Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch-wrapper / pl / callbacks / __init__.py
Size: Mime:
from pathlib import Path

from pytorch_lightning.callbacks import ModelCheckpoint as _ModelCheckpoint

__all__ = ['ModelCheckpoint', 'BestKCheckpoint']


class CheckpointWithLast:
    """A model checkpoint which know where the last iteration's checkpoint is stored."""

    @property
    def last_ckpt(self) -> Path:
        raise NotImplementedError(type(self))


class ModelCheckpoint(_ModelCheckpoint, CheckpointWithLast):
    def __init__(
        self, dirpath,
    ):
        super(ModelCheckpoint, self).__init__(
            # save_top_k=-1: saves all models
            # save_top_k=k: only save the best k model (or the last k from the prev epoches)
            every_n_epochs=1, save_top_k=-1, save_last=True,
            dirpath=str(dirpath),
            # epoch:02d will be automatically replaced with "epoch=" + actual epoch number, which is convenient but a bit
            # confusing. Similar for global_step
            # also, when saving to the same folder with existing checkpoint with the same name, pytorch-lightning will
            # automatically add -v1, -v2 etc to the saved file.
            filename="ckpt_{epoch}"
        )

    @property
    def last_ckpt(self):
        return Path(self.dirpath) / 'last.ckpt'


class BestKCheckpoint(_ModelCheckpoint, CheckpointWithLast):
    def __init__(
        self, dirpath, monitor_value: str, save_top_k=1,
    ):
        """

        Parameters
        ----------
        dirpath
            where the checkpoint will be stored.
        monitor_value
            the name of the value to be monitored. Must be logged via "self.log" in ``pl.LightningModule``.
        save_top_k
            save the top k ckpt of the monitored value.
        """
        # see comments in ModelCheckpoint for explanation
        super(BestKCheckpoint, self).__init__(
            every_n_epochs=1,
            save_top_k=save_top_k,
            save_last=True,
            monitor=monitor_value,
            dirpath=str(dirpath),
            save_on_train_epoch_end=True,
            filename='ckpt_{epoch}_{' + monitor_value + '}'
        )

    @property
    def last_ckpt(self):
        return Path(self.dirpath) / 'last.ckpt'