Repository URL to install this package:
|
Version:
0.0.26 ▾
|
torch-wrapper
/
lr_scheduler.py
|
|---|
class LRScheduler:
name = 'LRScheduler'
@property
def cfg(self) -> dict:
raise NotImplementedError(type(self))
def get_pytorch(self, optimizer):
raise NotImplementedError(type(self))
def get_pl(self, optimizer):
"""returns a lr_scheduler_config according to
``pytorch_lightning.LightningModule.configure_optimizers``
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
"""
sch = self.get_pytorch(optimizer)
lr_scheduler_config = {
"scheduler": sch,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": 'epoch',
# 1 corresponds to updating the learning rate after every epoch/step.
"frequency": 1,
}
return lr_scheduler_config
_memory = {}
def from_config(name: str, cfg: dict) -> LRScheduler:
return _memory[name](**cfg)
def register_new(cls: type):
"""Register the class."""
if issubclass(cls, LRScheduler):
name = cls.name
if name in _memory:
raise KeyError(f'{name} already registered')
_memory[name] = cls
else:
raise TypeError(f"{cls} is not a {LRScheduler.__class__.__name__}")
class CosineAnnealingLR(LRScheduler):
name = 'CosineAnnealingLR'
def __init__(self, T_max: float, eta_min: float = 0):
super(CosineAnnealingLR, self).__init__()
self._cfg = {'T_max': T_max, 'eta_min': eta_min}
@property
def cfg(self) -> dict:
return self._cfg
def get_pytorch(self, optimizer):
from torch.optim.lr_scheduler import CosineAnnealingLR as C
return C(optimizer=optimizer, **self._cfg)
register_new(CosineAnnealingLR)
class _ExpDecayLambda:
def __init__(self, decay):
self.decay = decay
def __call__(self, step):
return 1 / (1. + self.decay * step)
class TFLikeExpDecay(LRScheduler):
name = 'TFLikeExpDecay'
def __init__(self, decay: float, interval='step'):
super(TFLikeExpDecay, self).__init__()
self._cfg = {'decay': decay, 'interval': interval}
@property
def cfg(self) -> dict:
return self._cfg
def get_pytorch(self, optimizer):
# noinspection PyUnresolvedReferences
from torch.optim.lr_scheduler import MultiplicativeLR
return MultiplicativeLR(optimizer, lr_lambda=_ExpDecayLambda(self._cfg['decay']))
def get_pl(self, optimizer):
"""returns a lr_scheduler_config according to
``pytorch_lightning.LightningModule.configure_optimizers``
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
"""
sch = self.get_pytorch(optimizer)
lr_scheduler_config = {
"scheduler": sch,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": self._cfg['interval'],
# 1 corresponds to updating the learning rate after every epoch/step.
"frequency": 1,
}
return lr_scheduler_config
register_new(TFLikeExpDecay)