import torch
from torch.fx.graph import (
Node,
)
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
from torch.quantization import (
default_affine_fixed_qparams_fake_quant,
default_symmetric_fixed_qparams_fake_quant,
)
from ..quantization_mappings import (
get_static_quant_module_class,
get_dynamic_quant_module_class,
get_quantized_operator,
)
from ..utils import (
get_swapped_custom_module_class,
activation_is_statically_quantized,
weight_is_statically_quantized,
weight_dtype,
get_qconfig_dtypes,
)
from .pattern_utils import (
register_quant_pattern,
mark_input_output_not_observed,
)
from .utils import (
_parent_name,
quantize_node,
get_per_tensor_qparams,
get_linear_prepack_op_for_dtype,
create_qparam_nodes,
get_qconv_prepack_op,
get_qconv_op,
)
from .quantization_types import QuantizerCls
from abc import ABC, abstractmethod
import operator
import warnings
from typing import Any, Callable, Dict
# -------------------------
# Pattern Registrations
# -------------------------
# 1. Post Training Static Quantization and Quantization Aware Training Patterns
# Base Pattern Handler
class QuantizeHandler(ABC):
""" Base handler class for the quantizer patterns
"""
def __init__(self, quantizer: QuantizerCls, node: Node):
""" Records pattern information in __init__, which will be used
in convert
"""
# this is an indicator of whether all the inputs are Node or not
# since some op might be quantized differently depending on whether
# all inputs are tensors or not, e.g. add/mul
self.num_node_args = len(node.args)
self.all_node_args = True
@abstractmethod
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
""" Convert the given node to a quantized node and insert
it to the quantized graph
"""
return NotImplemented
@register_quant_pattern(operator.add)
@register_quant_pattern(torch.add)
@register_quant_pattern((torch.nn.ReLU, operator.add))
@register_quant_pattern((torch.nn.ReLU, torch.add))
@register_quant_pattern((torch.nn.functional.relu, operator.add))
@register_quant_pattern((torch.nn.functional.relu, torch.add))
class Add(QuantizeHandler):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0] # type: ignore
assert node.op == 'call_function' and node.target in [operator.add, torch.add]
self.add_node = node
self.num_node_args = len([a for a in self.add_node.args[:2] if isinstance(a, Node)])
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
if self.num_node_args == 1:
# add scalar
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add
if isinstance(self.add_node.args[0], Node):
quantized_index = 0
else:
quantized_index = 1
return quantizer.quantized_graph.create_node(
'call_function', op,
load_arg(quantized=[quantized_index])(self.add_node.args), self.add_node.kwargs)
else:
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add
kwargs = {**self.add_node.kwargs}
add_args = (*load_arg(quantized=True)(self.add_node.args), scale_arg, zero_point_arg)
op = quantizer.quantized_graph.create_node(
'call_function', op, add_args, kwargs)
return op
# TODO: merge with Add
@register_quant_pattern(operator.mul)
@register_quant_pattern(torch.mul)
@register_quant_pattern((torch.nn.ReLU, operator.mul))
@register_quant_pattern((torch.nn.ReLU, torch.mul))
@register_quant_pattern((torch.nn.functional.relu, operator.mul))
@register_quant_pattern((torch.nn.functional.relu, torch.mul))
class Mul(QuantizeHandler):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0] # type: ignore
assert node.op == 'call_function' and node.target in [operator.mul, torch.mul]
self.mul_node = node
self.num_node_args = len([a for a in self.mul_node.args[:2] if isinstance(a, Node)])
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
if self.num_node_args == 1:
# mul scalar
if self.relu_node is not None:
op = torch.ops.quantized.mul_relu
else:
op = torch.ops.quantized.mul
if isinstance(self.mul_node.args[0], Node):
quantized_index = 0
else:
quantized_index = 1
return quantizer.quantized_graph.create_node(
'call_function', op, load_arg(quantized=[quantized_index])(self.mul_node.args), self.mul_node.kwargs)
else:
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
if self.relu_node is not None:
op = torch.ops.quantized.mul_relu
else:
op = torch.ops.quantized.mul
kwargs = {**self.mul_node.kwargs}
args = (*load_arg(quantized=True)(self.mul_node.args), scale_arg, zero_point_arg)
return quantizer.quantized_graph.create_node('call_function', op, args, kwargs)
@register_quant_pattern(torch.cat)
class Cat(QuantizeHandler):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
if not self.all_node_args:
return NotImplemented
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
kwargs = {**load_arg(quantized=False)(node.kwargs), 'scale': scale_arg, 'zero_point': zero_point_arg}
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.cat, load_arg(quantized=[0])(node.args), kwargs)
# handle conv, maybe followed by relu
# NB: matching order is reversed, that is we match from the bottom of this list to the beginning
@register_quant_pattern(torch.nn.Conv1d)
@register_quant_pattern(torch.nn.Conv2d)
@register_quant_pattern(torch.nn.Conv3d)
@register_quant_pattern(torch.nn.functional.conv1d)
@register_quant_pattern(torch.nn.functional.conv2d)
@register_quant_pattern(torch.nn.functional.conv3d)
# TODO: add qat.Conv1d and qat.Conv3d
@register_quant_pattern(torch.nn.qat.Conv2d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU1d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU2d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU3d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn1d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU1d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d)
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv1d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv3d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv1d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv2d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv3d))
# just for error checks
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
class ConvRelu(QuantizeHandler):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0] # type: ignore
self.conv_node = node
if node.op == "call_module":
self.conv = quantizer.modules[self.conv_node.target]
elif node.op == "call_function":
self.conv = node.target
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
# Supported combinations are:
# quant_type | activation (compute_type) | weight
# static quint8 qint8
# tuple (activation_dtype, weight_dtype, compute_dtype)
supported_dtypes = [
(torch.quint8, torch.qint8, None),
]
# TODO: debug option for conv module
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
# leave the op unquantized if the dtype combination is not supported
if dtypes not in supported_dtypes:
warnings.warn(
"dtype combination: {} is not "
"supported by Conv "
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
if self.relu_node:
conv_out = quantizer.quantized_graph.node_copy(self.conv_node, load_arg(quantized=False))
relu_args = [conv_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
return quantizer.quantized_graph.create_node(
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
else:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
activation_statically_quantized = activation_is_statically_quantized(qconfig)
if self.conv_node.op == 'call_module':
# note that relu should already be fused into conv module in the fusion step
assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
'please make sure to run fusion before prepare'
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
# 1. attach activation post process to module
self.conv.activation_post_process = quantizer.activation_post_process_map[node.name]
# 2. select quantized class
qconv_cls = get_static_quant_module_class(
type(self.conv), additional_static_quant_mapping)
quantized = qconv_cls.from_float(self.conv)
parent_name, name = _parent_name(self.conv_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
self.conv_node.target,
(load_arg(quantized=True)(self.conv_node.args[0]),),
{})
else: # call_function
assert self.conv_node.op == "call_function"
if debug:
args = load_arg(quantized=[0, 1])(self.conv_node.args)
args = load_arg(quantized=False)(self.conv_node.args)
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
op_out = quantizer.quantized_graph.create_node(
"call_function", self.conv, args, kwargs)
if self.relu_node:
relu_args = [op_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
op_out = quantizer.quantized_graph.create_node(
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
if activation_statically_quantized:
root_module = quantizer.modules['']
act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
act_post_process_node = self.relu_node if self.relu_node else self.conv_node
return quantize_node(
quantizer, op_out, quantizer.activation_post_process_map[act_post_process_name],
act_post_process_node, is_input=False)
else:
# output for dynamically quantized conv op is not quantized
return op_out
else:
assert len(self.conv_node.args) >= 7, \
"only conv2d calls with all arguments specified is supported right now in debug=False option"
args = load_arg(quantized=[0, 1])(self.conv_node.args)
# pack weight
weight = load_arg(quantized=True)(self.conv_node.args[1])
other_args = load_arg(quantized=False)(self.conv_node.args[2:])
prepack_args = tuple([weight] + list(other_args))
prepack_op = get_qconv_prepack_op(self.conv)
packed_weight = quantizer.quantized_graph.create_node(
"call_function", prepack_op, prepack_args, {})
assert activation_statically_quantized, \
"currently only static quantization is supported for conv"
# construct conv input
if activation_statically_quantized:
qconv_op = get_qconv_op(self.conv, self.relu_node is not None)
conv_input = load_arg(quantized=True)(self.conv_node.args[0])
act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
activation_post_process = quantizer.activation_post_process_map[act_post_process_name]
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
scale_node, zero_point_node = create_qparam_nodes(quantizer, self.conv_node.name, scale, zero_point)
qconv_args = (conv_input, packed_weight, scale_node, zero_point_node)
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
op = quantizer.quantized_graph.create_node(
'call_function', qconv_op, qconv_args, kwargs)
# Store the name of the fused op to get the path of node after fusion as well.
# TODO: may need to change the key to Node regenerate the map in each transformation,
Loading ...