"""
This module implements observers which are used to collect statistics about
the values observed during calibration (PTQ) or training (QAT).
"""
import re
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from functools import partial
from typing import Any, List, Tuple, Optional, Dict
import torch
import torch.nn as nn
from torch.ao.quantization.utils import (
check_min_max_valid, calculate_qmin_qmax, is_per_tensor, is_per_channel, validate_qmin_qmax)
__all__ = [
"default_affine_fixed_qparams_observer",
"default_debug_observer",
"default_dynamic_quant_observer",
"default_fixed_qparams_range_0to1_observer",
"default_fixed_qparams_range_neg1to1_observer",
"default_float_qparams_observer",
"default_float_qparams_observer_4bit",
"default_histogram_observer",
"default_observer",
"default_per_channel_weight_observer",
"default_placeholder_observer",
"default_reuse_input_observer",
"default_symmetric_fixed_qparams_observer",
"default_weight_observer",
"get_observer_state_dict",
"load_observer_state_dict",
"per_channel_weight_observer_range_neg_127_to_127",
"weight_observer_range_neg_127_to_127",
"FixedQParamsObserver",
"HistogramObserver",
"MinMaxObserver",
"MovingAverageMinMaxObserver",
"MovingAveragePerChannelMinMaxObserver",
"NoopObserver",
"ObserverBase",
"PerChannelMinMaxObserver",
"PlaceholderObserver",
"RecordingObserver",
"ReuseInputObserver",
"UniformQuantizationObserverBase",
]
class _PartialWrapper:
def __init__(self, p):
self.p = p
self.callable_args = {}
def __call__(self, *args, **keywords):
# call each arg in callable_args and add them partial, then run with keywords
# skip if arg_name in keywords so its possible to overwrite
for arg_name in self.callable_args:
if arg_name not in keywords:
keywords = {**keywords, **{arg_name: self.callable_args[arg_name]()}}
return self.p(*args, **keywords)
def __repr__(self):
return self.p.__repr__() + self.callable_args.__repr__()
def with_args(self, **kwargs):
return _with_args(self, **kwargs)
def with_callable_args(self, **kwargs):
result = _PartialWrapper(p=self.p)
result.callable_args = {**self.callable_args, **kwargs}
return result
def _with_args(cls_or_self, **kwargs):
r"""Wrapper that allows creation of class factories.
This can be useful when there is a need to create classes with the same
constructor arguments, but different instances. Can be used in conjunction with
_callable_args
Example::
>>> # xdoctest: +SKIP("Undefined vars")
>>> Foo.with_args = classmethod(_with_args)
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
>>> foo_instance1 = foo_builder()
>>> foo_instance2 = foo_builder()
>>> id(foo_instance1) == id(foo_instance2)
False
"""
r = _PartialWrapper(partial(cls_or_self, **kwargs))
return r
def _with_callable_args(cls_or_self, **kwargs):
r"""Wrapper that allows creation of class factories args that need to be
called at construction time.
This can be useful when there is a need to create classes with the same
constructor arguments, but different instances and those arguments should only
be calculated at construction time. Can be used in conjunction with _with_args
Example::
>>> # xdoctest: +SKIP("Undefined vars")
>>> Foo.with_callable_args = classmethod(_with_callable_args)
>>> Foo.with_args = classmethod(_with_args)
>>> foo_builder = Foo.with_callable_args(cur_time=get_time_func).with_args(name="dan")
>>> foo_instance1 = foo_builder()
>>> # wait 50
>>> foo_instance2 = foo_builder()
>>> id(foo_instance1.creation_time) == id(foo_instance2.creation_time)
False
"""
r = _PartialWrapper(partial(cls_or_self))
return r.with_callable_args(**kwargs)
ABC: Any = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:
class ObserverBase(ABC, nn.Module):
r"""Base observer Module.
Any observer implementation should derive from this class.
Concrete observers should follow the same API. In forward, they will update
the statistics of the observed Tensor. And they should provide a
`calculate_qparams` function that computes the quantization parameters given
the collected statistics.
Args:
dtype: dtype argument to the `quantize` node needed to implement the
reference model spec.
"""
def __init__(self, dtype):
super().__init__()
self.dtype = dtype
@abstractmethod
def forward(self, x):
pass
@abstractmethod
def calculate_qparams(self, **kwargs):
pass
with_args = classmethod(_with_args)
with_callable_args = classmethod(_with_callable_args)
class UniformQuantizationObserverBase(ObserverBase):
r"""Common base for all observers using uniform quantization to calculate
scale and zero_point.
Args:
dtype: dtype argument to the `quantize` node needed to implement the
reference model spec.
qscheme: Quantization scheme to be used.
reduce_range: Reduces the range of the quantized data type by 1 bit.
This is sometimes required to avoid instruction overflow.
quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
.. warning::
:attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
.. warning::
:attr:`qscheme` can only take one of the following options:
- ``torch.per_tensor_affine``
- ``torch.per_tensor_symmetric``
- ``torch.per_channel_affine``
- ``torch.per_channel_symmetric``
"""
# Note: the version is shared by all observer types
#
# Version 1/None
# self
#
# Version 2 (base class only, does not include child class buffers)
# self
# |--- eps : Tensor
#
# Version 3
# for HistogramObserver only, changed the shape of uninitialized
# min_val and max_val buffers from torch.Size([0]) to torch.Size([])
# for PerChannelObservers, changed the name of the buffers from min_vals
# to min_val and from max_vals to max_val.
_version = 3
eps: torch.Tensor
def __init__(
self,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=False,
quant_min=None,
quant_max=None,
factory_kwargs=None,
eps=torch.finfo(torch.float32).eps,
) -> None:
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
super().__init__(dtype=dtype)
self.qscheme = qscheme
if reduce_range:
warnings.warn(
"Please use quant_min and quant_max to specify the range for observers. \
reduce_range will be deprecated in a future release of PyTorch."
)
self.reduce_range = reduce_range
self.register_buffer(
"eps", torch.tensor([eps], **factory_kwargs)
)
assert self.qscheme in (
torch.per_tensor_affine,
torch.per_tensor_symmetric,
torch.per_channel_affine,
torch.per_channel_symmetric,
torch.per_channel_affine_float_qparams,
), "Default Observer only works for per_tensor_affine, \
per_tensor_symmetric, per_channel_affine, \
per_channel_symmetric and per_channel_float_qparams quantization scheme"
assert self.dtype in (
torch.qint8,
torch.quint8,
torch.quint4x2,
torch.qint32,
), "Default Observer only works for qint8, quint8 and quint4x2 data type"
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
if self.has_customized_qrange:
validate_qmin_qmax(quant_min, quant_max)
self.quant_min, self.quant_max = \
calculate_qmin_qmax(quant_min, quant_max, self.has_customized_qrange, self.dtype, self.reduce_range)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version == 1:
# eps was moved to a buffer in version 2
eps = torch.tensor([torch.finfo(torch.float32).eps])
state_dict[prefix + "eps"] = eps
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
@torch.jit.export
def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None:
r"""Validates that the user-specified quantization range is properly initialized
and within the given bound supported by the observer dtype.
To accommodate lower-bit quantization with respect to the existing torch.qint8 and
torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
fake quantization. These estimates are compared against parameters learned through backpropagation.
The related literatures for scale and zero point via backpropagation are as follows:
Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
"""
# The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
# based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
assert (
quant_min <= 0 <= quant_max
), "Used-specified quantization range must include 0."
assert (
quant_min < quant_max
), "qmin must be strictly less than qmax for user-specified quantization range."
@torch.jit.export
def _calculate_qparams(
self, min_val: torch.Tensor, max_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Calculates the quantization parameters, given min and max
value tensors. Works for both per tensor and per channel cases
Args:
min_val: Minimum values per channel
max_val: Maximum values per channel
Returns:
scales: Scales tensor of shape (#channels,)
zero_points: Zero points tensor of shape (#channels,)
"""
# Functionally equivalent to 'determine_qparams' in utils.py. Observers must be torchscriptable however and qscheme
# as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer
# to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code
# seems unlikey to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor.
# TODO(jakeszwe, jerryzh168)
if not check_min_max_valid(min_val, max_val):
return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)
quant_min, quant_max = self.quant_min, self.quant_max
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device
scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
if (
self.qscheme == torch.per_tensor_symmetric
or self.qscheme == torch.per_channel_symmetric
):
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
scale = torch.max(scale, self.eps)
if self.dtype == torch.quint8:
if self.has_customized_qrange:
# When customized quantization range is used, down-rounded midpoint of the range is chosen.
zero_point = zero_point.new_full(
zero_point.size(), (quant_min + quant_max) // 2
)
else:
zero_point = zero_point.new_full(zero_point.size(), 128)
elif self.qscheme == torch.per_channel_affine_float_qparams:
scale = (max_val - min_val) / float(quant_max - quant_min)
scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
# We use the quantize function
# xq = Round(Xf * inv_scale + zero_point),
Loading ...