Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch-wrapper / pl / callbacks / __init__.py
Size: Mime:
import os
import re
from pathlib import Path

from pytorch_lightning.callbacks import ModelCheckpoint as _ModelCheckpoint

__all__ = ['SaveEveryEpoch', 'SaveBestK']


class CheckpointWithLast:
    """A model checkpoint which know where the last iteration's checkpoint is stored."""

    def get_recover_ckpt(self) -> Path:
        """Returns a ckpt where the training can be continued from a previous iteration.

        Raises
        -------
        FileNotFoundError
            if the ckpt file does not exist.
        """
        raise NotImplementedError(type(self))


class SaveEveryEpoch(_ModelCheckpoint, CheckpointWithLast):
    def __init__(
        self, dirpath,
    ):
        super(SaveEveryEpoch, self).__init__(
            # save_top_k=-1: saves all models
            # save_top_k=k: only save the best k model (or the last k from the prev epoches)
            every_n_epochs=1, save_top_k=-1, save_last=True,
            dirpath=str(dirpath),
            # epoch:02d will be automatically replaced with "epoch=" + actual epoch number, which is convenient but a bit
            # confusing. Similar for global_step
            # also, when saving to the same folder with existing checkpoint with the same name, pytorch-lightning will
            # automatically add -v1, -v2 etc to the saved file.
            filename="ckpt_{epoch}"
        )

    def get_recover_ckpt(self) -> Path:
        ret = Path(self.dirpath) / 'last.ckpt'
        if ret.exists():
            return ret
        raise FileNotFoundError('Recovery ckpt not found')


class SaveBestK(_ModelCheckpoint, CheckpointWithLast):
    def __init__(
        self, dirpath, monitor_value: str, save_top_k=1, save_last=True,
        recover_from='last'
    ):
        """

        Parameters
        ----------
        dirpath
            where the checkpoint will be stored.
        monitor_value
            the name of the value to be monitored. Must be logged via "self.log" in ``pl.LightningModule``.
        save_top_k
            save the top k ckpt of the monitored value.
        save_last
            If true, will save the last epoch (regardless it's the best or not). The last epoch's model is used for
            ``self.recover_ckpt``.
            If false, will only save best k models. The recover method is determined from `recover_from`.
        recover_from
            'last': recover from the last epoch
            (if the last epoch is not saved, the recover from the last of best k epochs.)
            'best': recover from the best epoch.
        """
        # see comments in ModelCheckpoint for explanation
        super(SaveBestK, self).__init__(
            every_n_epochs=1,
            save_top_k=save_top_k,
            save_last=save_last,
            monitor=monitor_value,
            dirpath=str(dirpath),
            save_on_train_epoch_end=True,
            filename='ckpt_{epoch}_{' + monitor_value + '}'
        )
        self._save_last = save_last
        if recover_from not in ('last', 'best'): raise ValueError(recover_from)
        self._monitor_value = monitor_value
        self._recover_from = recover_from

    @staticmethod
    def _list_best_helper(files, monitor_value):
        pattern = re.compile(r'ckpt_epoch=(\d+)_' + re.escape(monitor_value) + r'=(\d+\.*\d*)\.ckpt')
        ret = []
        for f in files:
            m = pattern.match(f)
            if m:
                epoch, value = m.groups()
                epoch = int(epoch)
                value = float(value)
                ret.append({'epoch': epoch, 'value': value, 'f': f})
        return ret

    @staticmethod
    def _list_best_helper_last_epoch(best_epochs):
        return sorted(best_epochs, key=lambda x: x['epoch'])[-1]['f']

    @staticmethod
    def _list_best_helper_best_val(best_epochs):
        return sorted(best_epochs, key=lambda x: x['value'])[0]['f']

    def _list_best(self):
        return self._list_best_helper(os.listdir(Path(self.dirpath)), self._monitor_value)

    def get_recover_ckpt(self) -> Path:
        if self._recover_from == 'last':
            if self._save_last:
                ret = Path(self.dirpath) / 'last.ckpt'
                if ret.exists():
                    return ret
                raise FileNotFoundError('Recovery file not found')
            best_epochs = self._list_best()
            if len(best_epochs) == 0:
                raise FileNotFoundError('Recovery file not found')
            return Path(self.dirpath) / self._list_best_helper_last_epoch(best_epochs)
        elif self._recover_from == 'best':
            best_epochs = self._list_best()
            if len(best_epochs) == 0:
                raise FileNotFoundError('Recovery file not found')
            return Path(self.dirpath) / self._list_best_helper_best_val(best_epochs)
        else:
            raise NotImplementedError(self._recover_from)