import warnings
from abc import ABCMeta, abstractmethod
from functools import partial
from typing import Any, List, Tuple, Optional, Dict, Union
from collections import OrderedDict
import torch
import torch.nn as nn
import re
def _with_args(cls_or_self, **kwargs):
r"""Wrapper that allows creation of class factories.
This can be useful when there is a need to create classes with the same
constructor arguments, but different instances.
Example::
>>> Foo.with_args = classmethod(_with_args)
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
>>> foo_instance1 = foo_builder()
>>> foo_instance2 = foo_builder()
>>> id(foo_instance1) == id(foo_instance2)
False
"""
class _PartialWrapper(object):
def __init__(self, p):
self.p = p
def __call__(self, *args, **keywords):
return self.p(*args, **keywords)
def __repr__(self):
return self.p.__repr__()
with_args = _with_args
r = _PartialWrapper(partial(cls_or_self, **kwargs))
return r
ABC: Any = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:
class ObserverBase(ABC, nn.Module):
r"""Base observer Module.
Any observer implementation should derive from this class.
Concrete observers should follow the same API. In forward, they will update
the statistics of the observed Tensor. And they should provide a
`calculate_qparams` function that computes the quantization parameters given
the collected statistics.
Args:
dtype: Quantized data type
"""
def __init__(self, dtype):
super(ObserverBase, self).__init__()
self.dtype = dtype
@abstractmethod
def forward(self, x):
pass
@abstractmethod
def calculate_qparams(self, **kwargs):
pass
with_args = classmethod(_with_args)
class _ObserverBase(ObserverBase):
r"""Internal common base for all qint/quint8 observers.
This base is for commonly used parameters used internally.
Users should use `~torch.quantization.observer.ObserverBase` as a base class
for custom observers.
Args:
dtype: Quantized data type.
qscheme: Quantization scheme to be used.
reduce_range: Reduces the range of the quantized data type by 1 bit.
This is sometimes required to avoid instruction overflow.
quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
.. warning::
:attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
.. warning::
:attr:`qscheme` can only take one of the following options:
- ``torch.per_tensor_affine``
- ``torch.per_tensor_symmetric``
- ``torch.per_channel_affine``
- ``torch.per_channel_symmetric``
"""
# Note: the version is shared by all observer types
#
# Version 1/None
# self
#
# Version 2 (base class only, does not include child class buffers)
# self
# |--- eps : Tensor
#
# Version 3
# for HistogramObserver only, changed the shape of uninitialized
# min_val and max_val buffers from torch.Size([0]) to torch.Size([])
_version = 2
eps: torch.Tensor
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
reduce_range=False, quant_min=None, quant_max=None):
super(_ObserverBase, self).__init__(dtype=dtype)
self.qscheme = qscheme
if reduce_range:
warnings.warn(
"Please use quant_min and quant_max to specify the range for observers. \
reduce_range will be deprecated in a future release of PyTorch."
)
self.reduce_range = reduce_range
self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps]))
assert self.qscheme in (
torch.per_tensor_affine,
torch.per_tensor_symmetric,
torch.per_channel_affine,
torch.per_channel_symmetric,
torch.per_channel_affine_float_qparams,
), "Default Observer only works for per_tensor_affine, \
per_tensor_symmetric, per_channel_affine, \
per_channel_symmetric and per_channel_float_qparams quantization scheme"
assert self.dtype in (
torch.qint8,
torch.quint8,
torch.quint4x2,
), "Default Observer only works for qint8, quint8 and quint4x2 data type"
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
if self.has_customized_qrange:
self._validate_qmin_qmax(quant_min, quant_max)
self.quant_min = quant_min
self.quant_max = quant_max
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version == 1:
# eps was moved to a buffer in version 2
eps = torch.tensor([torch.finfo(torch.float32).eps])
state_dict[prefix + 'eps'] = eps
super(ObserverBase, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
@torch.jit.export
def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None:
r"""Validates that the user-specified quantization range is properly initialized
and within the given bound supported by the observer dtype.
To accommodate lower-bit quantization with respect to the existing torch.qint8 and
torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
fake quantization. These estimates are compared against parameters learned through backpropagation.
The related literatures for scale and zero point via backpropagation are as follows:
Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
"""
# The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
# based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
assert quant_min <= 0 <= quant_max, "Used-specified quantization range must include 0."
assert quant_min < quant_max, "qmin must be strictly less than qmax for user-specified quantization range."
@torch.jit.export
def _calculate_qmin_qmax(self) -> Tuple[int, int]:
r"""Calculates actual qmin and qmax based on the quantization range,
observer datatype and if range is reduced.
"""
if self.has_customized_qrange:
# This initialization here is to be resolve TorchScript compilation issues and allow
# using of refinement to decouple initial_qmin and initial_qmax from quantization range.
# The actual values of initial_qmin and initial_qmax will be reset below.
initial_quant_min, initial_quant_max = 0, 255
# The following assignment of self.qmin and self.qmax to the local variables and the if check refine the
# attribute from Optional valid integers for use, based on TorchScript's requirements.
custom_quant_min, custom_quant_max = self.quant_min, self.quant_max
if custom_quant_min is not None and custom_quant_max is not None:
initial_quant_min, initial_quant_max = custom_quant_min, custom_quant_max
qrange_len = initial_quant_max - initial_quant_min + 1
assert 0 < qrange_len <= 256, \
"quantization range should be positive and not exceed the maximum bit range (=256)."
if self.dtype == torch.qint8:
quant_min, quant_max = -qrange_len // 2, qrange_len // 2 - 1
else:
quant_min, quant_max = 0, qrange_len - 1
if self.reduce_range:
quant_min, quant_max = quant_min // 2, quant_max // 2
else:
# Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used.
if self.dtype == torch.qint8:
if self.reduce_range:
quant_min, quant_max = -64, 63
else:
quant_min, quant_max = -128, 127
elif self.dtype == torch.quint8:
if self.reduce_range:
quant_min, quant_max = 0, 127
else:
quant_min, quant_max = 0, 255
else:
quant_min, quant_max = 0, 15
return quant_min, quant_max
@torch.jit.export
def _calculate_qparams(self, min_val: torch.Tensor, max_val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Calculates the quantization parameters, given min and max
value tensors. Works for both per tensor and per channel cases
Args:
min_val: Minimum values per channel
max_val: Maximum values per channel
Returns:
scales: Scales tensor of shape (#channels,)
zero_points: Zero points tensor of shape (#channels,)
"""
if min_val.numel() == 0 or max_val.numel() == 0:
warnings.warn(
"must run observer before calling calculate_qparams.\
Returning default scale and zero point "
)
return torch.tensor([1.0]), torch.tensor([0])
if min_val.dim() == 0 or max_val.dim() == 0:
if min_val == float('inf') and max_val == float('-inf'):
warnings.warn(
"must run observer before calling calculate_qparams.\
Returning default scale and zero point "
)
return torch.tensor([1.0]), torch.tensor([0])
assert min_val <= max_val, "min {} should be less than max {}".format(
min_val, max_val
)
else:
assert torch.all(min_val <= max_val), "min {} should be less than max {}".format(
min_val, max_val
)
quant_min, quant_max = self._calculate_qmin_qmax()
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device
scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
if self.qscheme == torch.per_tensor_symmetric or self.qscheme == torch.per_channel_symmetric:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
scale = torch.max(scale, self.eps)
if self.dtype == torch.quint8:
if self.has_customized_qrange:
# When customized quantization range is used, down-rounded midpoint of the range is chosen.
zero_point = zero_point.new_full(zero_point.size(), (quant_min + quant_max) // 2)
else:
zero_point = zero_point.new_full(zero_point.size(), 128)
elif self.qscheme == torch.per_channel_affine_float_qparams:
scale = (max_val - min_val) / float(quant_max - quant_min)
scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
# We use the quantize function
# xq = Round(Xf * inv_scale + zero_point),
# setting zero_point to (-1 * min *inv_scale) we get
# Xq = Round((Xf - min) * inv_scale)
zero_point = -1 * min_val / scale
else:
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.max(scale, self.eps)
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
# For scalar values, cast them to Tensors of size 1 to keep the shape
# consistent with default values in FakeQuantize.
if len(scale.shape) == 0:
# TODO: switch to scale.item() after adding JIT support
scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
if len(zero_point.shape) == 0:
# TODO: switch to zero_point.item() after adding JIT support
zero_point = torch.tensor([int(zero_point)], dtype=zero_point.dtype, device=device)
if self.qscheme == torch.per_channel_affine_float_qparams:
zero_point = torch.tensor([float(zero_point)], dtype=zero_point.dtype, device=device)
return scale, zero_point
class MinMaxObserver(_ObserverBase):
r"""Observer module for computing the quantization parameters based on the
running min and max values.
This observer uses the tensor min/max statistics to compute the quantization
parameters. The module records the running minimum and maximum of incoming
tensors, and uses this statistic to compute the quantization parameters.
Args:
dtype: Quantized data type
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`,
scale :math:`s` and zero point :math:`z` are computed as:
The running minimum/maximum :math:`x_\text{min/max}` is computed as:
.. math::
\begin{array}{ll}
x_\text{min} &= \begin{cases}
\min(X) & \text{if~}x_\text{min} = \text{None} \\
\min\left(x_\text{min}, \min(X)\right) & \text{otherwise}
\end{cases}\\
x_\text{max} &= \begin{cases}
\max(X) & \text{if~}x_\text{max} = \text{None} \\
\max\left(x_\text{max}, \max(X)\right) & \text{otherwise}
\end{cases}\\
\end{array}
where :math:`X` is the observed tensor.
The scale :math:`s` and zero point :math:`z` are then computed as:
.. math::
\begin{aligned}
\text{if Symmetric:}&\\
&s = 2 \max(|x_\text{min}|, x_\text{max}) /
\left( Q_\text{max} - Q_\text{min} \right) \\
Loading ...