"""
Utils shared by different modes of quantization (eager/graph)
"""
import functools
import warnings
from collections import OrderedDict
from inspect import getfullargspec, signature
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
from torch.ao.quantization.quant_type import QuantType
from torch.fx import Node
from torch.nn.utils.parametrize import is_parametrized
NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
NodePattern.__module__ = "torch.ao.quantization.utils"
# This is the Quantizer class instance from torch/quantization/fx/quantize.py.
# Define separately to prevent circular imports.
# TODO(future PR): improve this.
# make this public once fixed (can't be public as is because setting the module directly
# doesn't work)
QuantizerCls = Any
# Type for fusion patterns, it can be more complicated than the following actually,
# see pattern.md for docs
# TODO: not sure if typing supports recursive data types
Pattern = Union[
Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any
]
Pattern.__module__ = "torch.ao.quantization.utils"
# TODO: maybe rename this to MatchInputNode
class MatchAllNode:
""" A node pattern that matches all nodes, used in defining
fusion patterns in FX Graph Mode Quantization
"""
pass
module_type_list = {
torch.nn.ReLU,
torch.nn.ReLU6,
torch.nn.AdaptiveAvgPool1d,
torch.nn.AdaptiveAvgPool2d,
torch.nn.AdaptiveAvgPool3d,
torch.nn.AvgPool1d,
torch.nn.AvgPool2d,
torch.nn.AvgPool3d,
torch.nn.MaxPool1d,
torch.nn.MaxPool2d,
torch.nn.MaxPool3d,
torch.nn.Identity,
torch.nn.Hardsigmoid,
torch.nn.Sigmoid,
torch.nn.Tanh,
}
func_list = {
torch.nn.functional.adaptive_avg_pool1d,
torch.nn.functional.adaptive_avg_pool2d,
torch.nn.functional.adaptive_avg_pool3d,
torch.nn.functional.elu,
torch.nn.functional.hardswish,
torch.nn.functional.instance_norm,
torch.nn.functional.layer_norm,
torch.nn.functional.leaky_relu,
torch.nn.functional.silu,
torch.nn.functional.mish,
torch.nn.functional.dropout,
torch.nn.functional.max_pool1d,
torch.nn.functional.max_pool2d,
torch.nn.functional.max_pool3d,
torch.nn.functional.relu,
torch.nn.functional.hardtanh,
torch.nn.functional.hardtanh_,
torch.nn.functional.hardsigmoid,
torch.nn.functional.sigmoid,
torch.transpose,
torch.repeat_interleave,
torch.sigmoid,
torch.squeeze,
torch.stack,
torch.sum,
torch.tanh,
torch.unsqueeze,
torch.cat,
}
method_list = {
torch.mean,
'relu',
'relu_',
'contiguous',
'detach',
'detach_',
'hardsigmoid',
'hardsigmoid_',
'permute',
'repeat',
'repeat_interleave',
'reshape',
'resize_',
'shape',
'sigmoid',
'sigmoid_',
'size',
'squeeze',
'squeeze_',
'tanh',
'tanh_',
'transpose',
'unsqueeze',
'unsqueeze_',
'view',
}
# TODO: not used now, remove
def check_node(node, modules):
# TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py
is_call_function = node.op == "call_function" and node.target in func_list
is_call_method = node.op == "call_method" and node.target in method_list
is_call_module = node.op == "call_module" and type(modules[str(node.target)]) in module_type_list
return is_call_function, is_call_method, is_call_module
def get_combined_dict(default_dict, additional_dict):
d = default_dict.copy()
d.update(additional_dict)
return d
def is_per_tensor(qscheme):
return qscheme == torch.per_tensor_affine or \
qscheme == torch.per_tensor_symmetric
def is_per_channel(qscheme):
return qscheme in [torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
torch.per_channel_symmetric]
def getattr_from_fqn(obj: Any, fqn: str) -> Any:
"""
Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
"""
return functools.reduce(getattr, fqn.split("."), obj)
def to_underlying_dtype(qdtype):
DTYPE_MAPPING = {
torch.quint8: torch.uint8,
torch.qint8: torch.int8,
torch.qint32: torch.int32,
torch.quint4x2: torch.uint8,
torch.quint2x4: torch.uint8,
}
assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + qdtype
return DTYPE_MAPPING[qdtype]
def get_qparam_dict(observer_or_fake_quant):
qscheme = observer_or_fake_quant.qscheme if hasattr(observer_or_fake_quant, "qscheme") else None
dtype = observer_or_fake_quant.dtype
qparams = {"qscheme": qscheme, "dtype": dtype}
if not qscheme:
return qparams
if is_per_tensor(qscheme):
qscheme = torch.per_tensor_affine
elif is_per_channel(qscheme):
# change symmetric to affine since we do not have symmetric
# quantized Tensor
if qscheme == torch.per_channel_symmetric:
qscheme = torch.per_channel_affine
qparams["axis"] = observer_or_fake_quant.ch_axis
else:
raise RuntimeError(f"Unrecognized qscheme: {qscheme}")
# update qscheme, since we don't have symmetric quant qscheme
# in quantized Tensor
qparams["qscheme"] = qscheme
scale, zero_point = observer_or_fake_quant.calculate_qparams()
qparams["scale"] = scale
qparams["zero_point"] = zero_point
return qparams
def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
""" Get the observed/quantized custom module class that we need
to swap `custom_module` to
Input:
custom_module: input, can be an instance of either a float or observed custom module
custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
qconfig: qconfig configured for the custom module
Output:
corresponding observed/quantized custom module class for input custom module instance
"""
quant_type = get_quant_type(qconfig)
class_mapping = custom_module_class_mapping.get(quant_type, {})
assert type(custom_module) in class_mapping, "did not find corresponding observed " \
"module class for {} in mapping: {}".format(type(custom_module), class_mapping)
return class_mapping[type(custom_module)]
def activation_dtype(qconfig):
assert qconfig is not None
activation = qconfig.activation()
return activation.dtype
def weight_dtype(qconfig):
assert qconfig is not None
weight = qconfig.weight()
return weight.dtype
def activation_is_statically_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
quantized or not, this includes quantizing to quint8, qint8 and qint32 and float16
"""
return (
activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.qint32, torch.float16]
and (not activation_is_dynamically_quantized(qconfig))
)
def activation_is_dynamically_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
dynamically quantized or not, this includes dynamically quantizing to
quint8, qint8 and float16
"""
activation_dtype, _, activation_is_dynamic = \
get_qconfig_dtypes(qconfig)
return activation_is_dynamic
def activation_is_int8_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
quantized to int8 or not, this includes quantizing to quint8, qint8
"""
return activation_dtype(qconfig) in [torch.quint8, torch.qint8]
def activation_is_int32_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
quantized to int32 or not
"""
return activation_dtype(qconfig) == torch.qint32
def weight_is_quantized(qconfig):
""" Given a qconfig, decide if the weight needs to be
quantized or not
"""
return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16, torch.quint4x2]
def weight_is_statically_quantized(qconfig):
""" Given a qconfig, decide if the weight needs to be statically
quantized or not
"""
return weight_dtype(qconfig) in [torch.quint8, torch.qint8]
def op_is_int8_dynamically_quantized(qconfig) -> bool:
""" Given a qconfig, returns True if this op is using int8 dynamic
quantization
"""
activation_dtype, weight_dtype, activation_is_dynamic = \
get_qconfig_dtypes(qconfig)
return (
activation_dtype is torch.quint8 and
# for now, the lines below assume fbgemm or qnnpack
weight_dtype is torch.qint8 and
activation_is_dynamic
)
def get_qconfig_dtypes(qconfig):
r""" returns the qconfig tuple for qconfig:
(activation_dtype, weight_dtype, activation_is_dynamic)
"""
assert qconfig is not None
activation = qconfig.activation()
weight = qconfig.weight()
act_is_dynamic = activation.is_dynamic if hasattr(activation, 'is_dynamic') else False
return (activation.dtype, weight.dtype, act_is_dynamic)
def get_quant_type(qconfig):
assert qconfig is not None
activation = qconfig.activation()
weight = qconfig.weight()
static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]
if weight.dtype in static_dtypes:
if hasattr(activation, 'is_dynamic') and activation.is_dynamic:
return QuantType.DYNAMIC
elif activation.dtype in static_dtypes:
return QuantType.STATIC
else:
return QuantType.WEIGHT_ONLY
if weight.dtype == torch.float16:
if hasattr(activation, 'is_dynamic') and activation.is_dynamic:
return QuantType.DYNAMIC
elif activation.dtype == torch.float16:
return QuantType.STATIC
raise Exception("Unrecognized dtype combination in get_quant_type: activation({}),"
"weight({})".format(activation.dtype, weight.dtype))
def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
""" Checks if the given minimum and maximum values are valid, meaning that
they exist and the min value is less than the max value.
"""
if min_val.numel() == 0 or max_val.numel() == 0:
warnings.warn(
"must run observer before calling calculate_qparams. " +
"Returning default values."
)
return False
if min_val.dim() == 0 or max_val.dim() == 0:
if min_val == float("inf") and max_val == float("-inf"):
warnings.warn(
"must run observer before calling calculate_qparams. " +
"Returning default values."
)
return False
assert min_val <= max_val, "min {} should be less than max {}".format(
min_val, max_val
)
else:
assert torch.all(
min_val <= max_val
), "min {} should be less than max {}".format(min_val, max_val)
return True
def calculate_qmin_qmax(quant_min: int, quant_max: int, has_customized_qrange: bool, dtype: torch.dtype,
reduce_range: bool) -> Tuple[int, int]:
r"""Calculates actual qmin and qmax based on the quantization range,
observer datatype and if range is reduced.
"""
# TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted.
if has_customized_qrange:
# This initialization here is to be resolve TorchScript compilation issues and allow
# using of refinement to decouple initial_qmin and initial_qmax from quantization range.
# The actual values of initial_qmin and initial_qmax will be reset below.
if dtype == torch.qint32:
initial_quant_min, initial_quant_max = 0, 2**31 - 1
else:
initial_quant_min, initial_quant_max = 0, 255
# The following assignment of self.qmin and self.qmax to the local variables and the if check refine the
# attribute from Optional valid integers for use, based on TorchScript's requirements.
custom_quant_min, custom_quant_max = quant_min, quant_max
if custom_quant_min is not None and custom_quant_max is not None:
Loading ...