from collections import OrderedDict, namedtuple
import itertools
import warnings
import functools
import weakref
import torch
from ..parameter import Parameter
import torch.utils.hooks as hooks
from torch import Tensor, device, dtype
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List
from ...utils.hooks import RemovableHandle
__all__ = ['register_module_forward_pre_hook', 'register_module_forward_hook',
'register_module_full_backward_pre_hook', 'register_module_backward_hook',
'register_module_full_backward_hook', 'register_module_buffer_registration_hook',
'register_module_module_registration_hook', 'register_module_parameter_registration_hook', 'Module']
_grad_t = Union[Tuple[Tensor, ...], Tensor]
# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use
# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be
# the type of the subclass, not the looser type of `Module`.
T = TypeVar('T', bound='Module')
class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])):
def __repr__(self):
if not self.missing_keys and not self.unexpected_keys:
return '<All keys matched successfully>'
return super().__repr__()
__str__ = __repr__
def _addindent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
r"""This tracks hooks common to all modules that are executed immediately before
.registering the buffer/module/parameter"""
_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict()
_global_module_registration_hooks: Dict[int, Callable] = OrderedDict()
_global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict()
class _WrappedHook:
def __init__(self, hook: Callable, module: Optional["Module"] = None):
self.hook: Callable = hook
functools.update_wrapper(self, hook)
self.with_module: bool = False
if module is not None:
self.module: weakref.ReferenceType["Module"] = weakref.ref(module)
self.with_module = True
def __call__(self, *args: Any, **kwargs: Any) -> Any:
if self.with_module:
module = self.module()
if module is None:
raise RuntimeError("You are trying to call the hook of a dead Module!")
return self.hook(module, *args, **kwargs)
return self.hook(*args, **kwargs)
def __getstate__(self) -> Dict:
result = {"hook": self.hook, "with_module": self.with_module}
if self.with_module:
result["module"] = self.module()
return result
def __setstate__(self, state: Dict):
self.hook = state["hook"]
self.with_module = state["with_module"]
if self.with_module:
if state["module"] is None:
raise RuntimeError("You are trying to revive the hook of a dead Module!")
self.module = weakref.ref(state["module"])
r"""This tracks hooks common to all modules that are executed before/after
calling forward and backward. This is global state used for debugging/profiling
purposes"""
_global_backward_pre_hooks: Dict[int, Callable] = OrderedDict()
_global_backward_hooks: Dict[int, Callable] = OrderedDict()
_global_is_full_backward_hook: Optional[bool] = None
_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict()
_global_forward_hooks: Dict[int, Callable] = OrderedDict()
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
def register_module_buffer_registration_hook(hook: Callable[..., None]) -> RemovableHandle:
r"""Registers a buffer registration hook common to all modules.
.. warning ::
This adds global state to the `nn.Module` module
The hook will be called every time :func:`register_buffer` is invoked.
It should have the following signature::
hook(module, name, buffer) -> None or new buffer
The hook can modify the input or return a single modified value in the hook.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(_global_buffer_registration_hooks)
_global_buffer_registration_hooks[handle.id] = hook
return handle
def register_module_module_registration_hook(hook: Callable[..., None]) -> RemovableHandle:
r"""Registers a module registration hook common to all modules.
.. warning ::
This adds global state to the `nn.Module` module
The hook will be called every time :func:`register_module` is invoked.
It should have the following signature::
hook(module, name, submodule) -> None or new submodule
The hook can modify the input or return a single modified value in the hook.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(_global_module_registration_hooks)
_global_module_registration_hooks[handle.id] = hook
return handle
def register_module_parameter_registration_hook(hook: Callable[..., None]) -> RemovableHandle:
r"""Registers a parameter registration hook common to all modules.
.. warning ::
This adds global state to the `nn.Module` module
The hook will be called every time :func:`register_parameter` is invoked.
It should have the following signature::
hook(module, name, param) -> None or new parameter
The hook can modify the input or return a single modified value in the hook.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(_global_parameter_registration_hooks)
_global_parameter_registration_hooks[handle.id] = hook
return handle
def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
r"""Registers a forward pre-hook common to all modules.
.. warning ::
This adds global state to the `nn.module` module
and it is only intended for debugging/profiling purposes.
The hook will be called every time before :func:`forward` is invoked.
It should have the following signature::
hook(module, input) -> None or modified input
The input contains only the positional arguments given to the module.
Keyword arguments won't be passed to the hooks and only to the ``forward``.
The hook can modify the input. User can either return a tuple or a
single modified value in the hook. We will wrap the value into a tuple
if a single value is returned(unless that value is already a tuple).
This hook has precedence over the specific module hooks registered with
``register_forward_pre_hook``.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(_global_forward_pre_hooks)
_global_forward_pre_hooks[handle.id] = hook
return handle
def register_module_forward_hook(hook: Callable[..., None]) -> RemovableHandle:
r"""Registers a global forward hook for all the modules
.. warning ::
This adds global state to the `nn.module` module
and it is only intended for debugging/profiling purposes.
The hook will be called every time after :func:`forward` has computed an output.
It should have the following signature::
hook(module, input, output) -> None or modified output
The input contains only the positional arguments given to the module.
Keyword arguments won't be passed to the hooks and only to the ``forward``.
The hook can modify the output. It can modify the input inplace but
it will not have effect on forward since this is called after
:func:`forward` is called.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
This hook will be executed before specific module hooks registered with
``register_forward_hook``.
"""
handle = hooks.RemovableHandle(_global_forward_hooks)
_global_forward_hooks[handle.id] = hook
return handle
def register_module_backward_hook(
hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]]
) -> RemovableHandle:
r"""Registers a backward hook common to all the modules.
This function is deprecated in favor of
:func:`torch.nn.modules.module.register_module_full_backward_hook`
and the behavior of this function will change in future versions.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
global _global_is_full_backward_hook
if _global_is_full_backward_hook is True:
raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a "
"global Module hook. Please use only one of them.")
_global_is_full_backward_hook = False
handle = hooks.RemovableHandle(_global_backward_hooks)
_global_backward_hooks[handle.id] = hook
return handle
def register_module_full_backward_pre_hook(
hook: Callable[['Module', _grad_t], Union[None, _grad_t]]
) -> RemovableHandle:
r"""Registers a backward pre-hook common to all the modules.
.. warning ::
This adds global state to the `nn.module` module
and it is only intended for debugging/profiling purposes.
The hook will be called every time the gradients for the module are computed.
The hook should have the following signature::
hook(module, grad_output) -> Tensor or None
The :attr:`grad_output` is a tuple. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the output that will be used in place of :attr:`grad_output` in
subsequent computations. Entries in :attr:`grad_output` will be ``None`` for
all non-Tensor arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module's forward function.
Global hooks are called before hooks registered with `register_backward_pre_hook`
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(_global_backward_pre_hooks)
_global_backward_pre_hooks[handle.id] = hook
return handle
def register_module_full_backward_hook(
hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]]
) -> RemovableHandle:
r"""Registers a backward hook common to all the modules.
.. warning ::
This adds global state to the `nn.module` module
and it is only intended for debugging/profiling purposes.
The hook will be called every time the gradients with respect to a module
are computed, i.e. the hook will execute if and only if the gradients with
respect to module outputs are computed. The hook should have the following
signature::
hook(module, grad_input, grad_output) -> Tensor or None
The :attr:`grad_input` and :attr:`grad_output` are tuples. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the input that will be used in place of :attr:`grad_input` in
subsequent computations. :attr:`grad_input` will only correspond to the inputs given
as positional arguments and all kwarg arguments will not appear in the hook. Entries
in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor
arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module's forward function.
Global hooks are called before hooks registered with `register_backward_hook`
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
global _global_is_full_backward_hook
if _global_is_full_backward_hook is False:
raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a "
"global Module hook. Please use only one of them.")
_global_is_full_backward_hook = True
handle = hooks.RemovableHandle(_global_backward_hooks)
_global_backward_hooks[handle.id] = hook
Loading ...