Repository URL to install this package:
|
Version:
0.0.16 ▾
|
torch-wrapper
/
lr_scheduler.py
|
|---|
class LRScheduler:
name = 'LRScheduler'
@property
def cfg(self) -> dict:
raise NotImplementedError(type(self))
def get_pytorch(self, optimizer):
raise NotImplementedError(type(self))
_memory = {}
def from_config(name: str, cfg: dict) -> LRScheduler:
return _memory[name](**cfg)
def register_new(cls: type):
"""Register the class."""
if issubclass(cls, LRScheduler):
name = cls.name
if name in _memory:
raise KeyError(f'{name} already registered')
_memory[name] = cls
else:
raise TypeError(f"{cls} is not a {LRScheduler.__class__.__name__}")
class CosineAnnealingLR(LRScheduler):
name = 'CosineAnnealingLR'
def __init__(self, T_max: float, eta_min: float = 0):
super(CosineAnnealingLR, self).__init__()
self._cfg = {'T_max': T_max, 'eta_min': eta_min}
@property
def cfg(self) -> dict:
return self._cfg
def get_pytorch(self, optimizer):
from torch.optim.lr_scheduler import CosineAnnealingLR as C
return C(optimizer=optimizer, **self._cfg)
register_new(CosineAnnealingLR)