from typing import List, Optional
from torch import Tensor
from .module import Module
from .utils import _single, _pair, _triple
from .. import functional as F
from ..common_types import (_size_any_t, _size_1_t, _size_2_t, _size_3_t,
_ratio_3_t, _ratio_2_t, _size_any_opt_t, _size_2_opt_t, _size_3_opt_t)
__all__ = ['MaxPool1d', 'MaxPool2d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d',
'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'FractionalMaxPool2d', 'FractionalMaxPool3d', 'LPPool1d',
'LPPool2d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d']
class _MaxPoolNd(Module):
__constants__ = ['kernel_size', 'stride', 'padding', 'dilation',
'return_indices', 'ceil_mode']
return_indices: bool
ceil_mode: bool
def __init__(self, kernel_size: _size_any_t, stride: Optional[_size_any_t] = None,
padding: _size_any_t = 0, dilation: _size_any_t = 1,
return_indices: bool = False, ceil_mode: bool = False) -> None:
super().__init__()
self.kernel_size = kernel_size
self.stride = stride if (stride is not None) else kernel_size
self.padding = padding
self.dilation = dilation
self.return_indices = return_indices
self.ceil_mode = ceil_mode
def extra_repr(self) -> str:
return 'kernel_size={kernel_size}, stride={stride}, padding={padding}' \
', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
class MaxPool1d(_MaxPoolNd):
r"""Applies a 1D max pooling over an input signal composed of several input
planes.
In the simplest case, the output value of the layer with input size :math:`(N, C, L)`
and output :math:`(N, C, L_{out})` can be precisely described as:
.. math::
out(N_i, C_j, k) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
input(N_i, C_j, stride \times k + m)
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
sliding window. This `link`_ has a nice visualization of the pooling parameters.
Note:
When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
or the input. Sliding windows that would start in the right padded region are ignored.
Args:
kernel_size: The size of the sliding window, must be > 0.
stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`.
padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
dilation: The stride between elements within a sliding window, must be > 0.
return_indices: If ``True``, will return the argmax along with the max values.
Useful for :class:`torch.nn.MaxUnpool1d` later
ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
ensures that every element in the input tensor is covered by a sliding window.
Shape:
- Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
- Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
.. math::
L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
\times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
Examples::
>>> # pool of size=3, stride=2
>>> m = nn.MaxPool1d(3, stride=2)
>>> input = torch.randn(20, 16, 50)
>>> output = m(input)
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
kernel_size: _size_1_t
stride: _size_1_t
padding: _size_1_t
dilation: _size_1_t
def forward(self, input: Tensor):
return F.max_pool1d(input, self.kernel_size, self.stride,
self.padding, self.dilation, ceil_mode=self.ceil_mode,
return_indices=self.return_indices)
class MaxPool2d(_MaxPoolNd):
r"""Applies a 2D max pooling over an input signal composed of several input
planes.
In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,
output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)`
can be precisely described as:
.. math::
\begin{aligned}
out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
& \text{input}(N_i, C_j, \text{stride[0]} \times h + m,
\text{stride[1]} \times w + n)
\end{aligned}
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
Note:
When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
or the input. Sliding windows that would start in the right padded region are ignored.
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
- a single ``int`` -- in which case the same value is used for the height and width dimension
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
and the second `int` for the width dimension
Args:
kernel_size: the size of the window to take a max over
stride: the stride of the window. Default value is :attr:`kernel_size`
padding: Implicit negative infinity padding to be added on both sides
dilation: a parameter that controls the stride of elements in the window
return_indices: if ``True``, will return the max indices along with the outputs.
Useful for :class:`torch.nn.MaxUnpool2d` later
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
Shape:
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
.. math::
H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
\times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
.. math::
W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
\times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
Examples::
>>> # pool of square window of size=3, stride=2
>>> m = nn.MaxPool2d(3, stride=2)
>>> # pool of non-square window
>>> m = nn.MaxPool2d((3, 2), stride=(2, 1))
>>> input = torch.randn(20, 16, 50, 32)
>>> output = m(input)
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
kernel_size: _size_2_t
stride: _size_2_t
padding: _size_2_t
dilation: _size_2_t
def forward(self, input: Tensor):
return F.max_pool2d(input, self.kernel_size, self.stride,
self.padding, self.dilation, ceil_mode=self.ceil_mode,
return_indices=self.return_indices)
class MaxPool3d(_MaxPoolNd):
r"""Applies a 3D max pooling over an input signal composed of several input
planes.
In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,
output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)`
can be precisely described as:
.. math::
\begin{aligned}
\text{out}(N_i, C_j, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
& \text{input}(N_i, C_j, \text{stride[0]} \times d + k,
\text{stride[1]} \times h + m, \text{stride[2]} \times w + n)
\end{aligned}
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
Note:
When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
or the input. Sliding windows that would start in the right padded region are ignored.
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
- a single ``int`` -- in which case the same value is used for the depth, height and width dimension
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
the second `int` for the height dimension and the third `int` for the width dimension
Args:
kernel_size: the size of the window to take a max over
stride: the stride of the window. Default value is :attr:`kernel_size`
padding: Implicit negative infinity padding to be added on all three sides
dilation: a parameter that controls the stride of elements in the window
return_indices: if ``True``, will return the max indices along with the outputs.
Useful for :class:`torch.nn.MaxUnpool3d` later
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
Shape:
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, where
.. math::
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
(\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
.. math::
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
(\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
.. math::
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
(\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
Examples::
>>> # pool of square window of size=3, stride=2
>>> m = nn.MaxPool3d(3, stride=2)
>>> # pool of non-square window
>>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2))
>>> input = torch.randn(20, 16, 50, 44, 31)
>>> output = m(input)
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
""" # noqa: E501
kernel_size: _size_3_t
stride: _size_3_t
padding: _size_3_t
dilation: _size_3_t
def forward(self, input: Tensor):
return F.max_pool3d(input, self.kernel_size, self.stride,
self.padding, self.dilation, ceil_mode=self.ceil_mode,
return_indices=self.return_indices)
class _MaxUnpoolNd(Module):
def extra_repr(self) -> str:
return 'kernel_size={}, stride={}, padding={}'.format(
self.kernel_size, self.stride, self.padding
)
class MaxUnpool1d(_MaxUnpoolNd):
r"""Computes a partial inverse of :class:`MaxPool1d`.
:class:`MaxPool1d` is not fully invertible, since the non-maximal values are lost.
:class:`MaxUnpool1d` takes in as input the output of :class:`MaxPool1d`
including the indices of the maximal values and computes a partial inverse
in which all non-maximal values are set to zero.
Note:
This operation may behave nondeterministically when the input indices has repeat values.
See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information.
.. note:: :class:`MaxPool1d` can map several input sizes to the same output
sizes. Hence, the inversion process can get ambiguous.
To accommodate this, you can provide the needed output size
as an additional argument :attr:`output_size` in the forward call.
See the Inputs and Example below.
Args:
kernel_size (int or tuple): Size of the max pooling window.
stride (int or tuple): Stride of the max pooling window.
It is set to :attr:`kernel_size` by default.
padding (int or tuple): Padding that was added to the input
Inputs:
- `input`: the input Tensor to invert
- `indices`: the indices given out by :class:`~torch.nn.MaxPool1d`
- `output_size` (optional): the targeted output size
Shape:
- Input: :math:`(N, C, H_{in})` or :math:`(C, H_{in})`.
- Output: :math:`(N, C, H_{out})` or :math:`(C, H_{out})`, where
.. math::
H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
or as given by :attr:`output_size` in the call operator
Example::
>>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?")
>>> pool = nn.MaxPool1d(2, stride=2, return_indices=True)
>>> unpool = nn.MaxUnpool1d(2, stride=2)
>>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8]]])
>>> output, indices = pool(input)
>>> unpool(output, indices)
tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]])
>>> # Example showcasing the use of output_size
>>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8, 9]]])
>>> output, indices = pool(input)
>>> unpool(output, indices, output_size=input.size())
tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8., 0.]]])
>>> unpool(output, indices)
tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]])
"""
kernel_size: _size_1_t
stride: _size_1_t
padding: _size_1_t
def __init__(self, kernel_size: _size_1_t, stride: Optional[_size_1_t] = None, padding: _size_1_t = 0) -> None:
super().__init__()
self.kernel_size = _single(kernel_size)
self.stride = _single(stride if (stride is not None) else kernel_size)
self.padding = _single(padding)
def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
return F.max_unpool1d(input, indices, self.kernel_size, self.stride,
self.padding, output_size)
class MaxUnpool2d(_MaxUnpoolNd):
r"""Computes a partial inverse of :class:`MaxPool2d`.
:class:`MaxPool2d` is not fully invertible, since the non-maximal values are lost.
:class:`MaxUnpool2d` takes in as input the output of :class:`MaxPool2d`
including the indices of the maximal values and computes a partial inverse
in which all non-maximal values are set to zero.
Note:
This operation may behave nondeterministically when the input indices has repeat values.
See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information.
.. note:: :class:`MaxPool2d` can map several input sizes to the same output
sizes. Hence, the inversion process can get ambiguous.
Loading ...