Repository URL to install this package:
|
Version:
0.0.9 ▾
|
from pathlib import Path
from pytorch_lightning.callbacks import ModelCheckpoint as _ModelCheckpoint
__all__ = ['ModelCheckpoint']
class ModelCheckpoint(_ModelCheckpoint):
def __init__(
self, dirpath,
):
super(ModelCheckpoint, self).__init__(
# save_top_k=-1: saves all models
# save_top_k=k: only save the best k model (or the last k from the prev epoches)
every_n_epochs=1, save_top_k=-1, save_last=True,
dirpath=str(dirpath),
# epoch:02d will be automatically replaced with "epoch=" + actual epoch number, which is convenient but a bit
# confusing. Similar for global_step
# also, when saving to the same folder with existing checkpoint with the same name, pytorch-lightning will
# automatically add -v1, -v2 etc to the saved file.
filename="ckpt_{epoch}"
)
@property
def last_ckpt(self):
return Path(self.dirpath) / 'last.ckpt'