Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch-wrapper / pl / callbacks / __init__.py
Size: Mime:
from pathlib import Path

from pytorch_lightning.callbacks import ModelCheckpoint as _ModelCheckpoint

__all__ = ['ModelCheckpoint']


class ModelCheckpoint(_ModelCheckpoint):
    def __init__(
        self, dirpath,
    ):
        super(ModelCheckpoint, self).__init__(
            # save_top_k=-1: saves all models
            # save_top_k=k: only save the best k model (or the last k from the prev epoches)
            every_n_epochs=1, save_top_k=-1, save_last=True,
            dirpath=str(dirpath),
            # epoch:02d will be automatically replaced with "epoch=" + actual epoch number, which is convenient but a bit
            # confusing. Similar for global_step
            # also, when saving to the same folder with existing checkpoint with the same name, pytorch-lightning will
            # automatically add -v1, -v2 etc to the saved file.
            filename="ckpt_{epoch}"
        )

    @property
    def last_ckpt(self):
        return Path(self.dirpath) / 'last.ckpt'