Repository URL to install this package:
|
Version:
2.5.0 ▾
|
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import logging
from argparse import ArgumentParser, Namespace
from pathlib import Path
from shutil import copyfile
import torch
from tqdm import tqdm
from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch
_log = logging.getLogger(__name__)
def _upgrade(args: Namespace) -> None:
path = Path(args.path).absolute()
extension: str = args.extension if args.extension.startswith(".") else f".{args.extension}"
files: list[Path] = []
if not path.exists():
_log.error(
f"The path {path} does not exist. Please provide a valid path to a checkpoint file or a directory"
f" containing checkpoints ending in {extension}."
)
exit(1)
if path.is_file():
files = [path]
if path.is_dir():
files = [Path(p) for p in glob.glob(str(path / "**" / f"*{extension}"), recursive=True)]
if not files:
_log.error(
f"No checkpoint files with extension {extension} were found in {path}."
f" HINT: Try setting the `--extension` option to specify the right file extension to look for."
)
exit(1)
_log.info("Creating a backup of the existing checkpoint files before overwriting in the upgrade process.")
for file in files:
backup_file = file.with_suffix(".bak")
if backup_file.exists():
# never overwrite backup files - they are the original, untouched checkpoints
continue
copyfile(file, backup_file)
_log.info("Upgrading checkpoints ...")
for file in tqdm(files):
with pl_legacy_patch():
checkpoint = torch.load(file, map_location=(torch.device("cpu") if args.map_to_cpu else None))
migrate_checkpoint(checkpoint)
torch.save(checkpoint, file)
_log.info("Done.")
def main() -> None:
parser = ArgumentParser(
description=(
"A utility to upgrade old checkpoints to the format of the current Lightning version."
" This will also save a backup of the original files."
)
)
parser.add_argument(
"path",
type=str,
help="Path to a checkpoint file or a directory with checkpoints to upgrade",
)
parser.add_argument(
"--extension",
"-e",
type=str,
default=".ckpt",
help="The file extension to look for when searching for checkpoint files in a directory.",
)
parser.add_argument(
"--map-to-cpu",
action="store_true",
help=(
"Map all tensors in the checkpoint to CPU. Enable this option if you are converting a GPU checkpoint"
" on a machine without GPUs."
),
)
args = parser.parse_args()
_upgrade(args)
if __name__ == "__main__":
main()