import copy
import torch
import warnings
from torch.fx import (
GraphModule,
)
from torch.fx.graph import (
Graph,
Node,
)
from torch.fx.node import Argument
from ..quantize import (
propagate_qconfig_,
)
from ..observer import (
ObserverBase,
_is_activation_post_process
)
from ..qconfig import (
_is_reuse_input_qconfig,
QConfigAny,
)
from ..qconfig_mapping import (
QConfigMapping,
)
from .qconfig_mapping_utils import (
_generate_node_name_to_qconfig,
_update_qconfig_for_fusion,
_get_flattened_qconfig_dict,
_update_qconfig_for_qat,
)
from .quantize_handler import (
_default_root_node_getter,
_get_pattern_to_quantize_handlers,
QuantizeHandler,
)
from torch.ao.quantization.utils import (
Pattern,
NodePattern,
)
from ._equalize import (
is_equalization_observer,
node_supports_equalization,
)
from .pattern_utils import (
_sorted_patterns_dict,
)
from .match_utils import (
_MatchResultWithQConfig,
_find_matches,
)
from .utils import (
_insert_dequant_stubs_for_custom_module_lstm_output,
_is_custom_module_lstm,
_maybe_get_custom_module_lstm_from_node_arg,
_qconfig_satisfies_dtype_config_constraints,
get_custom_module_class_keys,
all_node_args_have_no_tensors,
assert_and_get_unique_device,
get_non_observable_arg_indexes_and_types,
get_new_attr_name_with_prefix,
node_arg_is_weight,
node_arg_is_bias,
NON_QUANTIZABLE_WEIGHT_OPS,
ObservedGraphModuleAttrs,
)
from torch.ao.quantization import (
PlaceholderObserver
)
from torch.ao.quantization.quantize import (
convert
)
from ..utils import (
_parent_name,
get_qconfig_dtypes,
get_swapped_custom_module_class,
activation_is_statically_quantized,
)
from ..backend_config.utils import (
get_pattern_to_dtype_configs,
get_module_to_qat_module,
get_fusion_pattern_to_root_node_getter,
)
from ..backend_config import (
BackendConfig,
DTypeConfig,
get_native_backend_config,
)
from .custom_config import (
PrepareCustomConfig,
StandaloneModuleConfigEntry,
)
from torch._subclasses import FakeTensor
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union, Callable
__all__ = [
"insert_observers_for_model",
"prepare",
"propagate_dtypes_for_known_nodes",
]
# list of dtypes to not add observers to
_DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None]
# note: the following default target dtype info dicts are temporary,
# should be moved to the new programmable API class soon
_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = {
"input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation,
"output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation
}
_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = {
"input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation,
"output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation
}
def _is_activation_post_process_node(node: Node, named_modules: Dict[str, torch.nn.Module]) -> bool:
return isinstance(node, torch.fx.Node) and node.op == "call_module" and \
_is_activation_post_process(named_modules[str(node.target)])
def _get_dtype_and_is_dynamic(obs_or_fq_ctr: Optional[Callable]) -> Tuple[Optional[torch.dtype], bool]:
""" Given a constructor for observer or fake quant module, returns
a Tuple of dtype and is_dynamic
"""
# TODO: instead of instantiating the instance, we can use inspect to get the default args
if obs_or_fq_ctr is None:
return None, False
else:
obs_or_fq = obs_or_fq_ctr()
return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False)
def _is_input_arg_dtype_supported_by_backend(
arg: Argument,
node: Node,
qconfig: QConfigAny,
dtype_config: DTypeConfig,
backend_config: BackendConfig,
) -> bool:
""" Check if the configured qconfig for the argument
is supported by the backend or not
"""
if isinstance(arg, (list, tuple)):
return all(_is_input_arg_dtype_supported_by_backend(
a, node, qconfig,
dtype_config, backend_config) for a in arg)
if not isinstance(arg, Node):
return True
# TODO: support check for standalone module
is_weight = node_arg_is_weight(node, arg, backend_config)
is_bias = node_arg_is_bias(node, arg, backend_config)
is_activation = not is_weight and not is_bias
if is_activation:
input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr")
qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic(input_act_obs_or_fq_ctr)
# TODO(future PR): remove the cast to bool below after figuring
# out why backend_config has is_dynamic set to None in some cases.
return (dtype_config.input_dtype is None) or (
dtype_config.input_dtype == qconfig_dtype and
bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic) and
_qconfig_satisfies_dtype_config_constraints(qconfig, dtype_config.input_dtype_with_constraints)
)
elif is_weight:
# TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", None)
qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq_ctr)
backend_config_weight_dtype = dtype_config.weight_dtype
dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype
qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False)
return backend_config_weight_dtype is None or (dtype_matches and qconfig_satisfies_constraints)
else: # bias
# TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", None)
qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq_ctr)
backend_config_bias_dtype = dtype_config.bias_dtype
return backend_config_bias_dtype is None or qconfig_bias_dtype == backend_config_bias_dtype
def _is_output_dtype_supported_by_backend(
node: Node,
qconfig: QConfigAny,
dtype_config: DTypeConfig,
) -> bool:
""" Check if the configured qconfig for the output
is supported by the backend or not
"""
# TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
backend_config_output_dtype = dtype_config.output_dtype
# TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend
# from input activation check can be reused here
qconfig_output_dtype = None
output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr")
qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq_ctr)
# TODO: this is a hack because we can only specify one activation_obs_or_fq for
# qconfig (qconfig.activation), and we are only supporting dynamically quantized
# linear op which has fp32 output dtype, this should be removed if we generalize
# the structure of qconfig in the future
if qconfig_output_is_dynamic:
qconfig_output_dtype = torch.float32
dtype_matches = qconfig_output_dtype == backend_config_output_dtype
qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
qconfig, dtype_config.output_dtype_with_constraints)
return backend_config_output_dtype is None or (dtype_matches and qconfig_satisfies_constraints)
def _is_observer_in_same_graph(node: Node, named_modules: Dict[str, torch.nn.Module]):
""" Check if observer in same graph
when the node output is not fp32 and input is 'placeholder'
the input is assumed to be quantized, so it is observed
in a different place rather than not observed.
"""
node_output_dtype = _get_arg_target_dtype_as_output(node, named_modules)
if len(node.args) > 0 and isinstance(node.args[0], Node):
if node_output_dtype == torch.quint8 and node.args[0].op == 'placeholder':
return False
return True
def _is_pattern_dtype_config_and_qconfig_supported_by_backend(
pattern: Optional[Pattern],
matched_node_pattern: Optional[List[Node]],
qconfig: QConfigAny,
backend_config: BackendConfig,
) -> bool:
""" Check if the dtype configuration of a pattern is supported by
the backend or not, and whether the qconfig satisfies constraints
specified in the corresponding dtype config.
"""
if backend_config is None or pattern is None:
return True
assert matched_node_pattern is not None and len(matched_node_pattern) >= 1
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
dtype_configs: List[DTypeConfig] = pattern_to_dtype_configs.get(pattern, [])
pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter)
root_node = root_node_getter(matched_node_pattern)
input_node = root_node
output_node = matched_node_pattern[0]
for dtype_config in dtype_configs:
# check if arg dtype are supported
supported = True
for arg in list(input_node.args) + list(input_node.kwargs.values()):
supported = supported and _is_input_arg_dtype_supported_by_backend(
arg, input_node, qconfig, dtype_config, backend_config)
# check if output dtype is supported
supported = supported and _is_output_dtype_supported_by_backend(
output_node, qconfig, dtype_config)
if supported:
return True
return False
def _get_standalone_module_configs(
node: Node,
named_modules: Dict[str, torch.nn.Module],
prepare_custom_config: PrepareCustomConfig,
parent_qconfig: QConfigAny,
parent_backend_config: Optional[BackendConfig],
) -> Tuple[QConfigMapping, Tuple[Any, ...], PrepareCustomConfig, Optional[BackendConfig]]:
"""
Returns the standalone module QConfigMapping and PrepareCustomConfig
for `node`, assuming that the module pointed to by `node` is
a standalone modules.
"""
module_name = str(node.target)
module_type = type(named_modules[module_name]) # type: ignore[index]
# name config has precedence over type config
config_entry = StandaloneModuleConfigEntry(None, (), None, None)
config_entry = prepare_custom_config.standalone_module_classes.get(module_type, config_entry)
config_entry = prepare_custom_config.standalone_module_names.get(module_name, config_entry)
# fallback to use parent module's qconfig if user didn't specify qconfig dict
qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global(parent_qconfig)
example_inputs = config_entry.example_inputs
prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig()
backend_config = config_entry.backend_config or parent_backend_config
return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
def _qat_swap_modules(
root: torch.nn.Module,
module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]]) -> None:
convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False)
def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: Set[str]):
if isinstance(matched_node_pattern, Node):
s.add(matched_node_pattern.name)
elif isinstance(matched_node_pattern, (list, tuple)):
for maybe_node in matched_node_pattern:
_add_matched_node_name_to_set(maybe_node, s)
def _insert_observer(
node: Node,
observer: ObserverBase,
model: torch.nn.Module,
named_modules: Dict[str, torch.nn.Module],
graph: Graph,
) -> Node:
"""
Attaches `observer` to `model`, and creates a node which calls
`observer` on the output of `node`.
"""
model_device = assert_and_get_unique_device(model)
if model_device:
observer.to(model_device)
# add observer module as attribute
if is_equalization_observer(observer):
prefix = node.name + '_equalization_process_'
else:
prefix = 'activation_post_process_'
get_new_observer_name = get_new_attr_name_with_prefix(prefix)
observer_name = get_new_observer_name(model)
setattr(model, observer_name, observer)
named_modules[observer_name] = observer
with graph.inserting_after(node):
new_obs = graph.create_node(
'call_module', observer_name, (node,), {})
return new_obs
def _set_target_dtype_info_for_matched_node_pattern(
matched_node_pattern: NodePattern,
last_node: Node,
qconfig: QConfigAny,
backend_config: BackendConfig,
named_modules: Dict[str, torch.nn.Module],
cache_for_no_tensor_check: Dict[Node, bool],
processed_nodes: Set[Node],
) -> None:
""" Sets the target_dtype_info for each node in matched_node_pattern
Note: processed_nodes is used to ensure we only process each node once
"""
if isinstance(matched_node_pattern, (list, tuple)):
for node_pattern in matched_node_pattern:
_set_target_dtype_info_for_matched_node_pattern(
node_pattern,
last_node,
Loading ...