import torch
from torch.fx import ( # type: ignore
GraphModule,
Proxy,
map_arg
)
from torch.fx.graph import (
Graph,
Node,
)
from torch.fx.node import Argument
from torch.quantization import (
propagate_qconfig_,
convert,
)
from ..quantization_mappings import (
get_default_qat_module_mappings,
)
from ..quantize import (
_remove_qconfig,
is_activation_post_process
)
from ..utils import (
get_combined_dict,
get_swapped_custom_module_class,
activation_is_statically_quantized,
)
from .pattern_utils import (
is_match,
get_default_quant_patterns,
get_default_output_activation_post_process_map,
input_output_observed,
Pattern,
)
from .observed_module import (
mark_observed_module,
is_observed_module,
mark_observed_standalone_module,
is_observed_standalone_module,
)
from .quantization_patterns import *
from .utils import (
_parent_name,
quantize_node,
get_custom_module_class_keys,
get_new_attr_name_with_prefix,
collect_producer_nodes,
graph_module_from_producer_nodes,
assert_and_get_unique_device,
)
from .qconfig_utils import *
from typing import Optional, Dict, Any, List, Tuple, Set, Callable
# Define helper types
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]
# ------------------------
# Helper Functions
# ------------------------
def insert_observer(
node: Node, observer: torch.quantization.ObserverBase,
model: torch.nn.Module,
activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
env: Dict[Any, Any], observed_graph: Graph, load_arg: Callable,
observed_node_names_set: Set[str]):
"""Insert observer for node by modifying the observed_graph and
attach observer module to the model
Args:
node: Node
observer: observer/fake_quantize module instance
"""
# respect device affinity when adding observers
model_device = assert_and_get_unique_device(model)
if model_device:
observer.to(model_device)
# add observer module as attribute
prefix = node.name + '_activation_post_process_'
get_new_observer_name = get_new_attr_name_with_prefix(prefix)
observer_name = get_new_observer_name(model)
setattr(model, observer_name, observer)
# put observer instance activation_post_process map
assert activation_post_process_map is not None
activation_post_process_map[node.name] = observer
# insert observer call
env[node.name] = observed_graph.create_node(
'call_module', observer_name, (load_arg(node),), {})
observed_node_names_set.add(node.name)
def maybe_insert_observer_for_special_module(
quantize_handler: QuantizeHandler, modules: Dict[str, torch.nn.Module],
prepare_custom_config_dict: Any, qconfig: Any, node: Node) -> Optional[List[int]]:
""" Insert observer for custom module and standalone module
Returns: standalone_module_input_idxs: the indexs for inputs that
needs to be observed by parent module
"""
assert modules is not None
standalone_module_input_idxs = None
if isinstance(quantize_handler, CustomModuleQuantizeHandler):
custom_module = modules[node.target] # type: ignore
custom_module_class_mapping = prepare_custom_config_dict.get(
"float_to_observed_custom_module_class", {})
observed_custom_module_class = \
get_swapped_custom_module_class(
custom_module, custom_module_class_mapping, qconfig)
observed_custom_module = \
observed_custom_module_class.from_float(custom_module)
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, observed_custom_module)
elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
# observe standalone module
standalone_module = modules[node.target] # type: ignore
standalone_module_name_configs = prepare_custom_config_dict.get("standalone_module_name", [])
standalone_module_class_configs = prepare_custom_config_dict.get("standalone_module_class", [])
class_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_class_configs}
name_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_name_configs}
config = class_config_map.get(type(standalone_module), (None, None))
config = name_config_map.get(node.target, config)
sm_qconfig_dict = {"": qconfig} if config[0] is None else config[0]
sm_prepare_config_dict = {} if config[1] is None else config[1]
prepare = \
torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore
observed_standalone_module = \
prepare(standalone_module, sm_qconfig_dict, sm_prepare_config_dict)
standalone_module_input_idxs = observed_standalone_module.\
_standalone_module_input_quantized_idxs.int().tolist()
observed_standalone_module = mark_observed_standalone_module(
observed_standalone_module)
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name,
observed_standalone_module)
modules[node.target] = observed_standalone_module # type: ignore
return standalone_module_input_idxs
def insert_observer_for_output_of_the_node(
node: Node,
quantize_handler: QuantizeHandler,
qconfig: Any,
modules: Dict[str, torch.nn.Module],
model: torch.nn.Module,
pattern: Any,
activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
env: Dict[Any, Any],
observed_graph: Graph,
load_arg: Callable,
observed_node_names_set: Set[str],
matched_nodes: Optional[List[Node]],
standalone_module_input_idxs: Optional[List[int]]):
""" Insert observer/fake_quantize module for output of the observed
module if needed
"""
# don't need to insert observer for output if activation does not
# need to be statically quantized
assert modules is not None
if activation_is_statically_quantized(qconfig):
if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) \
and model.training:
# we only insert fake quantize module in qat
assert pattern is not None
activation_post_process_ctr = \
get_default_output_activation_post_process_map().get(
pattern, None)
assert activation_post_process_ctr is not None, \
"activation_post_process constructor not provided " + \
"for pattern:" + str(pattern)
insert_observer(
node, activation_post_process_ctr(),
model, activation_post_process_map, env, observed_graph,
load_arg, observed_node_names_set)
elif (isinstance(quantize_handler,
FixedQParamsOpQuantizeHandler) and
not model.training) or \
isinstance(quantize_handler, CopyNode):
# inserting observers for output of observed module, or
# mark the output as observed
assert node.op in [
'call_module',
'call_function',
'call_method'], \
'CopyNode of type ' + node.op + ' is not handled'
def is_observed(input_arg):
if isinstance(input_arg, Node):
return input_arg.name in observed_node_names_set
elif isinstance(input_arg, list):
return all(map(is_observed, input_arg))
# propagate observed property from input
if is_observed(node.args[0]):
observed_node_names_set.add(node.name)
elif ((isinstance(quantize_handler, Add) or
isinstance(quantize_handler, Mul)) and
quantize_handler.num_node_args == 1):
assert matched_nodes is not None
input_node = matched_nodes[-1] # first node in the sequence
def input_is_observed(arg):
return (isinstance(arg, Node) and
arg.name in observed_node_names_set)
# This is checking if one of the argument of add/mul
# is an observed node
# If both of the inputs are number,
# we will not consider the output to be observed
if (input_is_observed(input_node.args[0]) or
input_is_observed(input_node.args[1])):
observed_node_names_set.add(node.name)
elif isinstance(quantize_handler,
StandaloneModuleQuantizeHandler):
assert node.op == "call_module"
assert isinstance(node.target, str)
sm_out_qidxs = modules[node.target]._standalone_module_output_quantized_idxs.tolist() # type: ignore
output_is_quantized = 0 in sm_out_qidxs
if output_is_quantized:
observed_node_names_set.add(node.name)
elif (quantize_handler.all_node_args and
input_output_observed(quantize_handler)):
# observer for outputs
new_observer = qconfig.activation()
insert_observer(
node, new_observer, model,
activation_post_process_map, env, observed_graph,
load_arg, observed_node_names_set)
# insert observer for input of standalone module
if standalone_module_input_idxs is not None:
for idx in standalone_module_input_idxs:
if node.args[idx].name not in observed_node_names_set: # type: ignore
new_observer = qconfig.activation()
insert_observer(
node, new_observer, model,
activation_post_process_map, env, observed_graph,
load_arg, observed_node_names_set)
def insert_observer_for_input_arg_of_observed_node(
node: Node, observed_node_names_set: Set[str],
quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]],
model: torch.nn.Module,
activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
env: Dict[str, str], observed_graph: Graph,
load_arg: Callable):
if node.name not in observed_node_names_set and node.name in quants:
_, activation_post_process_ctr = quants[node.name]
if activation_post_process_ctr is not None:
insert_observer(
node, activation_post_process_ctr(),
model, activation_post_process_map,
env, observed_graph, load_arg, observed_node_names_set)
# A dictionary for querying the weight index for a given op
WEIGHT_INDEX_DICT = {
torch.nn.functional.conv1d : [1],
torch.nn.functional.conv2d : [1],
torch.nn.functional.conv3d : [1],
torch.nn.functional.linear : [1],
}
def node_arg_is_weight(node: Node, arg: Any) -> bool:
if isinstance(node, Node) and node.op == 'call_function' and \
node.target in WEIGHT_INDEX_DICT:
for i, node_arg in enumerate(node.args):
if arg is node_arg and i in \
WEIGHT_INDEX_DICT[node.target]: # type: ignore
return True
return False
CONV_OPS_WITH_BIAS = {
torch.nn.functional.conv1d,
torch.nn.functional.conv2d,
torch.nn.functional.conv3d,
}
CONV_BIAS_ARG_INDEX = 2
def node_arg_is_bias(node: Node, arg: Any) -> bool:
if isinstance(node, Node) and node.op == 'call_function':
if node.target in CONV_OPS_WITH_BIAS:
for i, node_arg in enumerate(node.args):
if arg is node_arg and i == CONV_BIAS_ARG_INDEX:
return True
elif node.target is torch.nn.functional.linear:
for kwarg_name, kwarg_value in node.kwargs.items():
if kwarg_name == 'bias' and arg is kwarg_value:
return True
return False
# weight prepacking ops
WEIGHT_PREPACK_OPS = {
torch._ops.ops.quantized.linear_prepack,
torch._ops.ops.quantized.linear_prepack_fp16,
torch._ops.ops.quantized.conv2d_prepack,
}
class Quantizer:
def __init__(self):
# mapping from matched node to activation_post_process
# must be filled before convert
self.activation_post_process_map: Optional[
Dict[str, torch.quantization.observer.ObserverBase]] = None
# mapping from node name to qconfig that should be used for that node
# filled out for a model during _generate_qconfig_map
self.qconfig_map: Optional[Dict[str, QConfigAny]] = None
# mapping from fully qualified module name to module instance
# for example,
# {
# '': Model(...),
# 'linear': Linear(...),
# 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
# }
self.modules: Optional[Dict[str, torch.nn.Module]] = None
# mapping from a tuple of nodes in reverse order to uninitialized
# QuantizeHandler subclass. For example,
# {
# # match a single node
# (<class 'torch.nn.modules.conv.Conv3d'>:
# <class 'torch.quantization.fx.quantize.ConvRelu'>),
# # match multiple nodes in reverse order
# ((<function relu at 0x7f766a7360d0>, <built-in function add>):
# <class 'torch.quantization.fx.quantize.Add'>),
# }
self.patterns: Optional[Dict[Pattern, QuantizeHandler]] = None
self.prepare_custom_config_dict: Dict[str, Any] = {}
# mapping from node name to the scope of the module which contains the node.
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
def _qat_swap_modules(
self, root: torch.nn.Module,
additional_qat_module_mapping: Dict[Callable, Callable]) -> None:
all_mappings = get_combined_dict(
get_default_qat_module_mappings(), additional_qat_module_mapping)
convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False)
Loading ...