Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch-wrapper / pl / bug_fix.py
Size: Mime:
import torch
import pytorch_lightning as pl


def fix_pl_issue_12274(trainer: pl.Trainer, ckpt_path: str):
    """
    Summary:
    pytorch lightning's step number got reset to 0 after reloading from checkpoint.

    Related ref:

    1. issue page: https://github.com/Lightning-AI/lightning/issues/12274
    2. affected 3rd party logger: Neptune
       Source code of Neptune's pytorch_lightning logger, saying that `step`
       is not guaranteed to be strictly increasing from pytorch_lightning.
       `pytorch_lightning.loggers.neptune.NeptuneLogger.log_metrics`
    3. pytorch_lightning GitHub issue with global_step in other 3rd party loggers, might be related:
        https://github.com/Lightning-AI/lightning/issues/13163

    """
    if pl.__version__ != '1.6.4':
        raise RuntimeError(f'pytorch_lightning version={pl.__version__}, not supported.')
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    global_step_offset = checkpoint["global_step"]
    trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset
    del checkpoint