# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Union, Tuple, List, Any, Optional
import torch
from functools import partial, wraps
import contextlib
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map, tree_map_only
from torch.fx.experimental import const_fold
from torch.fx.experimental.proxy_tensor import make_fx
from .pytree_hacks import tree_map_, treespec_pprint
import torch.autograd.forward_ad as fwAD
from .vmap import vmap, doesnt_support_saved_tensors_hooks, get_chunk_sizes
from torch._C._functorch import (
_wrap_for_grad,
_unwrap_for_grad,
_grad_increment_nesting,
_grad_decrement_nesting,
_jvp_increment_nesting,
_jvp_decrement_nesting,
_wrap_functional_tensor,
_unwrap_functional_tensor,
_func_decrement_nesting,
_func_increment_nesting,
_assert_wrapped_functional,
_propagate_functional_input_mutation,
set_inplace_requires_grad_allowed,
get_inplace_requires_grad_allowed
)
from torch._functorch.utils import exposed_in
argnums_t = Union[int, Tuple[int, ...]]
@contextlib.contextmanager
def enable_inplace_requires_grad(enabled=True):
prev_state = get_inplace_requires_grad_allowed()
set_inplace_requires_grad_allowed(enabled)
try:
yield
finally:
set_inplace_requires_grad_allowed(prev_state)
def _create_differentiable(inps, level=None):
def create_differentiable(x):
if isinstance(x, torch.Tensor):
with enable_inplace_requires_grad():
return x.requires_grad_()
raise ValueError(f'Thing passed to transform API must be Tensor, '
f'got {type(x)}')
return tree_map(create_differentiable, inps)
def _undo_create_differentiable(inps, level=None):
def unwrap_tensors(x):
if isinstance(x, torch.Tensor):
return _unwrap_for_grad(x, level)
# TODO: Remove the following hack for namedtuples
if isinstance(x, tuple):
return tree_map(unwrap_tensors, tuple(x))
raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}")
return tree_map(unwrap_tensors, inps)
def _is_differentiable(maybe_tensor):
if not isinstance(maybe_tensor, torch.Tensor):
return False
return maybe_tensor.requires_grad
def _any_differentiable(tensor_or_tuple_of_tensors):
flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors)
return any(tuple(map(_is_differentiable, flat_args)))
def _wrap_tensor_for_grad(maybe_tensor, level):
if not isinstance(maybe_tensor, torch.Tensor):
return maybe_tensor
return _wrap_for_grad(maybe_tensor, level)
def _wrap_all_tensors(tensor_pytree, level):
return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree)
def _as_tuple(val):
if isinstance(val, tuple):
return val
return (val,)
# Version of autograd.grad that handles outputs that don't depend on inputs
def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
if grad_outputs is None:
diff_outputs = tuple(out for out in outputs if out.requires_grad)
else:
result = tuple((out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad)
if len(result) == 0:
diff_outputs, grad_outputs = (), ()
else:
diff_outputs, grad_outputs = zip(*result)
if len(diff_outputs) == 0:
return tuple(torch.zeros_like(inp) for inp in inputs)
grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
retain_graph=retain_graph,
create_graph=create_graph,
allow_unused=True)
grad_inputs = tuple(torch.zeros_like(inp) if gi is None else gi
for gi, inp in zip(grad_inputs, inputs))
return grad_inputs
# NOTE [grad and vjp interaction with no_grad]
#
# def f(x):
# with torch.no_grad():
# c = x ** 2
# return x - c
#
# The thing to consider is if enable_grad is on/off before grad gets called.
#
# Case 1: enable_grad is on.
# grad(f)(x)
# In this case, `grad` should respect the inner torch.no_grad.
#
# Case 2: enable_grad is off
# with torch.no_grad():
# grad(f)(x)
# In this case, `grad` should respect the inner torch.no_grad, but not the
# outer one. This is because `grad` is a "function transform": its result
# should not depend on the result of a context manager outside of `f`.
#
# This gives us the following desired behavior:
# - (nested) grad transforms must obey torch.no_grad inside them
# - (nested) grad transforms should not obey torch.no_grad outside them
#
# To achieve this behavior, upon entering grad/vjp:
# - we save the current ("previous") is_grad_enabled (*)
# - we unconditionally enable grad.
#
# Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer
# off the stack:
# - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad
# active, all subsequent grad transforms must obey it).
# - if grad_mode is enabled, and the previous is_grad_enabled (*) is False,
# then we temporarily restore the previous `is_grad_enabled`. This is
# because we're crossing the boundary from a `grad` outside the
# no_grad to a `grad` inside the no_grad.
#
# NB: vjp has some interesting behavior because the vjp's callable can be called
# under a different grad_mode than the forward computation...
#
# NB: forward-mode AD: forward-mode AD doesn't respect torch.no_grad, but
# it respects c10::AutoFwGradMode. We've implemented the same logic for
# our jvp transform (it will have special handling if FwGradMode is disabled).
# How do we increment and decrement the nesting? I don't think we can.
@exposed_in("torch.func")
def vjp(func: Callable, *primals, has_aux: bool = False):
"""
Standing for the vector-Jacobian product, returns a tuple containing the
results of ``func`` applied to ``primals`` and a function that, when
given ``cotangents``, computes the reverse-mode Jacobian of ``func`` with
respect to ``primals`` times ``cotangents``.
Args:
func (Callable): A Python function that takes one or more arguments. Must
return one or more Tensors.
primals (Tensors): Positional arguments to ``func`` that must all be
Tensors. The returned function will also be computing the
derivative with respect to these arguments
has_aux (bool): Flag indicating that ``func`` returns a
``(output, aux)`` tuple where the first element is the output of
the function to be differentiated and the second element is
other auxiliary objects that will not be differentiated.
Default: False.
Returns:
Returns a ``(output, vjp_fn)`` tuple containing the output of ``func``
applied to ``primals`` and a function that computes the vjp of
``func`` with respect to all ``primals`` using the cotangents passed
to the returned function. If ``has_aux is True``, then instead returns a
``(output, vjp_fn, aux)`` tuple.
The returned ``vjp_fn`` function will return a tuple of each VJP.
When used in simple cases, :func:`vjp` behaves the same as :func:`grad`
>>> x = torch.randn([5])
>>> f = lambda x: x.sin().sum()
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> grad = vjpfunc(torch.tensor(1.))[0]
>>> assert torch.allclose(grad, torch.func.grad(f)(x))
However, :func:`vjp` can support functions with multiple outputs by
passing in the cotangents for each of the outputs
>>> x = torch.randn([5])
>>> f = lambda x: (x.sin(), x.cos())
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
:func:`vjp` can even support outputs being Python structs
>>> x = torch.randn([5])
>>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
>>> vjps = vjpfunc(cotangents)
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
The function returned by :func:`vjp` will compute the partials with
respect to each of the ``primals``
>>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
>>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y)
>>> cotangents = torch.randn([5, 5])
>>> vjps = vjpfunc(cotangents)
>>> assert len(vjps) == 2
>>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
>>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
``primals`` are the positional arguments for ``f``. All kwargs use their
default value
>>> x = torch.randn([5])
>>> def f(x, scale=4.):
>>> return x * scale
>>>
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc(torch.ones_like(x))
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
.. note::
Using PyTorch ``torch.no_grad`` together with ``vjp``.
Case 1: Using ``torch.no_grad`` inside a function:
>>> def f(x):
>>> with torch.no_grad():
>>> c = x ** 2
>>> return x - c
In this case, ``vjp(f)(x)`` will respect the inner ``torch.no_grad``.
Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager:
>>> # xdoctest: +SKIP(failing)
>>> with torch.no_grad():
>>> vjp(f)(x)
In this case, ``vjp`` will respect the inner ``torch.no_grad``, but not the
outer one. This is because ``vjp`` is a "function transform": its result
should not depend on the result of a context manager outside of ``f``.
"""
return _vjp_with_argnums(func, *primals, has_aux=has_aux)
@doesnt_support_saved_tensors_hooks
def _vjp_with_argnums(func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False):
# This is the same function as vjp but also accepts an argnums argument
# All args are the same as vjp except for the added argument
# argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to.
# If None, computes the gradients with respect to all inputs (used for vjp). Default: None
#
# WARN: Users should NOT call this function directly and should just be calling vjp.
# It is only separated so that inputs passed to jacrev but not differentiated get the correct wrappers.
#
# NOTE: All error messages are produced as if vjp was being called, even if this was called by jacrev
#
# Returns the same two elements as :func:`vjp` but the function returned, vjp_fn, returns a tuple of VJPs
# for only the primal elements given by argnums.
level = _grad_increment_nesting()
try:
# See NOTE [grad and vjp interaction with no_grad]
with torch.enable_grad():
primals = _wrap_all_tensors(primals, level)
if argnums is None:
diff_primals = _create_differentiable(primals, level)
else:
diff_primals = _slice_argnums(primals, argnums, as_tuple=False)
tree_map_(partial(_create_differentiable, level=level), diff_primals)
primals_out = func(*primals)
if has_aux:
if not (isinstance(primals_out, tuple) and len(primals_out) == 2):
raise RuntimeError(
"vjp(f, *primals): output of function f should be a tuple: (output, aux) "
"if has_aux is True"
)
primals_out, aux = primals_out
aux = _undo_create_differentiable(aux, level)
flat_primals_out, primals_out_spec = tree_flatten(primals_out)
assert_non_empty_tensor_output(flat_primals_out, 'vjp(f, *primals)')
flat_diff_primals, primals_spec = tree_flatten(diff_primals)
results = _undo_create_differentiable(primals_out, level)
for primal_out in flat_primals_out:
assert isinstance(primal_out, torch.Tensor)
if primal_out.is_floating_point() or primal_out.is_complex():
continue
raise RuntimeError("vjp(f, ...): All outputs of f must be "
"floating-point or complex Tensors, got Tensor "
f"with dtype {primal_out.dtype}")
def wrapper(cotangents, retain_graph=True, create_graph=None):
if create_graph is None:
create_graph = torch.is_grad_enabled()
flat_cotangents, cotangents_spec = tree_flatten(cotangents)
if primals_out_spec != cotangents_spec:
raise RuntimeError(
f'Expected pytree structure of cotangents to be the same '
f'as pytree structure of outputs to the function. '
f'cotangents: {treespec_pprint(cotangents_spec)}, '
f'primal output: {treespec_pprint(primals_out_spec)}')
result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,
retain_graph=retain_graph, create_graph=create_graph)
return tree_unflatten(result, primals_spec)
finally:
_grad_decrement_nesting()
if has_aux:
return results, wrapper, aux
else:
return results, wrapper
def _safe_zero_index(x):
assert len(x) == 1
return x[0]
# jacrev and jacfwd don't support complex functions
# Helper function to throw appropriate error.
def error_if_complex(func_name, args, is_input):
flat_args, _ = tree_flatten(args)
Loading ...