r""" Functional interface (quantized)."""
from typing import List, Optional
import warnings
import torch
from torch import Tensor
from torch.nn.modules.utils import _pair, _triple
from torch.nn.quantized.modules.utils import _pair_from_first
from torch.jit.annotations import BroadcastingList2
# Although some of the functions and docstrings are mirrored from the torch.nn,
# we want to have them here for future changes.
def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False,
count_include_pad=True, divisor_override=None):
r"""
Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size
:math:`sH \times sW` steps. The number of output features is equal to the number of
input planes.
.. note:: The input quantization parameters propagate to the output.
See :class:`~torch.nn.quantized.AvgPool2d` for details and output shape.
Args:
input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
kernel_size: size of the pooling region. Can be a single number or a
tuple `(kH, kW)`
stride: stride of the pooling operation. Can be a single number or a
tuple `(sH, sW)`. Default: :attr:`kernel_size`
padding: implicit zero paddings on both sides of the input. Can be a
single number or a tuple `(padH, padW)`. Default: 0
ceil_mode: when True, will use `ceil` instead of `floor` in the formula
to compute the output shape. Default: ``False``
count_include_pad: when True, will include the zero-padding in the
averaging calculation. Default: ``True``
divisor_override: if specified, it will be used as divisor, otherwise
size of the pooling region will be used. Default: None
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.avg_pool2d' must be quantized!")
return torch.nn.functional.avg_pool2d(input, kernel_size, stride, padding,
ceil_mode, count_include_pad,
divisor_override)
def avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False,
count_include_pad=True, divisor_override=None):
r"""
Applies 3D average-pooling operation in :math:`kD \ times kH \times kW` regions by step size
:math:`sD \times sH \times sW` steps. The number of output features is equal to the number of
input planes.
.. note:: The input quantization parameters propagate to the output.
Args:
input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
kernel_size: size of the pooling region. Can be a single number or a
tuple `(kD, kH, kW)`
stride: stride of the pooling operation. Can be a single number or a
tuple `(sD, sH, sW)`. Default: :attr:`kernel_size`
padding: implicit zero paddings on both sides of the input. Can be a
single number or a tuple `(padD, padH, padW)`. Default: 0
ceil_mode: when True, will use `ceil` instead of `floor` in the formula
to compute the output shape. Default: ``False``
count_include_pad: when True, will include the zero-padding in the
averaging calculation. Default: ``True``
divisor_override: if specified, it will be used as divisor, otherwise
size of the pooling region will be used. Default: None
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.avg_pool3d' must be quantized!")
return torch.nn.functional.avg_pool3d(input, kernel_size, stride, padding,
ceil_mode, count_include_pad,
divisor_override)
def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
r"""
Applies a 2D adaptive average pooling over a quantized input signal composed
of several quantized input planes.
.. note:: The input quantization parameters propagate to the output.
See :class:`~torch.nn.quantized.AdaptiveAvgPool2d` for details and output shape.
Args:
output_size: the target output size (single integer or
double-integer tuple)
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.functional.adaptive_avg_pool2d' must be quantized!")
return torch.nn.functional.adaptive_avg_pool2d(input, output_size)
def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
r"""
Applies a 3D adaptive average pooling over a quantized input signal composed
of several quantized input planes.
.. note:: The input quantization parameters propagate to the output.
See :class:`~torch.nn.quantized.AdaptiveAvgPool3d` for details and output shape.
Args:
output_size: the target output size (single integer or
double-integer tuple)
"""
if not input.is_quantized:
raise ValueError(
"Input to 'quantized.functional.adaptive_avg_pool3d' must be quantized!")
return torch.nn.functional.adaptive_avg_pool3d(input, output_size)
def conv1d(input, weight, bias,
stride=1, padding=0, dilation=1, groups=1,
padding_mode='zeros',
scale=1.0, zero_point=0,
dtype=torch.quint8):
r"""
Applies a 1D convolution over a quantized 1D input composed of several input
planes.
See :class:`~torch.nn.quantized.Conv1d` for details and output shape.
Args:
input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , iW)`
bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
stride: the stride of the convolving kernel. Can be a single number or a
tuple `(sW,)`. Default: 1
padding: implicit paddings on both sides of the input. Can be a
single number or a tuple `(padW,)`. Default: 0
dilation: the spacing between kernel elements. Can be a single number or
a tuple `(dW,)`. Default: 1
groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
number of groups. Default: 1
padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
scale: quantization scale for the output. Default: 1.0
zero_point: quantization zero_point for the output. Default: 0
dtype: quantization data type to use. Default: ``torch.quint8``
Examples::
>>> from torch.nn.quantized import functional as qF
>>> filters = torch.randn(33, 16, 3, dtype=torch.float)
>>> inputs = torch.randn(20, 16, 50, dtype=torch.float)
>>> bias = torch.randn(33, dtype=torch.float)
>>>
>>> scale, zero_point = 1.0, 0
>>> dtype_inputs = torch.quint8
>>> dtype_filters = torch.qint8
>>>
>>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
>>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
>>> qF.conv1d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
""" # noqa: E501
if padding_mode != 'zeros':
raise NotImplementedError("Only zero-padding is supported!")
if input.dtype != torch.quint8:
raise NotImplementedError("Only torch.quint8 is supported for activation tensor!")
if weight.dtype != torch.qint8:
raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
if input.ndim != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
stride = _pair_from_first(stride)
padding = _pair_from_first(padding)
dilation = _pair_from_first(dilation)
packed_params = torch.ops.quantized.conv1d_prepack(
weight, bias, stride, padding, dilation, groups)
return torch.ops.quantized.conv1d(input, packed_params, scale, zero_point)
def conv2d(input, weight, bias,
stride=1, padding=0, dilation=1, groups=1,
padding_mode='zeros',
scale=1.0, zero_point=0,
dtype=torch.quint8):
r"""
Applies a 2D convolution over a quantized 2D input composed of several input
planes.
See :class:`~torch.nn.quantized.Conv2d` for details and output shape.
Args:
input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)`
bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
stride: the stride of the convolving kernel. Can be a single number or a
tuple `(sH, sW)`. Default: 1
padding: implicit paddings on both sides of the input. Can be a
single number or a tuple `(padH, padW)`. Default: 0
dilation: the spacing between kernel elements. Can be a single number or
a tuple `(dH, dW)`. Default: 1
groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
number of groups. Default: 1
padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
scale: quantization scale for the output. Default: 1.0
zero_point: quantization zero_point for the output. Default: 0
dtype: quantization data type to use. Default: ``torch.quint8``
Examples::
>>> from torch.nn.quantized import functional as qF
>>> filters = torch.randn(8, 4, 3, 3, dtype=torch.float)
>>> inputs = torch.randn(1, 4, 5, 5, dtype=torch.float)
>>> bias = torch.randn(8, dtype=torch.float)
>>>
>>> scale, zero_point = 1.0, 0
>>> dtype_inputs = torch.quint8
>>> dtype_filters = torch.qint8
>>>
>>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
>>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
>>> qF.conv2d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
""" # noqa: E501
if padding_mode != 'zeros':
raise NotImplementedError("Only zero-padding is supported!")
if input.dtype != torch.quint8:
raise NotImplementedError("Only torch.quint8 is supported for activation tensor!")
if weight.dtype != torch.qint8:
raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
if input.ndim != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
packed_params = torch.ops.quantized.conv2d_prepack(
weight, bias, stride, padding, dilation, groups)
return torch.ops.quantized.conv2d(input, packed_params, scale, zero_point)
def conv3d(input, weight, bias, stride=1, padding=0, dilation=1, groups=1,
padding_mode='zeros', scale=1.0, zero_point=0, dtype=torch.quint8):
r"""
Applies a 3D convolution over a quantized 3D input composed of several input
planes.
See :class:`~torch.nn.quantized.Conv3d` for details and output shape.
Args:
input: quantized input tensor of shape
:math:`(\text{minibatch} , \text{in\_channels} , iD , iH , iW)`
weight: quantized filters of shape
:math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kD , kH , kW)`
bias: **non-quantized** bias tensor of shape
:math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
stride: the stride of the convolving kernel. Can be a single number or a
tuple `(sD, sH, sW)`. Default: 1
padding: implicit paddings on both sides of the input. Can be a
single number or a tuple `(padD, padH, padW)`. Default: 0
dilation: the spacing between kernel elements. Can be a single number or
a tuple `(dD, dH, dW)`. Default: 1
groups: split input into groups, :math:`\text{in\_channels}` should be
divisible by the number of groups. Default: 1
padding_mode: the padding mode to use. Only "zeros" is supported for
quantized convolution at the moment. Default: "zeros"
scale: quantization scale for the output. Default: 1.0
zero_point: quantization zero_point for the output. Default: 0
dtype: quantization data type to use. Default: ``torch.quint8``
Examples::
>>> from torch.nn.quantized import functional as qF
>>> filters = torch.randn(8, 4, 3, 3, 3, dtype=torch.float)
>>> inputs = torch.randn(1, 4, 5, 5, 5, dtype=torch.float)
>>> bias = torch.randn(8, dtype=torch.float)
>>>
>>> scale, zero_point = 1.0, 0
>>> dtype_inputs = torch.quint8
>>> dtype_filters = torch.qint8
>>>
>>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
>>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
>>> qF.conv3d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
""" # noqa: E501
if padding_mode != 'zeros':
raise NotImplementedError("Only zero-padding is supported!")
if input.dtype != torch.quint8:
raise NotImplementedError("Only torch.quint8 is supported for activation tensor!")
if weight.dtype != torch.qint8:
raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
if input.ndim != 5:
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
packed_params = torch.ops.quantized.conv3d_prepack(
weight, bias, stride, padding, dilation, groups)
return torch.ops.quantized.conv3d(input, packed_params, scale, zero_point)
def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None):
r"""Down/up samples the input to either the given :attr:`size` or the given
:attr:`scale_factor`
See :func:`torch.nn.functional.interpolate` for implementation details.
The input dimensions are interpreted in the form:
`mini-batch x channels x [optional depth] x [optional height] x width`.
.. note:: The input quantization parameters propagate to the output.
.. note:: Only 2D/3D input is supported for quantized inputs
.. note:: Only the following modes are supported for the quantized inputs:
- `bilinear`
- `nearest`
Args:
input (Tensor): the input tensor
size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
output spatial size.
scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
mode (str): algorithm used for upsampling:
``'nearest'`` | ``'bilinear'``
align_corners (bool, optional): Geometrically, we consider the pixels of the
input and output as squares rather than points.
If set to ``True``, the input and output tensors are aligned by the
center points of their corner pixels, preserving the values at the corner pixels.
If set to ``False``, the input and output tensors are aligned by the corner
points of their corner pixels, and the interpolation uses edge value padding
for out-of-boundary values, making this operation *independent* of input size
when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
is ``'bilinear'``.
Default: ``False``
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.interpolate' must be quantized!")
return torch.nn.functional.interpolate(input, size, scale_factor, mode,
align_corners)
def linear(
input: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
scale: Optional[float] = None, zero_point: Optional[int] = None
) -> Tensor:
r"""
Applies a linear transformation to the incoming quantized data:
:math:`y = xA^T + b`.
See :class:`~torch.nn.quantized.Linear`
.. note::
Current implementation packs weights on every call, which has penalty on performance.
If you want to avoid the overhead, use :class:`~torch.nn.quantized.Linear`.
Args:
input (Tensor): Quantized input of type `torch.quint8`
Loading ...