Repository URL to install this package:
|
Version:
0.0.24 ▾
|
import os
import re
from pathlib import Path
from pytorch_lightning.callbacks import ModelCheckpoint as _ModelCheckpoint
__all__ = ['SaveEveryEpoch', 'SaveBestK']
class CheckpointWithLast:
"""A model checkpoint which know where the last iteration's checkpoint is stored."""
def get_recover_ckpt(self) -> Path:
"""Returns a ckpt where the training can be continued from a previous iteration.
Raises
-------
FileNotFoundError
if the ckpt file does not exist.
"""
raise NotImplementedError(type(self))
class SaveEveryEpoch(_ModelCheckpoint, CheckpointWithLast):
def __init__(
self, dirpath,
):
super(SaveEveryEpoch, self).__init__(
# save_top_k=-1: saves all models
# save_top_k=k: only save the best k model (or the last k from the prev epoches)
every_n_epochs=1, save_top_k=-1, save_last=True,
dirpath=str(dirpath),
# epoch:02d will be automatically replaced with "epoch=" + actual epoch number, which is convenient but a bit
# confusing. Similar for global_step
# also, when saving to the same folder with existing checkpoint with the same name, pytorch-lightning will
# automatically add -v1, -v2 etc to the saved file.
filename="ckpt_{epoch}"
)
def get_recover_ckpt(self) -> Path:
ret = Path(self.dirpath) / 'last.ckpt'
if ret.exists():
return ret
raise FileNotFoundError('Recovery ckpt not found')
class SaveBestK(_ModelCheckpoint, CheckpointWithLast):
def __init__(
self, dirpath, monitor_value: str, save_top_k=1, save_last=True,
recover_from='last'
):
"""
Parameters
----------
dirpath
where the checkpoint will be stored.
monitor_value
the name of the value to be monitored. Must be logged via "self.log" in ``pl.LightningModule``.
save_top_k
save the top k ckpt of the monitored value.
save_last
If true, will save the last epoch (regardless it's the best or not). The last epoch's model is used for
``self.recover_ckpt``.
If false, will only save best k models. The recover method is determined from `recover_from`.
recover_from
'last': recover from the last epoch
(if the last epoch is not saved, the recover from the last of best k epochs.)
'best': recover from the best epoch.
"""
# see comments in ModelCheckpoint for explanation
super(SaveBestK, self).__init__(
every_n_epochs=1,
save_top_k=save_top_k,
save_last=save_last,
monitor=monitor_value,
dirpath=str(dirpath),
save_on_train_epoch_end=True,
filename='ckpt_{epoch}_{' + monitor_value + '}'
)
self._save_last = save_last
if recover_from not in ('last', 'best'): raise ValueError(recover_from)
self._monitor_value = monitor_value
self._recover_from = recover_from
@staticmethod
def _list_best_helper(files, monitor_value):
pattern = re.compile(r'ckpt_epoch=(\d+)_' + re.escape(monitor_value) + r'=(\d+\.*\d*)\.ckpt')
ret = []
for f in files:
m = pattern.match(f)
if m:
epoch, value = m.groups()
epoch = int(epoch)
value = float(value)
ret.append({'epoch': epoch, 'value': value, 'f': f})
return ret
@staticmethod
def _list_best_helper_last_epoch(best_epochs):
return sorted(best_epochs, key=lambda x: x['epoch'])[-1]['f']
@staticmethod
def _list_best_helper_best_val(best_epochs):
return sorted(best_epochs, key=lambda x: x['value'])[0]['f']
def _list_best(self):
return self._list_best_helper(os.listdir(Path(self.dirpath)), self._monitor_value)
def get_recover_ckpt(self) -> Path:
if self._recover_from == 'last':
if self._save_last:
ret = Path(self.dirpath) / 'last.ckpt'
if ret.exists():
return ret
raise FileNotFoundError('Recovery file not found')
best_epochs = self._list_best()
if len(best_epochs) == 0:
raise FileNotFoundError('Recovery file not found')
return Path(self.dirpath) / self._list_best_helper_last_epoch(best_epochs)
elif self._recover_from == 'best':
best_epochs = self._list_best()
if len(best_epochs) == 0:
raise FileNotFoundError('Recovery file not found')
return Path(self.dirpath) / self._list_best_helper_best_val(best_epochs)
else:
raise NotImplementedError(self._recover_from)