import torch
import contextlib
from typing import Callable, Any, Dict, Tuple, Optional, Sequence, List, Set
from torch.utils.hooks import RemovableHandle
from torch.utils._python_dispatch import TorchDispatchMode
from collections import defaultdict
import weakref
import abc
__all__ = [
"saved_tensors_hooks",
"save_on_cpu",
"disable_saved_tensors_hooks",
"register_multi_grad_hook",
"allow_mutation_on_saved_tensors",
"Node",
]
class Node(abc.ABC):
@abc.abstractmethod
def name(self) -> str:
r"""Returns the name.
Example::
>>> import torch
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
>>> b = a.clone()
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
>>> print(b.grad_fn.name())
CloneBackward0
"""
...
@property
@abc.abstractmethod
def next_functions(self) -> Tuple[Tuple[Optional['Node'], int], ...]:
...
@abc.abstractmethod
def metadata(self) -> dict:
r"""Returns the metadata."""
...
@abc.abstractmethod
def _register_hook_dict(self, tensor: torch.Tensor) -> None:
...
@abc.abstractmethod
def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle:
r"""Registers a backward hook.
The hook will be called every time a gradient with respect to the
Node is computed. The hook should have the following signature::
hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
The hook should not modify its argument, but it can optionally return
a new gradient which will be used in place of :attr:`grad_outputs`.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
.. note::
See :ref:`backward-hooks-execution` for more information on how when this hook
is executed, and how its execution is ordered relative to other hooks.
Example::
>>> import torch
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
>>> b = a.clone()
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
>>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
>>> b.sum().backward(retain_graph=True)
>>> print(a.grad)
tensor([2., 2., 2.])
>>> handle.remove() # Removes the hook
>>> a.grad = None
>>> b.sum().backward(retain_graph=True)
>>> print(a.grad)
tensor([1., 1., 1.])
"""
...
@abc.abstractmethod
def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle:
r"""Registers a backward pre-hook.
The hook will be called every time a gradient with respect to the
Node is computed. The hook should have the following signature::
hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
The hook should not modify its argument, but it can optionally return
a new gradient which will be used in place of :attr:`grad_outputs`.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
.. note::
See :ref:`backward-hooks-execution` for more information on how when this hook
is executed, and how its execution is ordered relative to other hooks.
Example::
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
>>> b = a.clone()
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
>>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,))
>>> b.sum().backward(retain_graph=True)
>>> print(a.grad)
tensor([2., 2., 2.])
>>> handle.remove()
>>> a.grad = None
>>> b.sum().backward(retain_graph=True)
>>> print(a.grad)
tensor([1., 1., 1.])
"""
...
@classmethod
def __subclasshook__(cls, C):
if cls is Node:
if ((C is not None and C is getattr(torch._C._functions, C.__name__, None))
or issubclass(C, torch.autograd.function.BackwardCFunction)):
return True
return NotImplemented
class saved_tensors_hooks():
"""Context-manager that sets a pair of pack / unpack hooks for saved tensors.
Use this context-manager to define how intermediary results of an operation
should be packed before saving, and unpacked on retrieval.
In that context, the ``pack_hook`` function will be called everytime an
operation saves a tensor for backward (this includes intermediary results
saved using
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
also those recorded by a PyTorch-defined operation). The output of
``pack_hook`` is then stored in the computation graph instead of the
original tensor.
The ``unpack_hook`` is called when the saved tensor needs to be accessed,
namely when executing :func:`torch.Tensor.backward()` or
:func:`torch.autograd.grad()`. It takes as argument the *packed* object
returned by ``pack_hook`` and should return a tensor which has the same
content as the original tensor (passed as input to the corresponding
``pack_hook``).
The hooks should have the following signatures:
pack_hook(tensor: Tensor) -> Any
unpack_hook(Any) -> Tensor
where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
of value, size, dtype and device.
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> def pack_hook(x):
... print("Packing", x)
... return x
>>>
>>> def unpack_hook(x):
... print("Unpacking", x)
... return x
>>>
>>> a = torch.ones(5, requires_grad=True)
>>> b = torch.ones(5, requires_grad=True) * 2
>>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
... y = a * b
Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
>>> y.sum().backward()
Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
.. warning ::
Performing an inplace operation on the input to either hooks may lead
to undefined behavior.
.. warning ::
Only one pair of hooks is allowed at a time. When recursively nesting this
context-manager, only the inner-most pair of hooks will be applied.
"""
def __init__(self, pack_hook: Callable[[torch.Tensor], Any], unpack_hook: Callable[[Any], torch.Tensor]):
self.pack_hook = pack_hook
self.unpack_hook = unpack_hook
def __enter__(self):
torch._C._autograd._push_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook)
def __exit__(self, *args: Any):
torch._C._autograd._pop_saved_tensors_default_hooks()
class save_on_cpu(saved_tensors_hooks):
"""Context-manager under which tensors saved by the forward pass will be
stored on cpu, then retrieved for backward.
When performing operations within this context manager, intermediary
results saved in the graph during the forward pass will be moved to CPU,
then copied back to the original device when needed for the backward pass.
If the graph was already on CPU, no tensor copy is performed.
Use this context-manager to trade compute for GPU memory usage (e.g.
when your model doesn't fit in GPU memory during training).
Args:
pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
during packing and copied to GPU asynchronously during unpacking.
Defaults to ``False``.
Also see :ref:`cuda-memory-pinning`.
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> a = torch.randn(5, requires_grad=True, device="cuda")
>>> b = torch.randn(5, requires_grad=True, device="cuda")
>>> c = torch.randn(5, requires_grad=True, device="cuda")
>>>
>>> def f(a, b, c):
... prod_1 = a * b # a and b are saved on GPU
... with torch.autograd.graph.save_on_cpu():
... prod_2 = prod_1 * c # prod_1 and c are saved on CPU
... y = prod_2 * a # prod_2 and a are saved on GPU
... return y
>>>
>>> y = f(a, b, c)
>>> del a, b, c # for illustration only
>>> # the content of a, b, and prod_2 are still alive on GPU
>>> # the content of prod_1 and c only live on CPU
>>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward
>>> # all intermediary tensors are released (deleted) after the call to backward
"""
def __init__(self, pin_memory=False):
def pack_to_cpu(tensor):
if not pin_memory:
return (tensor.device, tensor.cpu())
packed = torch.empty(
tensor.size(),
dtype=tensor.dtype,
layout=tensor.layout,
pin_memory=(torch.cuda.is_available() and not tensor.is_sparse))
packed.copy_(tensor)
return (tensor.device, packed)
def unpack_from_cpu(packed):
device, tensor = packed
return tensor.to(device, non_blocking=pin_memory)
super().__init__(pack_to_cpu, unpack_from_cpu)
@contextlib.contextmanager
def disable_saved_tensors_hooks(error_message):
"""Context-manager that disables the saved tensors default hooks feature.
Useful for if you are creating a feature that does not work with saved
tensors default hooks.
Args:
error_message (str): When saved tensors default hooks are used when they
have been are disabled, a RuntimeError with this
error message gets raised.
Example::
>>> # xdoctest: +SKIP(failing)
>>> message = "saved tensors default hooks are disabled"
>>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
... # Raises RuntimeError: saved tensors default hooks are disabled
... with torch.autograd.graph.save_on_cpu():
... pass
"""
try:
maybe_prev_message = torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
torch._C._autograd._saved_tensors_hooks_disable(error_message)
yield
finally:
# See NOTE: [disabled_error_message invariant]
if maybe_prev_message is None:
torch._C._autograd._saved_tensors_hooks_enable()
else:
torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
def register_multi_grad_hook(tensors: Sequence[torch.Tensor], fn: Callable[[Sequence[Optional[torch.Tensor]]], None]):
r"""Registers a multi-grad backward hook.
The hook will be called after gradients with respect to every tensor in
:attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but
is not part of the graph, or if a tensor is not needed to compute the gradients
for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call,
this tensor will be ignored and the hook will not wait for its gradient to be
computed.
After every non-ignored tensor's gradient has been computed, :attr:`fn` will be
called with those gradients. ``None`` will be passed for tensors that did not
have their gradients computed.
The hook should not modify its arguments.
This function returns a handle with a method ``handle.remove()`` that removes the hook.
.. note::
See :ref:`backward-hooks-execution` for more information on how when this hook
is executed, and how its execution is ordered relative to other hooks.
Example::
>>> import torch
>>>
>>> a = torch.rand(2, 3, requires_grad=True)
>>> b = torch.rand(2, 3, requires_grad=True)
>>> c = a * b
>>> d = a * b
>>>
>>> def fn(grads):
... print([g is not None for g in grads])
...
>>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
>>>
>>> c.sum().backward(retain_graph=True)
[True, True, True, False]
>>> c.sum().backward(inputs=(a,), retain_graph=True)
[True, False, True, False]
>>>
"""
count: Dict[int, int] = dict()
nb_calls = None
buffer: Dict[int, List[Optional[torch.Tensor]]] = dict()
def get_grad_fn(t):
Loading ...