import warnings
from .base_scheduler import BaseScheduler
__all__ = ["LambdaSL"]
class LambdaSL(BaseScheduler):
"""Sets the sparsity level of each parameter group to the final sl
times a given function. When last_epoch=-1, sets initial sl as zero.
Args:
sparsifier (BaseSparsifier): Wrapped sparsifier.
sl_lambda (function or list): A function which computes a multiplicative
factor given an integer parameter epoch, or a list of such
functions, one for each group in sparsifier.param_groups.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
Example:
>>> # Assuming sparsifier has two groups.
>>> lambda1 = lambda epoch: epoch // 30
>>> lambda2 = lambda epoch: 0.95 ** epoch
>>> # xdoctest: +SKIP
>>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2])
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False):
self.sparsifier = sparsifier
if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple):
self.sl_lambdas = [sl_lambda] * len(sparsifier.groups)
else:
if len(sl_lambda) != len(sparsifier.groups):
raise ValueError("Expected {} lr_lambdas, but got {}".format(
len(sparsifier.groups), len(sl_lambda)))
self.sl_lambdas = list(sl_lambda)
super().__init__(sparsifier, last_epoch, verbose)
def get_sl(self):
if not self._get_sl_called_within_step:
warnings.warn(
"To get the last sparsity level computed by the scheduler, "
"please use `get_last_sl()`.")
return [base_sl * lmbda(self.last_epoch)
for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)]