from typing import Any, Dict, Iterable, List, no_type_check, Type
import torch
__all__: List[str] = []
# WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter
# without changing it's life-time.
# NOTE: Alternative is to add the meta-data as an attribute to the tensor,
# but that will serialize the meta-data if Tensor is serialized.
param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary()
param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary()
@no_type_check
def _apply_optimizer_in_backward(
optimizer_class: Type[torch.optim.Optimizer],
params: Iterable[torch.nn.Parameter],
optimizer_kwargs: Dict[str, Any],
) -> None:
"""
Upon ``backward()``, parameters will fire the corresponding optimizer.
Note - gradients for these parameters will be set to None after ``backward()``.
This means that any other (non applied) optimizer over this parameter will be
a no-op.
Args:
optimizer_class: (Type[torch.optim.Optimizer]): Optimizer to apply to parameter
params: (Iterator[nn.Parameter]): parameters to apply optimizer state to
optimizer_kwargs: (Dict[str, Any]): kwargs to pass to optimizer constructor
Example::
params_generator = model.parameters()
param_1 = next(params_generator)
remainder_params = list(params_generator)
apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": .02})
apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": .04})
model(...).sum().backward() # after backward, parameters will already
# have their registered optimizer applied.
"""
@no_type_check
def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None:
# view_as creates a node in autograd graph that allows us access to the
# parameter's AccumulateGrad autograd function object. We register a
# hook on this object to fire the optimizer when the gradient for
# this parameter is ready (has been accumulated into .grad field)
# Don't create a new acc_grad if we already have one
# i.e. for shared parameters or attaching multiple optimizers to a param.
if param not in param_to_acc_grad_map:
param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[0][0]
optimizer = optimizer_class([param], **optimizer_kwargs)
if not hasattr(param, "_in_backward_optimizers"):
param._in_backward_optimizers = [] # type: ignore[attr-defined]
# TODO: investigate whether we really need these attributes.
param._optimizer_classes = [] # type: ignore[attr-defined]
param._optimizer_kwargs = [] # type: ignore[attr-defined]
param._in_backward_optimizers.append(optimizer) # type: ignore[attr-defined]
param._optimizer_classes.append(optimizer_class) # type: ignore[attr-defined]
param._optimizer_kwargs.append(optimizer_kwargs) # type: ignore[attr-defined]
def optimizer_hook(*_unused) -> None:
for opt in param._in_backward_optimizers: # type: ignore[attr-defined]
opt.step()
param.grad = None
handle = param_to_acc_grad_map[param].register_hook(optimizer_hook) # type: ignore[attr-defined]
if param not in param_to_optim_hook_handle_map:
param_to_optim_hook_handle_map[param] = []
param_to_optim_hook_handle_map[param].append(handle)
for param in params:
_apply_optimizer_in_backward_to_param(param)