import torch
from torch.types import _TensorOrTensors
import torch.testing
from torch.overrides import is_tensor_like
import collections
from itertools import product
import warnings
from typing import Callable, Union, Optional, Iterable, List, Tuple, Dict
from torch._vmap_internals import vmap, _vmap
import functools
# Note: `get_*_jacobian` functions are added here even though we didn't intend to make them public
# since they have been exposed from before we added `__all__` and we already maintain BC for them
# We should eventually deprecate them and remove them from `__all__`
__all__ = ["gradcheck", "gradgradcheck", "GradcheckError", "get_numerical_jacobian",
"get_analytical_jacobian", "get_numerical_jacobian_wrt_specific_input"]
class GradcheckError(RuntimeError):
r"""Error raised by :func:`gradcheck` and :func:`gradgradcheck`"""
pass
def _is_sparse_compressed_tensor(obj: torch.Tensor):
return obj.layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
def _is_sparse_any_tensor(obj: torch.Tensor):
return _is_sparse_compressed_tensor(obj) or obj.layout is torch.sparse_coo
def _is_float_or_complex_tensor(obj):
return is_tensor_like(obj) and (obj.is_floating_point() or obj.is_complex())
def _allocate_jacobians_with_inputs(input_tensors: Tuple, numel_output) -> Tuple[torch.Tensor, ...]:
# Makes zero-filled tensors from inputs. If `numel_output` is not None, for
# each tensor in `input_tensors`, returns a new zero-filled tensor with height
# of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns
# a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have
# the same dtype and device as those of the corresponding input.
out: List[torch.Tensor] = []
for t in input_tensors:
if _is_float_or_complex_tensor(t) and t.requires_grad:
out.append(t.new_zeros((t.numel(), numel_output), layout=torch.strided))
return tuple(out)
def _allocate_jacobians_with_outputs(output_tensors: Tuple, numel_input, dtype=None,
device=None) -> Tuple[torch.Tensor, ...]:
# Makes zero-filled tensors from outputs. If `dim` is not None, for each tensor
# in `output_tensors`, returns a new zero-filled tensor with height of `dim` and
# width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size
# (t.numel,).
out: List[torch.Tensor] = []
options = {"dtype": dtype, "device": device, "layout": torch.strided}
for t in output_tensors:
if _is_float_or_complex_tensor(t):
out.append(t.new_zeros((numel_input, t.numel()), **options))
return tuple(out)
def _iter_tensors(x: Union[torch.Tensor, Iterable[torch.Tensor]],
only_requiring_grad: bool = False) -> Iterable[torch.Tensor]:
if is_tensor_like(x):
# mypy doesn't narrow type of `x` to torch.Tensor
if x.requires_grad or not only_requiring_grad: # type: ignore[union-attr]
yield x # type: ignore[misc]
elif isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
for elem in x:
for result in _iter_tensors(elem, only_requiring_grad):
yield result
def _iter_tensor(x_tensor):
# (Only used for slow gradcheck) Returns a generator that yields the following
# elements at each iteration:
# 1) a tensor: the same tensor is returned across all iterations. The tensor
# is not the same as the original x_tensor as given as input - it is
# prepared so that it can be modified in-place. Depending on whether the
# input tensor is strided, sparse, or dense, the returned tensor may or may
# not share storage with x_tensor.
# 2) a tuple of indices that can be used with advanced indexing (yielded in
# dictionary order)
# 3) flattened index that will be used to index into the Jacobian tensor
#
# For a tensor t with size (2, 2), _iter_tensor yields:
# `x, (0, 0), 0`, `x, (0, 1), 1`, `x, (1, 0), 2`, `x, (1, 1), 3`
#
# where x is the t.data of the original tensor. Perturbing the entry of x
# at index (1, 1) yields the 3rd column of the overall Jacobian matrix.
if _is_sparse_any_tensor(x_tensor):
def get_stride(size):
dim = len(size)
tmp = 1
stride = [0] * dim
for i in reversed(range(dim)):
stride[i] = tmp
tmp *= size[i]
return stride
x_nnz = x_tensor._nnz()
x_size = list(x_tensor.size())
if x_tensor.layout is torch.sparse_coo:
x_indices = x_tensor._indices().t()
x_values = x_tensor._values()
elif x_tensor.layout is torch.sparse_csr:
x_indices = torch._convert_indices_from_csr_to_coo(x_tensor.crow_indices(), x_tensor.col_indices()).t()
x_values = x_tensor.values()
elif x_tensor.layout is torch.sparse_csc:
x_indices = torch._convert_indices_from_csr_to_coo(x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True).t()
x_values = x_tensor.values()
elif x_tensor.layout is torch.sparse_bsr:
x_block_values = x_tensor.values()
x_blocksize = x_block_values.size()[1:3]
x_indices = torch._convert_indices_from_csr_to_coo(x_tensor.crow_indices(), x_tensor.col_indices()) \
.repeat_interleave(x_blocksize[0] * x_blocksize[1], 1) \
.mul_(torch.tensor(x_blocksize).reshape(2, 1)) \
.add_(torch.stack(torch.where(torch.ones(x_blocksize))).repeat(1, x_nnz)).t()
x_values = x_block_values.flatten(0, 2)
x_nnz = x_values.size(0)
elif x_tensor.layout is torch.sparse_bsc:
x_block_values = x_tensor.values()
x_blocksize = x_block_values.size()[1:3]
x_indices = torch._convert_indices_from_csr_to_coo(x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True) \
.repeat_interleave(x_blocksize[0] * x_blocksize[1], 1) \
.mul_(torch.tensor(x_blocksize).reshape(2, 1)) \
.add_(torch.stack(torch.where(torch.ones(x_blocksize))).repeat(1, x_nnz)).t()
x_values = x_block_values.flatten(0, 2)
x_nnz = x_values.size(0)
else:
raise NotImplementedError(f'_iter_tensor for {x_tensor.layout} input')
x_stride = get_stride(x_size)
# Use .data here to get around the version check
x_values = x_values.data
for i in range(x_nnz):
x_value = x_values[i]
for x_idx in product(*[range(m) for m in x_values.size()[1:]]):
indices = x_indices[i].tolist() + list(x_idx)
d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size)))
yield x_value, x_idx, d_idx
elif x_tensor.layout == torch._mkldnn: # type: ignore[attr-defined]
for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
# this is really inefficient, but without indexing implemented, there's
# not really a better way than converting back and forth
x_tensor_dense = x_tensor.to_dense()
yield x_tensor_dense, x_idx, d_idx
else:
# Use .data here to get around the version check
x_tensor = x_tensor.data
for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
yield x_tensor, x_idx, d_idx
def _get_numerical_jacobian(fn, inputs, outputs=None, target=None, eps=1e-3,
is_forward_ad=False) -> List[Tuple[torch.Tensor, ...]]:
"""Computes the numerical Jacobian of `fn(inputs)` with respect to `target`. If
not specified, targets are the input. Returns M * N Jacobians where N is the
number of tensors in target that require grad and M is the number of non-integral
outputs.
Args:
fn: the function to compute the jacobian for
inputs: inputs to `fn`
outputs: provide precomputed outputs to avoid one extra invocation of fn
target: the Tensors wrt whom Jacobians are calculated (default=`inputs`)
eps: the magnitude of the perturbation during finite differencing
(default=`1e-3`)
is_forward_ad: if this numerical jacobian is computed to be checked wrt
forward AD gradients (this is used for error checking only)
Returns:
A list of M N-tuples of tensors
Note that `target` may not even be part of `input` to `fn`, so please be
**very careful** in this to not clone `target`.
"""
jacobians: List[Tuple[torch.Tensor, ...]] = []
if outputs is None:
outputs = _as_tuple(fn(*_as_tuple(inputs)))
if not is_forward_ad and any(o.is_complex() for o in outputs):
raise ValueError("Expected output to be non-complex. get_numerical_jacobian no "
"longer supports functions that return complex outputs.")
if target is None:
target = inputs
inp_indices = [i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad]
for i, (inp, inp_idx) in enumerate(zip(_iter_tensors(target, True), inp_indices)):
jacobians += [get_numerical_jacobian_wrt_specific_input(fn, inp_idx, inputs, outputs, eps,
input=inp, is_forward_ad=is_forward_ad)]
return jacobians
def get_numerical_jacobian(fn, inputs, target=None, eps=1e-3, grad_out=1.0):
"""Deprecated API to compute the numerical Jacobian for a given fn and its inputs.
Args:
fn: the function to compute the Jacobian for (must take inputs as a tuple)
input: input to `fn`
target: the Tensors wrt whom Jacobians are calculated (default=`input`)
eps: the magnitude of the perturbation during finite differencing
(default=`1e-3`)
Returns:
A list of Jacobians of `fn` (restricted to its first output) with respect to
each input or target, if provided.
Note that `target` may not even be part of `input` to `fn`, so please be
**very careful** in this to not clone `target`.
"""
warnings.warn("get_numerical_jacobian was part of PyTorch's private API and not "
"meant to be exposed. We are deprecating it and it will be removed "
"in a future version of PyTorch. If you have a specific use for "
"this or feature request for this to be a stable API, please file "
"us an issue at https://github.com/pytorch/pytorch/issues/new")
if grad_out != 1.0: # grad_out param is only kept for backward compatibility reasons
raise ValueError("Expected grad_out to be 1.0. get_numerical_jacobian no longer "
"supports values of grad_out != 1.0.")
def fn_pack_inps(*inps):
return fn(inps)
jacobians = _get_numerical_jacobian(fn_pack_inps, inputs, None, target, eps)
return tuple(jacobian_for_each_output[0] for jacobian_for_each_output in jacobians)
def _compute_numerical_gradient(fn, entry, v, norm_v, nbhd_checks_fn):
# Performs finite differencing by perturbing `entry` in-place by `v` and
# returns the gradient of each of the outputs wrt to x at idx.
orig = entry.clone()
entry.copy_(orig - v)
outa = fn()
entry.copy_(orig + v)
outb = fn()
entry.copy_(orig)
def compute(a, b):
nbhd_checks_fn(a, b)
ret = (b - a) / (2 * norm_v)
return ret.detach().reshape(-1)
return tuple(compute(a, b) for (a, b) in zip(outa, outb))
def _compute_numerical_jvps_wrt_specific_input(jvp_fn, delta, input_is_complex,
is_forward_ad=False) -> List[torch.Tensor]:
# Computing the jacobian only works for real delta
# For details on the algorithm used here, refer:
# Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf
# s = fn(z) where z = x for real valued input
# and z = x + yj for complex valued input
jvps: List[torch.Tensor] = []
ds_dx_tup = jvp_fn(delta[0] if isinstance(delta, tuple) else delta)
if input_is_complex: # C -> R
ds_dy_tup = jvp_fn(delta[1] * 1j) if isinstance(delta, tuple) else jvp_fn(delta * 1j)
for ds_dx, ds_dy in zip(ds_dx_tup, ds_dy_tup):
assert(not ds_dx.is_complex())
# conjugate wirtinger derivative
conj_w_d = ds_dx + ds_dy * 1j
jvps.append(conj_w_d)
else:
for ds_dx in ds_dx_tup: # R -> R or (R -> C for the forward AD case)
assert(is_forward_ad or not ds_dx.is_complex())
jvps.append(ds_dx)
return jvps
def _combine_jacobian_cols(jacobians_cols: Dict[int, List[torch.Tensor]], outputs, input,
numel) -> Tuple[torch.Tensor, ...]:
# jacobian_cols maps column_idx -> output_idx -> single column of jacobian Tensor
# we return a list that maps output_idx -> full jacobian Tensor
jacobians = _allocate_jacobians_with_outputs(outputs, numel, dtype=input.dtype if input.dtype.is_complex else None)
for i, jacobian in enumerate(jacobians):
for k, v in jacobians_cols.items():
jacobian[k] = v[i]
return jacobians
def _prepare_input(input: torch.Tensor, maybe_perturbed_input: Optional[torch.Tensor],
fast_mode=False) -> torch.Tensor:
# Prepares the inputs to be passed into the function while including the new
# modified input.
if input.layout == torch._mkldnn: # type: ignore[attr-defined] # no attr _mkldnn
# Convert back to mkldnn
if maybe_perturbed_input is not None:
return maybe_perturbed_input.to_mkldnn()
else:
return input
elif _is_sparse_any_tensor(input):
if fast_mode and maybe_perturbed_input is not None:
# entry is already a "cloned" version of the original tensor
# thus changes to entry are not reflected in the input
return maybe_perturbed_input
else:
return input
else:
# We cannot use entry (input.data) if we want gradgrad to work because
# fn (in the gradgrad case) needs to compute grad wrt input
return input
def _check_outputs_same_dtype_and_shape(output1, output2, eps, idx=None) -> None:
# Check that the returned outputs don't have different dtype or shape when you
# perturb the input
on_index = "on index {idx} " if idx is not None else ""
assert output1.shape == output2.shape, \
(f"Expected `func` to return outputs with the same shape"
f" when inputs are perturbed {on_index}by {eps}, but got:"
f" shapes {output1.shape} and {output2.shape}.")
assert output1.dtype == output2.dtype, \
(f"Expected `func` to return outputs with the same dtype"
f" when inputs are perturbed {on_index}by {eps}, but got:"
f" dtypes {output1.dtype} and {output2.dtype}.")
def get_numerical_jacobian_wrt_specific_input(fn, input_idx, inputs, outputs, eps,
input=None, is_forward_ad=False) -> Tuple[torch.Tensor, ...]:
# Computes the numerical jacobians wrt to a single input. Returns N jacobian
# tensors, where N is the number of outputs. We use a dictionary for
# jacobian_cols because indices aren't necessarily consecutive for sparse inputs
# When we perturb only a single element of the input tensor at a time, the jvp
# is equivalent to a single col of the Jacobian matrix of fn.
jacobian_cols: Dict[int, List[torch.Tensor]] = {}
input = inputs[input_idx] if input is None else input
assert input.requires_grad
for x, idx, d_idx in _iter_tensor(input):
wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, x)
input_to_perturb = x[idx]
nbhd_checks_fn = functools.partial(_check_outputs_same_dtype_and_shape, idx=idx, eps=eps)
jvp_fn = _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn)
jacobian_cols[d_idx] = _compute_numerical_jvps_wrt_specific_input(jvp_fn, eps, x.is_complex(), is_forward_ad)
return _combine_jacobian_cols(jacobian_cols, outputs, input, input.numel())
def _get_analytical_jacobian_forward_ad(fn, inputs, outputs, *, check_grad_dtypes=False,
all_u=None) -> Tuple[Tuple[torch.Tensor, ...], ...]:
"""Computes the analytical Jacobian using forward mode AD of `fn(inputs)` using forward mode AD with respect
to `target`. Returns N * M Jacobians where N is the number of tensors in target that require grad and
M is the number of non-integral outputs.
Contrary to other functions here, this function requires "inputs" to actually be used by the function.
The computed value is expected to be wrong if the function captures the inputs by side effect instead of
using the passed ones (many torch.nn tests do this).
Args:
fn: the function to compute the jacobian for
inputs: inputs to `fn`
outputs: provide precomputed outputs to avoid one extra invocation of fn
Loading ...