import inspect
from typing import Callable, Dict, List, Optional, Tuple
import torch
import torch._decomp
from torch import Tensor
decomposition_table = torch._decomp.decomposition_table
decomposition_table_for_jvp: Dict[torch._ops.OpOverload, Callable] = {}
register_decomposition = torch._decomp.register_decomposition
aten = torch.ops.aten
# NOTE: [forward-mode AD decompositions mechanism]
#
# The mechanism is in VariableType,
# IF any inputs have forward grad
# AND there is no forward AD formula implemented
# AND the functions is actually differentiable
# run the decomposition
# See run_jit_decomposition_with_args_for_jvp
# We currently use python decompositions that we torchscript.
#
# Note that we would be building the backward graph at the decomposed level
# too, but that is OK, because we would've errored out otherwise anyway.
#
# TODO: The mechanism we are using to register decompositions doesn't
# seem to be exclusively used for jvp. So open question here is whether
# torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things.
# If that is the case, we may go down the decomposition path unexpectedly
# (and possibly produce an unintelligible error) vs erroring out earlier and
# printing that the forward AD formula is not implemented.
#
# The solution to this may be to have a explicitly white list control when
# to enable the decomposition.
def maybe_register_decomposition(op):
def decorator(f):
try:
return register_decomposition(op)(f)
except Exception:
return f
return decorator
# Functions where we need a special decomposition for jvp but there's another version that
# should be used more generally (ex. for jvp we need to recompute the mean and variance for
# the backwards of a normalization function. Without jvp, it should used the saved value)
decomposition_table_for_jvp = {}
def register_decomposition_for_jvp(fn):
return register_decomposition(fn, registry=decomposition_table_for_jvp)
def _register_jit_decomposition_for_jvp(decomp, use_python=False):
if decomp in decomposition_table_for_jvp:
decomposition_table_used = decomposition_table_for_jvp
elif decomp in decomposition_table:
decomposition_table_used = decomposition_table
else:
raise RuntimeError(f"could not find decomposition for {decomp}")
decomp_fn = decomposition_table_used[decomp]
if use_python:
decomp_fn = torch.jit.ignore(decomp_fn)
sig = inspect.signature(decomp_fn)
# Create a string wrapping the function from the signature
# example output:
# def wrapped_decomp(x: torch.Tensor, y: int, z: int):
# return decomp_fn(x, y, z)
# Thanks copilot!
def get_function_def(sig):
param_def = [f"{param_str}" for param_str in sig.parameters.values()]
param_use = [f"{param_str}" for param_str in sig.parameters.keys()]
return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n"
f_str = get_function_def(sig)
graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph
else:
graph = torch.jit.script(decomp_fn).graph
torch.jit._register_decomposition(decomp, graph)
# The only decompositions here are temporary or hacks for the purposes of jvp
# TODO: do these also belong here?
@maybe_register_decomposition(aten.trace.default)
def trace(self: Tensor) -> Tensor:
return torch.sum(torch.diag(self))
@maybe_register_decomposition(aten.log_sigmoid_forward.default)
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
min = torch.minimum(self.new_zeros(()), self)
z = torch.exp(-torch.abs(self))
if self.is_cuda:
buffer = self.new_zeros((0,))
else:
buffer = z
return min - torch.log1p(z), buffer
def recompute_mean_var(
input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool
):
# for most norm decompositions, it will be the same as the core version except for here.
# We recompute the mean and variance so that they track gradients through input
mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
eps = eps.detach()
rstd = 1 / torch.sqrt(var + eps)
return mean, rstd
@register_decomposition_for_jvp(aten.native_layer_norm_backward)
def native_layer_norm_backward(
grad_out: Tensor,
input: Tensor,
normalized_shape: List[int],
mean: Tensor,
rstd: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
output_mask: List[bool],
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_ndim = input.dim()
axis = input_ndim - len(normalized_shape)
inner_dims = input_shape[axis:]
outer_dims = input_shape[:axis]
inner_dim_indices = list(range(axis, input_ndim))
outer_dim_indices = list(range(0, axis))
N = 1
for i in inner_dims:
N *= i
M = 1
for i in outer_dims:
M *= i
if M <= 0 or N <= 0:
return (
input.new_zeros(input_shape),
input.new_zeros(input_shape[axis:]),
input.new_zeros(input_shape[axis:]),
)
mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)
x_hat = (input - mean_) * rstd_
if weight is not None:
grad_x_hat = grad_out * weight
else:
grad_x_hat = grad_out
a = grad_x_hat * N
b = torch.sum(grad_x_hat, inner_dim_indices, True)
c1 = torch.mul(grad_x_hat, x_hat)
c2 = torch.sum(c1, inner_dim_indices, True)
c3 = torch.mul(x_hat, c2)
inner = a - b - c3
if output_mask[0]:
d_input: Optional[Tensor] = (rstd_ / N) * inner
else:
d_input = torch.zeros_like(input) # should be None but doesn't work with vjp
if output_mask[1] and weight is not None:
if len(outer_dim_indices) > 0:
d_weight: Optional[Tensor] = torch.sum(
grad_out * x_hat, outer_dim_indices, False
)
else:
d_weight = grad_out * x_hat
elif weight is not None:
d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
else:
d_weight = torch.zeros(()) # should be None but doesn't work with vjp
if output_mask[2] and bias is not None:
if len(outer_dim_indices) > 0:
d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
else:
d_bias = grad_out.clone()
elif bias is not None:
d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp
else:
d_bias = torch.zeros(()) # should be None but doesn't work with vjp
return (d_input, d_weight, d_bias)
def prod(x: List[int]):
r = 1
for i in x:
r *= i
return r
@register_decomposition_for_jvp(aten.native_batch_norm_backward)
def native_batch_norm_backward(
grad_out: Tensor,
input: Tensor,
weight: Optional[Tensor],
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
save_mean: Optional[Tensor],
save_invstd: Optional[Tensor],
train: bool,
eps: float,
output_mask: List[bool],
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_rank = input.dim()
assert input_rank >= 2, "rank of the input must be at least 2"
axis = 1
num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type]
mean = save_mean
invstd = save_invstd
if train:
assert (
save_mean is not None and save_invstd is not None
), "when train=True, save_mean and save_invstd are required"
reduciton_dims = [0] + list(range(2, input.dim()))
assert invstd is not None # for typing
mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
else:
assert running_mean is not None and running_var is not None
mean = running_mean
invstd = torch.rsqrt(running_var + eps)
assert invstd is not None and mean is not None
broadcast_mask = [1] * input_rank
broadcast_mask[axis] = input_shape[axis]
reduction_axes: List[int] = []
for i in range(input_rank):
if i != axis:
reduction_axes.append(i)
mean = torch.reshape(mean, broadcast_mask)
norm = 1.0 / num_features
grad_output_sum = torch.sum(grad_out, reduction_axes)
dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
if weight is None:
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
else:
grad_scale = torch.reshape(invstd * weight, broadcast_mask)
if train:
proj = (input - mean) * proj_scale
grad_input = ((grad_out - proj) - grad_mean) * grad_scale
else:
grad_input = grad_out * grad_scale
if output_mask[1]:
grad_weight = dot_p * invstd
elif weight is not None:
grad_weight = torch.zeros_like(
weight
) # should be None but doesn't work with vjp
else:
grad_weight = torch.zeros(()) # should be None but doesn't work with vjp
if output_mask[2]:
grad_bias = grad_output_sum
else:
grad_bias = torch.zeros_like(
grad_output_sum
) # should be None but doesn't work with vjp
return (grad_input, grad_weight, grad_bias)
_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default)
_register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)