Repository URL to install this package:
|
Version:
0.0.24 ▾
|
import torch
import pytorch_lightning as pl
def fix_pl_issue_12274(trainer: pl.Trainer, ckpt_path: str):
"""
Summary:
pytorch lightning's step number got reset to 0 after reloading from checkpoint.
Related ref:
1. issue page: https://github.com/Lightning-AI/lightning/issues/12274
2. affected 3rd party logger: Neptune
Source code of Neptune's pytorch_lightning logger, saying that `step`
is not guaranteed to be strictly increasing from pytorch_lightning.
`pytorch_lightning.loggers.neptune.NeptuneLogger.log_metrics`
3. pytorch_lightning GitHub issue with global_step in other 3rd party loggers, might be related:
https://github.com/Lightning-AI/lightning/issues/13163
"""
if pl.__version__ != '1.6.4':
raise RuntimeError(f'pytorch_lightning version={pl.__version__}, not supported.')
checkpoint = torch.load(ckpt_path, map_location='cpu')
global_step_offset = checkpoint["global_step"]
trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset
del checkpoint