import functools
import torch
from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeight
from torch.utils._pytree import tree_flatten
# dependency on `functional_call` means that this can't be exposed in utils
# without creating circular dependency
def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum", batch_first=True):
r"""
call_for_per_sample_grads(module, batch_size=None, loss_reduction="sum", batch_first=True)
``call_for_per_sample_grads`` returns a function that is invoked like the forward
function of ``module`` and will produce the same result. Then, when backward is invoked,
the parameters of ``module`` will have a ``grad_sample`` field populated with the per sample
gradients instead of the regular gradients
Args:
module: The ``nn.Module`` to get per sample gradients with respect to. All trainable
parameters will compute per sample gradients, located in a ``grad_sample``
field when ``backward`` is invoked
batch_size: The batch size of the input. If None is passed, all tensor arguments in args and kwargs must have
the same batch size, which is the size of the first dimension. Otherwise, it must be passed manually.
Default: None
loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. If
"mean", per sample gradients will be scaled by the batch size to offset the crossbatch interaction from
running mean across a batch. Must be "mean" or "sum". Default: "sum"
batch_first: Indicates if the batch dimension is the first dimension. If True, the batch dimension is the first
dimension. If False, it's the second dimension. Default: True.
Examples::
>>> # xdoctest: +SKIP
>>> model = nn.Linear(4, 3)
>>> batched_input = torch.randn(5, 4) # batch size of 5
>>> res = call_for_per_sample_grads(model)(batched_input).sum()
>>> res.backward()
>>> assert model.weight.shape == (3, 4)
>>> assert model.weight.grad_sample.shape == (5, 3, 4)
>>> assert model.weight.grad is None
>>> assert model.bias.shape == (3,)
>>> assert model.bias.grad_sample.shape == (5, 3)
>>> assert model.bias.grad is None
An example using "mean" loss reduction. The grad_sample fields will be scaled by batch_size from what they would be
if we ran the same code with loss_reduction="sum". This is because the mean at the end will scale all
grad_outputs by 1 / batch_size from cross batch interaction.
>>> model = nn.Linear(4, 3)
>>> batched_input = torch.randn(5, 4) # batch size of 5
>>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")(batched_input).mean()
>>> res.backward()
Note::
Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom
rewrites that wrap an `nn.Linear` module. See Opacus for an example
"""
def maybe_build_expanded_weight(og_tensor, batch_size):
if og_tensor.requires_grad:
return ExpandedWeight(og_tensor, batch_size, loss_reduction)
else:
return og_tensor
def compute_batch_size(*args, **kwargs):
args_and_kwargs = tree_flatten(args)[0] + tree_flatten(kwargs)[0]
batch_size = None
for arg in args_and_kwargs:
if not isinstance(arg, torch.Tensor):
continue
arg_batch_size = arg.shape[0] if batch_first else arg.shape[1]
if batch_size is not None and batch_size != arg_batch_size:
raise RuntimeError("When computing batch size, found at least one input with batch size "
f"{batch_size} and one with batch size {arg_batch_size}. Please specify it "
"explicitly using the batch size kwarg in call_for_per_sample_grads")
batch_size = arg_batch_size
if batch_size is None:
raise RuntimeError("Unable to find a tensor in the passed args and kwargs. They may not be pytree-able "
"and so ExpandedWeights cannot compute the batch size from the inputs. Please specify "
"it explicitly")
return batch_size
if loss_reduction not in ["sum", "mean"]:
raise RuntimeError(f"Expected loss_reduction argument to be sum or mean, got {loss_reduction}")
if not isinstance(module, torch.nn.Module):
raise RuntimeError(f"Module passed must be nn.Module, got {type(module).__name__}")
if not (batch_size is None or isinstance(batch_size, int)):
raise RuntimeError(f"Batch size passed must be None or an integer, got {type(batch_size).__name__}")
if batch_size is not None and batch_size < 1:
raise RuntimeError(f"Batch size must be positive, got {batch_size}")
for weight in module.parameters():
if hasattr(weight, "grad_sample") and weight.grad_sample is not None: # type: ignore[attr-defined]
raise RuntimeError("Current Expanded Weights accumulates the gradients, which will be incorrect for multiple "
f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or "
"post an issue to pytorch/pytorch to prioritize correct behavior")
@functools.wraps(module.forward)
def wrapper(*args, **kwargs):
wrapper_batch_size = batch_size
if wrapper_batch_size is None:
wrapper_batch_size = compute_batch_size(*args, **kwargs)
params = {name: maybe_build_expanded_weight(value, wrapper_batch_size) for (name, value) in module.named_parameters()}
return torch.func.functional_call(module, params, args, kwargs)
return wrapper