from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type, Callable
from torch.ao.quantization.quant_type import QuantType
import torch
import copy
import warnings
from torch.fx import (
GraphModule,
)
from torch.fx.graph import (
Graph,
Node,
Argument,
)
from ..utils import (
activation_is_statically_quantized,
weight_is_quantized,
get_qparam_dict,
_parent_name,
get_swapped_custom_module_class,
)
from ..qconfig import (
QConfigAny,
qconfig_equals
)
from ..qconfig_mapping import QConfigMapping
from .qconfig_mapping_utils import (
_generate_node_name_to_qconfig,
_compare_prepare_convert_qconfig_mappings,
_update_qconfig_for_fusion,
_is_qconfig_supported_by_dtype_configs,
_update_qconfig_for_qat,
)
from torch.ao.quantization.backend_config.utils import (
get_root_module_to_quantized_reference_module,
get_pattern_to_dtype_configs,
get_fused_module_classes,
get_qat_module_classes,
)
from torch.ao.quantization.backend_config import (
BackendConfig,
get_native_backend_config,
)
from torch.ao.quantization.observer import _is_activation_post_process
from .graph_module import (
_is_observed_module,
_is_observed_standalone_module,
)
from ._equalize import update_obs_for_equalization, convert_eq_obs
from torch.nn.utils.parametrize import type_before_parametrizations
from .utils import (
_get_module,
_is_custom_module_lstm,
get_custom_module_class_keys,
create_getattr_from_value,
collect_producer_nodes,
graph_module_from_producer_nodes,
node_arg_is_weight,
)
from torch.ao.quantization.utils import (
is_per_channel,
to_underlying_dtype,
)
from torch.ao.quantization.quantize import (
_remove_qconfig,
)
from torch.ao.quantization.stubs import DeQuantStub
from .custom_config import (
ConvertCustomConfig,
PrepareCustomConfig,
)
from .lower_to_fbgemm import lower_to_fbgemm
# importing the lib so that the quantized_decomposed ops are registered
from ._decomposed import quantized_decomposed_lib # noqa: F401
import operator
__all__ = [
"convert",
"convert_custom_module",
"convert_standalone_module",
"convert_weighted_module",
]
def _replace_observer_with_quantize_dequantize_node_decomposed(
model: torch.nn.Module,
graph: Graph,
node: Node,
modules: Dict[str, torch.nn.Module],
node_name_to_scope: Dict[str, Tuple[str, type]],
node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
""" Replace activation_post_process module call node with quantize and
dequantize node working with decomposed Tensor
Before:
... -> observer_0(x) -> ...
After:
... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) ->
torch.ops.quantized_decomposed.dequantize_per_tensor() -> ...
or quantize_per_channel and dequantize_per_channel
"""
assert modules is not None
assert isinstance(node.target, str)
module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
activation_post_process = modules[node.target]
# skip replacing observers to quant/dequant nodes if the qconfigs of all
# consumers and producers of this observer are None
skip_replacement = all([
_has_none_qconfig(n, node_name_to_qconfig) for n in
list(node.args) + list(node.users.keys())])
if skip_replacement or not _is_conversion_supported(activation_post_process):
# didn't find correponding quantize op and info for the activation_post_process
# so we just remove the observer
with graph.inserting_before(node):
node.replace_all_uses_with(node.args[0])
graph.erase_node(node)
return
# otherwise, we can convert the activation_post_process module call to quantize/dequantize node
# 1. extract the information from activation_post_process module for generating
# the quantize and dequantize operator
dtype = activation_post_process.dtype # type: ignore[attr-defined]
is_dynamic = False
if hasattr(activation_post_process, "is_dynamic"):
is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment]
if dtype in [torch.quint8, torch.qint8, torch.qint32] and \
(not is_dynamic):
# TODO: probably should cleanup this condition check, it's hard
# to reason about this if and the following elif
# uint8/int8/int32 static quantization branch
# 1. extract information for inserting q/dq node from activation_post_process
node_type = "call_function"
quantize_op : Optional[Callable] = None
scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator]
if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type]
quantize_op = torch.ops.quantized_decomposed.quantize_per_channel
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel
quant_min = activation_post_process.quant_min
quant_max = activation_post_process.quant_max
dtype_ = to_underlying_dtype(dtype)
qparams = {
"_scale_": scale,
"_zero_point_": zero_point,
"_axis_": ch_axis,
"_quant_min_": quant_min,
"_quant_max_": quant_max,
"_dtype_": dtype_
}
else:
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor
scale = float(scale)
zero_point = int(zero_point)
quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
dtype_ = to_underlying_dtype(dtype)
qparams = {
"_scale_": scale,
"_zero_point_": zero_point,
"_quant_min_": quant_min,
"_quant_max_": quant_max,
"_dtype_": dtype_
}
# 2. replace activation_post_process node with quantize and dequantize
with graph.inserting_before(node):
input_node = node.args[0]
quantize_op_inputs = [input_node]
for key, value_or_node in qparams.items():
# TODO: we can add the information of whether a value needs to
# be registered as an attribute in qparams dict itself
if key in ['_scale_', '_zero_point_']:
# For scale and zero_point values we register them as buffers in the root module.
# TODO: maybe need more complex attr name here
qparam_node = create_getattr_from_value(
model, graph, module_path + prefix + key, value_or_node)
quantize_op_inputs.append(qparam_node)
else:
# for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
quantize_op_inputs.append(value_or_node)
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
# use the same qparams from quantize op
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
dequantized_node = graph.call_function(
dequantize_op,
tuple(dq_inputs),
{}
)
node.replace_all_uses_with(dequantized_node)
graph.erase_node(node)
elif is_dynamic:
# uint8/int8/fp16 dynamic quantization
# 1. extract information for inserting q/dq node from activation_post_process
node_type = "call_function"
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor
# we only use choose_qparams for is_decomposed now,
# but we should probably align the non-decomposed path with this as well,
# and that can be done after we remove reduce_range flag
# 1. extract qparams from activation_post_process module
dtype_ = to_underlying_dtype(dtype)
assert dtype_ in [torch.uint8, torch.int8], \
"only uint8 and int8 are supported in reference flow for " \
"dynamic quantization right now"
quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
# note: scale and zero_point are missing for quantize_per_tensor op
# we'll need to get this from choose_qparams op, which we'll add after
# this step
qparams = {
"_quant_min_": quant_min,
"_quant_max_": quant_max,
"_dtype_": dtype_
}
# 2. insert choose_qparams op and update the qparams list
with graph.inserting_before(node):
input_node = node.args[0]
choose_qparams_op_inputs = [node.args[0]]
for key, value in qparams.items():
# we have quant_min, quant_max and dtype, all should be stored
# as literals
choose_qparams_op_inputs.append(value)
choose_qparams_node = graph.create_node(
"call_function",
torch.ops.quantized_decomposed.choose_qparams.tensor,
tuple(choose_qparams_op_inputs),
{}
)
# choose_qparms returns (scale, zero_point)
scale_node = graph.create_node(
"call_function",
operator.getitem,
(choose_qparams_node, 0),
{}
)
zero_point_node = graph.create_node(
"call_function",
operator.getitem,
(choose_qparams_node, 1),
{}
)
quant_min = qparams["_quant_min_"]
quant_max = qparams["_quant_max_"]
dtype = qparams["_dtype_"]
qparams = {
"_scale_": scale_node,
"_zero_point_": zero_point_node,
"_quant_min_": quant_min,
"_quant_max_": quant_max,
"_dtype_": dtype
}
# 3. replace activation_post_process node to quantize and dequantize node
with graph.inserting_before(node):
input_node = node.args[0]
quantize_op_inputs = [input_node]
for key, value_or_node in qparams.items():
# TODO: we can add the information of whether a value needs to
# be registered as an attribute in qparams dict itself
if key in ['_scale_', '_zero_point_']:
# in this case we have a node in the graph since it's dynamically
# computed from the input, with choose_qparams op
qparam_node = value_or_node
quantize_op_inputs.append(qparam_node)
else:
# for qparams that are not scale/zero_point (like axis, dtype) we
# store them as literals in the graph.
quantize_op_inputs.append(value_or_node)
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
# use the same qparams from quantize op
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
# need to use the tensor variant of this op, since scale and zero_point
# from choose_qparam are Tensors, instead of float/int, this is to
# prevent these nodes being traced away by downstream systems
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
dequantized_node = graph.call_function(
dequantize_op,
tuple(dq_inputs),
{}
)
node.replace_all_uses_with(dequantized_node)
graph.erase_node(node)
elif dtype == torch.float16:
raise NotImplementedError("decomposed to float16 op not implemented yet")
# should not reach since we have checks in the begining to make sure the
# activation_post_process is supported
def _replace_observer_with_quantize_dequantize_node(
model: torch.nn.Module,
graph: Graph,
node: Node,
modules: Dict[str, torch.nn.Module],
node_name_to_scope: Dict[str, Tuple[str, type]],
node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
""" Replace activation_post_process module call node with quantize and
dequantize node
Before:
... -> observer_0(x) -> ...
After:
... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
"""
assert modules is not None
assert isinstance(node.target, str)
module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
activation_post_process = modules[node.target]
# skip replacing observers to quant/dequant nodes if the qconfigs of all
# consumers and producers of this observer are None
skip_replacement = all([
_has_none_qconfig(n, node_name_to_qconfig) for n in
list(node.args) + list(node.users.keys())])
if skip_replacement or not _is_conversion_supported(activation_post_process):
# didn't find correponding quantize op and info for the activation_post_process
# so we just remove the observer
with graph.inserting_before(node):
node.replace_all_uses_with(node.args[0])
graph.erase_node(node)
return
# otherwise, we can convert the activation_post_process module call to quantize/dequantize node
dtype = activation_post_process.dtype # type: ignore[attr-defined]
is_dynamic = False
if hasattr(activation_post_process, "is_dynamic"):
is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
if dtype in [torch.quint8, torch.qint8, torch.qint32] and \
(not is_dynamic):
# TODO: probably should cleanup this condition check, it's hard
# to reason about this if and the following elif
# uint8/int8/int32 static quantization branch
# 1. extract the information from activation_post_process module for generating
# the quantize and dequantize operator
Loading ...