from collections import OrderedDict, namedtuple
import itertools
import warnings
import functools
import torch
from ..parameter import Parameter
import torch.utils.hooks as hooks
from torch import Tensor, device, dtype
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List
from ...utils.hooks import RemovableHandle
_grad_t = Union[Tuple[Tensor, ...], Tensor]
# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use
# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be
# the type of the subclass, not the looser type of `Module`.
T = TypeVar('T', bound='Module')
class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])):
def __repr__(self):
if not self.missing_keys and not self.unexpected_keys:
return '<All keys matched successfully>'
return super(_IncompatibleKeys, self).__repr__()
__str__ = __repr__
def _addindent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
r"""This tracks hooks common to all modules that are executed before/after
calling forward and backward. This is global state used for debugging/profiling
purposes"""
_global_backward_hooks: Dict[int, Callable] = OrderedDict()
_global_is_full_backward_hook: Optional[bool] = None
_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict()
_global_forward_hooks: Dict[int, Callable] = OrderedDict()
def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
r"""Registers a forward pre-hook common to all modules.
.. warning ::
This adds global state to the `nn.module` module
and it is only intended for debugging/profiling purposes.
The hook will be called every time before :func:`forward` is invoked.
It should have the following signature::
hook(module, input) -> None or modified input
The input contains only the positional arguments given to the module.
Keyword arguments won't be passed to the hooks and only to the ``forward``.
The hook can modify the input. User can either return a tuple or a
single modified value in the hook. We will wrap the value into a tuple
if a single value is returned(unless that value is already a tuple).
This hook has precedence over the specific module hooks registered with
``register_forward_pre_hook``.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(_global_forward_pre_hooks)
_global_forward_pre_hooks[handle.id] = hook
return handle
def register_module_forward_hook(hook: Callable[..., None]) -> RemovableHandle:
r"""Registers a global forward hook for all the modules
.. warning ::
This adds global state to the `nn.module` module
and it is only intended for debugging/profiling purposes.
The hook will be called every time after :func:`forward` has computed an output.
It should have the following signature::
hook(module, input, output) -> None or modified output
The input contains only the positional arguments given to the module.
Keyword arguments won't be passed to the hooks and only to the ``forward``.
The hook can modify the output. It can modify the input inplace but
it will not have effect on forward since this is called after
:func:`forward` is called.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
This hook will be executed before specific module hooks registered with
``register_forward_hook``.
"""
handle = hooks.RemovableHandle(_global_forward_hooks)
_global_forward_hooks[handle.id] = hook
return handle
def register_module_backward_hook(
hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
) -> RemovableHandle:
r"""Registers a backward hook common to all the modules.
This function is deprecated in favor of :meth:`nn.module.register_module_full_backward_hook`
and the behavior of this function will change in future versions.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
global _global_is_full_backward_hook
if _global_is_full_backward_hook is True:
raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a "
"global Module hook. Please use only one of them.")
_global_is_full_backward_hook = False
handle = hooks.RemovableHandle(_global_backward_hooks)
_global_backward_hooks[handle.id] = hook
return handle
def register_module_full_backward_hook(
hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
) -> RemovableHandle:
r"""Registers a backward hook common to all the modules.
.. warning ::
This adds global state to the `nn.module` module
and it is only intended for debugging/profiling purposes.
The current implementation will not have the presented behavior
for complex :class:`Module` that perform many operations.
In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only
contain the gradients for a subset of the inputs and outputs.
For such :class:`Module`, you should use :func:`torch.Tensor.register_hook`
directly on a specific input or output to get the required gradients.
The hook will be called every time the gradients with respect to module
inputs are computed. The hook should have the following signature::
hook(module, grad_input, grad_output) -> Tensor or None
The :attr:`grad_input` and :attr:`grad_output` are tuples. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the input that will be used in place of :attr:`grad_input` in
subsequent computations. :attr:`grad_input` will only correspond to the inputs given
as positional arguments and all kwarg arguments will not appear in the hook. Entries
in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor
arguments.
Global hooks are called before hooks registered with `register_backward_hook`
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
global _global_is_full_backward_hook
if _global_is_full_backward_hook is False:
raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a "
"global Module hook. Please use only one of them.")
_global_is_full_backward_hook = True
handle = hooks.RemovableHandle(_global_backward_hooks)
_global_backward_hooks[handle.id] = hook
return handle
# Trick mypy into not applying contravariance rules to inputs by defining
# forward as a value, rather than a function. See also
# https://github.com/python/mypy/issues/8795
def _forward_unimplemented(self, *input: Any) -> None:
r"""Defines the computation performed at every call.
Should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
"""
raise NotImplementedError
class Module:
r"""Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:`to`, etc.
:ivar training: Boolean represents whether this module is in training or
evaluation mode.
:vartype training: bool
"""
dump_patches: bool = False
r"""This allows better BC support for :meth:`load_state_dict`. In
:meth:`state_dict`, the version number will be saved as in the attribute
`_metadata` of the returned state dict, and thus pickled. `_metadata` is a
dictionary with keys that follow the naming convention of state dict. See
``_load_from_state_dict`` on how to use this information in loading.
If new parameters/buffers are added/removed from a module, this number shall
be bumped, and the module's `_load_from_state_dict` method can compare the
version number and do appropriate changes if the state dict is from before
the change."""
_version: int = 1
training: bool
_is_full_backward_hook: Optional[bool]
def __init__(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")
self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict()
self._is_full_backward_hook = None
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
forward: Callable[..., Any] = _forward_unimplemented
def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
r"""Adds a buffer to the module.
This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``
is not a parameter, but is part of the module's state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting :attr:`persistent` to ``False``. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module's
:attr:`state_dict`.
Buffers can be accessed as attributes using given names.
Args:
name (string): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor): buffer to be registered.
persistent (bool): whether the buffer is part of this module's
:attr:`state_dict`.
Example::
>>> self.register_buffer('running_mean', torch.zeros(num_features))
"""
if persistent is False and isinstance(self, torch.jit.ScriptModule):
raise RuntimeError("ScriptModule does not support non-persistent buffers")
if '_buffers' not in self.__dict__:
raise AttributeError(
"cannot assign buffer before Module.__init__() call")
elif not isinstance(name, torch._six.string_classes):
raise TypeError("buffer name should be a string. "
"Got {}".format(torch.typename(name)))
elif '.' in name:
raise KeyError("buffer name can't contain \".\"")
elif name == '':
raise KeyError("buffer name can't be empty string \"\"")
elif hasattr(self, name) and name not in self._buffers:
raise KeyError("attribute '{}' already exists".format(name))
elif tensor is not None and not isinstance(tensor, torch.Tensor):
raise TypeError("cannot assign '{}' object to buffer '{}' "
"(torch Tensor or None required)"
.format(torch.typename(tensor), name))
else:
self._buffers[name] = tensor
if persistent:
self._non_persistent_buffers_set.discard(name)
else:
self._non_persistent_buffers_set.add(name)
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
r"""Adds a parameter to the module.
The parameter can be accessed as an attribute using given name.
Args:
name (string): name of the parameter. The parameter can be accessed
from this module using the given name
param (Parameter): parameter to be added to the module.
"""
if '_parameters' not in self.__dict__:
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
elif not isinstance(name, torch._six.string_classes):
raise TypeError("parameter name should be a string. "
"Got {}".format(torch.typename(name)))
elif '.' in name:
raise KeyError("parameter name can't contain \".\"")
elif name == '':
raise KeyError("parameter name can't be empty string \"\"")
elif hasattr(self, name) and name not in self._parameters:
raise KeyError("attribute '{}' already exists".format(name))
Loading ...