import torch
from typing import Tuple, List
from . import forward_ad as fwAD
from torch._vmap_internals import _vmap
__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
# Utility functions
def _as_tuple_nocheck(x):
if isinstance(x, tuple):
return x
elif isinstance(x, list):
return tuple(x)
else:
return x,
def _as_tuple(inp, arg_name=None, fn_name=None):
# Ensures that inp is a tuple of Tensors
# Returns whether or not the original inp was a tuple and the tupled version of the input
if arg_name is None and fn_name is None:
return _as_tuple_nocheck(inp)
is_inp_tuple = True
if not isinstance(inp, tuple):
inp = (inp,)
is_inp_tuple = False
for i, el in enumerate(inp):
if not isinstance(el, torch.Tensor):
if is_inp_tuple:
raise TypeError("The {} given to {} must be either a Tensor or a tuple of Tensors but the"
" value at index {} has type {}.".format(arg_name, fn_name, i, type(el)))
else:
raise TypeError("The {} given to {} must be either a Tensor or a tuple of Tensors but the"
" given {} has type {}.".format(arg_name, fn_name, arg_name, type(el)))
return is_inp_tuple, inp
def _tuple_postprocess(res, to_unpack):
# Unpacks a potentially nested tuple of Tensors
# to_unpack should be a single boolean or a tuple of two booleans.
# It is used to:
# - invert _as_tuple when res should match the inp given to _as_tuple
# - optionally remove nesting of two tuples created by multiple calls to _as_tuple
if isinstance(to_unpack, tuple):
assert len(to_unpack) == 2
if not to_unpack[1]:
res = tuple(el[0] for el in res)
if not to_unpack[0]:
res = res[0]
else:
if not to_unpack:
res = res[0]
return res
def _grad_preprocess(inputs, create_graph, need_graph):
# Preprocess the inputs to make sure they require gradient
# inputs is a tuple of Tensors to preprocess
# create_graph specifies if the user wants gradients to flow back to the Tensors in inputs
# need_graph specifies if we internally want gradients to flow back to the Tensors in res
# Note that we *always* create a new Tensor object to be able to see the difference between
# inputs given as arguments and the same Tensors automatically captured by the user function.
# Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576
res = []
for inp in inputs:
if create_graph and inp.requires_grad:
# Create at least a new Tensor object in a differentiable way
if not inp.is_sparse:
# Use .view_as() to get a shallow copy
res.append(inp.view_as(inp))
else:
# We cannot use view for sparse Tensors so we clone
res.append(inp.clone())
else:
res.append(inp.detach().requires_grad_(need_graph))
return tuple(res)
def _grad_postprocess(inputs, create_graph):
# Postprocess the generated Tensors to avoid returning Tensors with history when the user did not
# request it.
if isinstance(inputs[0], torch.Tensor):
if not create_graph:
return tuple(inp.detach() for inp in inputs)
else:
return inputs
else:
return tuple(_grad_postprocess(inp, create_graph) for inp in inputs)
def _validate_v(v, other, is_other_tuple):
# This assumes that other is the correct shape, and v should match
# Both are assumed to be tuples of Tensors
if len(other) != len(v):
if is_other_tuple:
raise RuntimeError("v is a tuple of invalid length: should be {} but got {}.".format(len(other), len(v)))
else:
raise RuntimeError("The given v should contain a single Tensor.")
for idx, (el_v, el_other) in enumerate(zip(v, other)):
if el_v.size() != el_other.size():
prepend = ""
if is_other_tuple:
prepend = "Entry {} in ".format(idx)
raise RuntimeError("{}v has invalid size: should be {} but got {}.".format(
prepend, el_other.size(), el_v.size()))
def _check_requires_grad(inputs, input_type, strict):
# Used to make all the necessary checks to raise nice errors in strict mode.
if not strict:
return
if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]:
raise RuntimeError("Invalid input_type to _check_requires_grad")
for i, inp in enumerate(inputs):
if inp is None:
# This can only be reached for grad_inputs.
raise RuntimeError("The output of the user-provided function is independent of input {}."
" This is not allowed in strict mode.".format(i))
if not inp.requires_grad:
if input_type == "hessian":
raise RuntimeError("The hessian of the user-provided function with respect to input {}"
" is independent of the input. This is not allowed in strict mode."
" You should ensure that your function is thrice differentiable and that"
" the hessian depends on the inputs.".format(i))
elif input_type == "jacobian":
raise RuntimeError("While computing the hessian, found that the jacobian of the user-provided"
" function with respect to input {} is independent of the input. This is not"
" allowed in strict mode. You should ensure that your function is twice"
" differentiable and that the jacobian depends on the inputs (this would be"
" violated by a linear function for example).".format(i))
elif input_type == "grad_inputs":
raise RuntimeError("The gradient with respect to input {} is independent of the inputs of the"
" user-provided function. This is not allowed in strict mode.".format(i))
else:
raise RuntimeError("Output {} of the user-provided function does not require gradients."
" The outputs must be computed in a differentiable manner from the input"
" when running in strict mode.".format(i))
def _autograd_grad(outputs, inputs, grad_outputs=None, create_graph=False, retain_graph=None, is_grads_batched=False):
# Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them.
# This has the extra constraint that inputs has to be a tuple
assert isinstance(outputs, tuple)
if grad_outputs is None:
grad_outputs = (None,) * len(outputs)
assert isinstance(grad_outputs, tuple)
assert len(outputs) == len(grad_outputs)
new_outputs: Tuple[torch.Tensor, ...] = tuple()
new_grad_outputs: Tuple[torch.Tensor, ...] = tuple()
for out, grad_out in zip(outputs, grad_outputs):
if out is not None and out.requires_grad:
new_outputs += (out,)
new_grad_outputs += (grad_out,)
if len(new_outputs) == 0:
# No differentiable output, we don't need to call the autograd engine
return (None,) * len(inputs)
else:
return torch.autograd.grad(new_outputs, inputs, new_grad_outputs, allow_unused=True,
create_graph=create_graph, retain_graph=retain_graph,
is_grads_batched=is_grads_batched)
def _fill_in_zeros(grads, refs, strict, create_graph, stage):
# Used to detect None in the grads and depending on the flags, either replace them
# with Tensors full of 0s of the appropriate size based on the refs or raise an error.
# strict and create graph allow us to detect when it is appropriate to raise an error
# stage gives us information of which backward call we consider to give good error message
if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
raise RuntimeError("Invalid stage argument '{}' to _fill_in_zeros".format(stage))
res: Tuple[torch.Tensor, ...] = tuple()
for i, grads_i in enumerate(grads):
if grads_i is None:
if strict:
if stage == "back":
raise RuntimeError("The output of the user-provided function is independent of "
"input {}. This is not allowed in strict mode.".format(i))
elif stage == "back_trick":
raise RuntimeError("The gradient with respect to the input is independent of entry {}"
" in the grad_outputs when using the double backward trick to compute"
" forward mode gradients. This is not allowed in strict mode.".format(i))
elif stage == "double_back":
raise RuntimeError("The jacobian of the user-provided function is independent of "
"input {}. This is not allowed in strict mode.".format(i))
else:
raise RuntimeError("The hessian of the user-provided function is independent of "
"entry {} in the grad_jacobian. This is not allowed in strict "
"mode as it prevents from using the double backward trick to "
"replace forward mode AD.".format(i))
grads_i = torch.zeros_like(refs[i])
else:
if strict and create_graph and not grads_i.requires_grad:
if "double" not in stage:
raise RuntimeError("The jacobian of the user-provided function is independent of "
"input {}. This is not allowed in strict mode when create_graph=True.".format(i))
else:
raise RuntimeError("The hessian of the user-provided function is independent of "
"input {}. This is not allowed in strict mode when create_graph=True.".format(i))
res += (grads_i,)
return res
# Public API
def vjp(func, inputs, v=None, create_graph=False, strict=False):
r"""Function that computes the dot product between a vector ``v`` and the
Jacobian of the given function at the point given by the inputs.
Args:
func (function): a Python function that takes Tensor inputs and returns
a tuple of Tensors or a Tensor.
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
v (tuple of Tensors or Tensor): The vector for which the vector
Jacobian product is computed. Must be the same size as the output
of ``func``. This argument is optional when the output of ``func``
contains a single element and (if it is not provided) will be set
as a Tensor containing a single ``1``.
create_graph (bool, optional): If ``True``, both the output and result
will be computed in a differentiable way. Note that when ``strict``
is ``False``, the result can not require gradients or be
disconnected from the inputs. Defaults to ``False``.
strict (bool, optional): If ``True``, an error will be raised when we
detect that there exists an input such that all the outputs are
independent of it. If ``False``, we return a Tensor of zeros as the
vjp for said inputs, which is the expected mathematical value.
Defaults to ``False``.
Returns:
output (tuple): tuple with:
func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
vjp (tuple of Tensors or Tensor): result of the dot product with
the same shape as the inputs.
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> def exp_reducer(x):
... return x.exp().sum(dim=1)
>>> inputs = torch.rand(4, 4)
>>> v = torch.ones(4)
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> vjp(exp_reducer, inputs, v)
(tensor([5.7817, 7.2458, 5.7830, 6.7782]),
tensor([[1.4458, 1.3962, 1.3042, 1.6354],
[2.1288, 1.0652, 1.5483, 2.5035],
[2.2046, 1.1292, 1.1432, 1.3059],
[1.3225, 1.6652, 1.7753, 2.0152]]))
>>> vjp(exp_reducer, inputs, v, create_graph=True)
(tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=<SumBackward1>),
tensor([[1.4458, 1.3962, 1.3042, 1.6354],
[2.1288, 1.0652, 1.5483, 2.5035],
[2.2046, 1.1292, 1.1432, 1.3059],
[1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>))
>>> def adder(x, y):
... return 2 * x + 3 * y
>>> inputs = (torch.rand(2), torch.rand(2))
>>> v = torch.ones(2)
>>> vjp(adder, inputs, v)
(tensor([2.4225, 2.3340]),
(tensor([2., 2.]), tensor([3., 3.])))
"""
with torch.enable_grad():
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp")
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
outputs = func(*inputs)
is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "vjp")
_check_requires_grad(outputs, "outputs", strict=strict)
if v is not None:
_, v = _as_tuple(v, "v", "vjp")
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
_validate_v(v, outputs, is_outputs_tuple)
else:
if len(outputs) != 1 or outputs[0].nelement() != 1:
raise RuntimeError("The vector v can only be None if the "
"user-provided function returns "
"a single Tensor with a single element.")
enable_grad = True if create_graph else torch.is_grad_enabled()
with torch.set_grad_enabled(enable_grad):
grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph)
vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back")
# Cleanup objects and return them to the user
outputs = _grad_postprocess(outputs, create_graph)
vjp = _grad_postprocess(vjp, create_graph)
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(vjp, is_inputs_tuple)
def jvp(func, inputs, v=None, create_graph=False, strict=False):
r"""Function that computes the dot product between the Jacobian of
the given function at the point given by the inputs and a vector ``v``.
Args:
func (function): a Python function that takes Tensor inputs and returns
a tuple of Tensors or a Tensor.
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
v (tuple of Tensors or Tensor): The vector for which the Jacobian
vector product is computed. Must be the same size as the input of
``func``. This argument is optional when the input to ``func``
contains a single element and (if it is not provided) will be set
as a Tensor containing a single ``1``.
create_graph (bool, optional): If ``True``, both the output and result
will be computed in a differentiable way. Note that when ``strict``
is ``False``, the result can not require gradients or be
disconnected from the inputs. Defaults to ``False``.
strict (bool, optional): If ``True``, an error will be raised when we
detect that there exists an input such that all the outputs are
independent of it. If ``False``, we return a Tensor of zeros as the
jvp for said inputs, which is the expected mathematical value.
Defaults to ``False``.
Returns:
output (tuple): tuple with:
func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
jvp (tuple of Tensors or Tensor): result of the dot product with
the same shape as the output.
Note:
``autograd.functional.jvp`` computes the jvp by using the backward of
the backward (sometimes called the double backwards trick). This is not
the most performant way of computing the jvp. Please consider using
:func:`torch.func.jvp` or the
:ref:`low-level forward-mode AD API <forward-mode-ad>` instead.
Example:
Loading ...