Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch-wrapper / lr_scheduler.py
Size: Mime:
class LRScheduler:
    name = 'LRScheduler'

    @property
    def cfg(self) -> dict:
        raise NotImplementedError(type(self))

    def get_pytorch(self, optimizer):
        raise NotImplementedError(type(self))


_memory = {}


def from_config(name: str, cfg: dict) -> LRScheduler:
    return _memory[name](**cfg)


def register_new(cls: type):
    """Register the class."""
    if issubclass(cls, LRScheduler):
        name = cls.name
        if name in _memory:
            raise KeyError(f'{name} already registered')
        _memory[name] = cls
    else:
        raise TypeError(f"{cls} is not a {LRScheduler.__class__.__name__}")


class CosineAnnealingLR(LRScheduler):
    name = 'CosineAnnealingLR'

    def __init__(self, T_max: float, eta_min: float = 0):
        super(CosineAnnealingLR, self).__init__()
        self._cfg = {'T_max': T_max, 'eta_min': eta_min}

    @property
    def cfg(self) -> dict:
        return self._cfg

    def get_pytorch(self, optimizer):
        from torch.optim.lr_scheduler import CosineAnnealingLR as C
        return C(optimizer=optimizer, **self._cfg)


register_new(CosineAnnealingLR)