import copy
import torch
import torch.nn as nn
from torch.ao.quantization import (
QConfigAny,
QuantType,
)
from torch.ao.quantization.backend_config import (
BackendConfig,
DTypeWithConstraints,
)
from torch.ao.quantization.fake_quantize import (
FakeQuantizeBase,
FixedQParamsFakeQuantize,
)
from torch.ao.quantization.observer import (
FixedQParamsObserver,
ObserverBase,
)
from torch.ao.quantization.qconfig import (
float16_static_qconfig,
float16_dynamic_qconfig,
qconfig_equals,
)
from torch.ao.quantization.stubs import DeQuantStub
from torch.ao.quantization.utils import (
activation_is_statically_quantized,
)
from torch.ao.quantization.observer import _is_activation_post_process
from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.fx import GraphModule, map_arg
from torch.fx.graph import (
Graph,
Node,
)
from .custom_config import PrepareCustomConfig
# importing the lib so that the quantized_decomposed ops are registered
from ._decomposed import quantized_decomposed_lib # noqa: F401
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type
from dataclasses import dataclass
from collections import namedtuple
import operator
import warnings
# TODO: revisit this list. Many helper methods shouldn't be public
__all__ = [
"all_node_args_except_first",
"all_node_args_have_no_tensors",
"assert_and_get_unique_device",
"collect_producer_nodes",
"create_getattr_from_value",
"create_node_from_old_node_preserve_meta",
"EMPTY_ARG_DICT",
"get_custom_module_class_keys",
"get_linear_prepack_op_for_dtype",
"get_new_attr_name_with_prefix",
"get_non_observable_arg_indexes_and_types",
"get_qconv_prepack_op",
"get_skipped_module_name_and_classes",
"graph_module_from_producer_nodes",
"maybe_get_next_module",
"NodeInfo",
"node_arg_is_bias",
"node_arg_is_weight",
"NON_OBSERVABLE_ARG_DICT",
"NON_QUANTIZABLE_WEIGHT_OPS",
"return_arg_list",
"ObservedGraphModuleAttrs",
]
NON_QUANTIZABLE_WEIGHT_OPS = {torch.nn.functional.layer_norm, torch.nn.functional.group_norm, torch.nn.functional.instance_norm}
@dataclass
class ObservedGraphModuleAttrs:
node_name_to_qconfig: Dict[str, QConfigAny]
node_name_to_scope: Dict[str, Tuple[str, type]]
prepare_custom_config: PrepareCustomConfig
equalization_node_name_to_qconfig: Dict[str, Any]
qconfig_mapping: QConfigMapping
is_qat: bool
observed_node_names: Set[str]
is_observed_standalone_module: bool = False
standalone_module_input_quantized_idxs: Optional[List[int]] = None
standalone_module_output_quantized_idxs: Optional[List[int]] = None
def node_arg_is_weight(node: Node, arg: Any, backend_config: BackendConfig) -> bool:
"""Returns if node arg is weight"""
if isinstance(node, Node) and node.op == "call_function" and \
node.target in backend_config._pattern_complex_format_to_config:
weight_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("weight")
if weight_index is not None and weight_index < len(node.args) and node.args[weight_index] is arg:
return True
return node.kwargs.get("weight") is arg
return False
def node_arg_is_bias(node: Node, arg: Any, backend_config: BackendConfig) -> bool:
"""Returns if node arg is bias"""
if isinstance(node, Node) and node.op == "call_function" and \
node.target in backend_config._pattern_complex_format_to_config:
bias_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("bias")
if bias_index is not None and bias_index < len(node.args) and node.args[bias_index] is arg:
return True
return node.kwargs.get("bias") is arg
return False
def get_custom_module_class_keys(custom_module_mapping: Dict[QuantType, Dict[Type, Type]]) -> List[Any]:
r""" Get all the unique custom module keys in the custom config dict
e.g.
Input:
{
QuantType.STATIC: {
CustomModule1: ObservedCustomModule
},
QuantType.DYNAMIC: {
CustomModule2: DynamicObservedCustomModule
},
QuantType.WEIGHT_ONLY: {
CustomModule3: WeightOnlyObservedCustomModule
},
}
Output:
# extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts
[CustomModule1, CustomModule2, CustomModule3]
"""
# using set to dedup
float_custom_module_classes : Set[Any] = set()
for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]:
quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {})
quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys())
float_custom_module_classes |= quant_mode_custom_module_classes
return list(float_custom_module_classes)
def get_linear_prepack_op_for_dtype(dtype):
if dtype == torch.float16:
return torch.ops.quantized.linear_prepack_fp16
elif dtype == torch.qint8:
return torch.ops.quantized.linear_prepack
else:
raise Exception("can't get linear prepack op for dtype:", dtype)
def get_qconv_prepack_op(conv_op: Callable) -> Callable:
prepack_ops = {
torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack,
torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack,
torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack
}
prepack_op = prepack_ops.get(conv_op, None)
assert prepack_op, "Didn't find prepack op for {}".format(conv_op)
return prepack_op
# Returns a function that can get a new attribute name for module with given
# prefix, for example,
# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
# >> new_name = get_new_observer_name(module)
# new_name will be an unused attribute name on module, e.g. `_observer_1`
def get_new_attr_name_with_prefix(prefix: str) -> Callable:
prefix = prefix.replace(".", "_")
def get_new_attr_name(module: torch.nn.Module):
def get_attr_name(i: int):
return prefix + str(i)
i = 0
attr_name = get_attr_name(i)
while hasattr(module, attr_name):
i += 1
attr_name = get_attr_name(i)
return attr_name
return get_new_attr_name
def collect_producer_nodes(node: Node) -> Optional[List[Node]]:
r''' Starting from a target node, trace back until we hit inpu or
getattr node. This is used to extract the chain of operators
starting from getattr to the target node, for example
def forward(self, x):
observed = self.observer(self.weight)
return F.linear(x, observed)
collect_producer_nodes(observed) will either return a list of nodes that
produces the observed node or None if we can't extract a self contained
graph without free variables(inputs of the forward function).
'''
nodes = [node]
frontier = [node]
while frontier:
node = frontier.pop()
all_args = list(node.args) + list(node.kwargs.values())
for arg in all_args:
if not isinstance(arg, Node):
continue
if arg.op == 'placeholder':
# hit input, can't fold in this case
return None
nodes.append(arg)
if not (arg.op == 'call_function' and arg.target == getattr):
frontier.append(arg)
return nodes
def graph_module_from_producer_nodes(
root: GraphModule, producer_nodes: List[Node]) -> GraphModule:
r''' Construct a graph module from extracted producer nodes
from `collect_producer_nodes` function
Args:
root: the root module for the original graph
producer_nodes: a list of nodes we use to construct the graph
Return:
A graph module constructed from the producer nodes
'''
assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
# since we traced back from node to getattrr
producer_nodes.reverse()
graph = Graph()
env: Dict[Any, Any] = {}
def load_arg(a):
return map_arg(a, lambda node: env[node])
for producer_node in producer_nodes:
env[producer_node] = graph.node_copy(producer_node, load_arg)
graph.output(load_arg(producer_nodes[-1]))
graph_module = GraphModule(root, graph)
return graph_module
def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
"""
Returns the unique device for a module, or None if no device is found.
Throws an error if multiple devices are detected.
"""
devices = {p.device for p in module.parameters()} | \
{p.device for p in module.buffers()}
assert len(devices) <= 1, (
"prepare only works with cpu or single-device CUDA modules, "
"but got devices {}".format(devices)
)
device = next(iter(devices)) if len(devices) > 0 else None
return device
def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node:
"""
Given a value of any type, creates a getattr node corresponding to the value and
registers the value as a buffer to the module.
"""
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
attr_name = get_new_attr_name(module)
device = assert_and_get_unique_device(module)
new_value = value.clone().detach() if isinstance(value, torch.Tensor) \
else torch.tensor(value, device=device)
module.register_buffer(attr_name, new_value)
# Create get_attr with value
attr_node = graph.create_node("get_attr", attr_name)
return attr_node
def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module], cache: Dict[Node, bool]) -> bool:
"""
If we know for sure that all of this node's args have no
tensors (are primitives), return True. If we either
find a tensor or are not sure, return False. Note: this
function is not exact.
"""
if cache and node in cache:
return cache[node]
result = False # will be overwritten
if not isinstance(node, Node):
result = True
elif node.op == 'placeholder':
result = False
elif node.op == 'call_module':
assert isinstance(node.target, str)
if _is_activation_post_process(modules[node.target]):
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
elif node.op == 'call_module':
result = False
elif node.op == 'call_function' and node.target is operator.getitem:
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
elif node.op == 'get_attr':
result = False
elif node.target is getattr and node.args[1] in ['ndim', 'shape']:
# x1 = x0.ndim
result = True
elif node.op == 'call_method' and node.target == 'size':
# x1 = x0.size(0)
result = True
else:
found_one_tensor = False
for arg in node.args:
if isinstance(arg, list):
for list_el in arg:
if isinstance(list_el, Node):
this_list_el_args_have_no_tensors = \
all_node_args_have_no_tensors(list_el, modules, cache)
found_one_tensor = found_one_tensor or \
(not this_list_el_args_have_no_tensors)
# If found_one_tensor is True, there is no point in
# recursing further as the end result will always
# be True.
# TODO(future PR): remove this entire function and
# change to dtype inference without recursion.
if found_one_tensor:
result = not found_one_tensor
if cache:
cache[node] = result
return result
elif isinstance(arg, int):
pass
else:
if isinstance(arg, Node):
this_arg_args_have_no_tensors = all_node_args_have_no_tensors(arg, modules, cache)
found_one_tensor = found_one_tensor or \
(not this_arg_args_have_no_tensors)
# If found_one_tensor is True, there is no point in
# recursing further as the end result will always
# be True.
# TODO(future PR): remove this entire function and
# change to dtype inference without recursion.
if found_one_tensor:
result = not found_one_tensor
if cache:
cache[node] = result
return result
else:
found_one_tensor = True
result = not found_one_tensor
if cache:
cache[node] = result
return result
def all_node_args_except_first(node: Node) -> List[int]:
"""
Returns all node arg indices after first
"""
return list(range(1, len(node.args)))
def return_arg_list(arg_indices: List[int]) -> Callable[[Node], List[int]]:
"""
Constructs a function that takes a node as arg and returns the arg_indices
that are valid for node.args
"""
def arg_indices_func(node: Node) -> List[int]:
return [i for i in arg_indices if i < len(node.args)]
return arg_indices_func
NodeInfo = namedtuple("NodeInfo", "op target")
Loading ...