Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch-wrapper / lr_scheduler.py
Size: Mime:
class LRScheduler:
    name = 'LRScheduler'

    @property
    def cfg(self) -> dict:
        raise NotImplementedError(type(self))

    def get_pytorch(self, optimizer):
        raise NotImplementedError(type(self))

    def get_pl(self, optimizer):
        """returns a lr_scheduler_config according to
        ``pytorch_lightning.LightningModule.configure_optimizers``

        https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
        """
        sch = self.get_pytorch(optimizer)
        lr_scheduler_config = {
            "scheduler": sch,
            # The unit of the scheduler's step size, could also be 'step'.
            # 'epoch' updates the scheduler on epoch end whereas 'step'
            # updates it after a optimizer update.
            "interval": 'epoch',
            # 1 corresponds to updating the learning rate after every epoch/step.
            "frequency": 1,
        }
        return lr_scheduler_config


_memory = {}


def from_config(name: str, cfg: dict) -> LRScheduler:
    return _memory[name](**cfg)


def register_new(cls: type):
    """Register the class."""
    if issubclass(cls, LRScheduler):
        name = cls.name
        if name in _memory:
            raise KeyError(f'{name} already registered')
        _memory[name] = cls
    else:
        raise TypeError(f"{cls} is not a {LRScheduler.__class__.__name__}")


class CosineAnnealingLR(LRScheduler):
    name = 'CosineAnnealingLR'

    def __init__(self, T_max: float, eta_min: float = 0):
        super(CosineAnnealingLR, self).__init__()
        self._cfg = {'T_max': T_max, 'eta_min': eta_min}

    @property
    def cfg(self) -> dict:
        return self._cfg

    def get_pytorch(self, optimizer):
        from torch.optim.lr_scheduler import CosineAnnealingLR as C
        return C(optimizer=optimizer, **self._cfg)


register_new(CosineAnnealingLR)


class _ExpDecayLambda:
    def __init__(self, decay):
        self.decay = decay

    def __call__(self, step):
        return 1 / (1. + self.decay * step)


class TFLikeExpDecay(LRScheduler):
    name = 'TFLikeExpDecay'

    def __init__(self, decay: float, interval='step'):
        super(TFLikeExpDecay, self).__init__()
        self._cfg = {'decay': decay, 'interval': interval}

    @property
    def cfg(self) -> dict:
        return self._cfg

    def get_pytorch(self, optimizer):
        # noinspection PyUnresolvedReferences
        from torch.optim.lr_scheduler import MultiplicativeLR
        return MultiplicativeLR(optimizer, lr_lambda=_ExpDecayLambda(self._cfg['decay']))

    def get_pl(self, optimizer):
        """returns a lr_scheduler_config according to
        ``pytorch_lightning.LightningModule.configure_optimizers``

        https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
        """
        sch = self.get_pytorch(optimizer)
        lr_scheduler_config = {
            "scheduler": sch,
            # The unit of the scheduler's step size, could also be 'step'.
            # 'epoch' updates the scheduler on epoch end whereas 'step'
            # updates it after a optimizer update.
            "interval": self._cfg['interval'],
            # 1 corresponds to updating the learning rate after every epoch/step.
            "frequency": 1,
        }
        return lr_scheduler_config


register_new(TFLikeExpDecay)