"""
This module contains tooling to compare weights and activations
across models. Example usage::
import copy
import torch
import torch.ao.quantization.quantize_fx as quantize_fx
import torch.ao.ns._numeric_suite_fx as ns
m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval()
mp = quantize_fx.prepare_fx(m, {'': torch.ao.quantization.default_qconfig})
# We convert a copy because we need the original prepared model
# to be available for comparisons, and `quantize_fx.convert_fx` is inplace.
mq = quantize_fx.convert_fx(copy.deepcopy(mp))
#
# Comparing weights
#
# extract weight pairs
weight_comparison = ns.extract_weights('a', mp, 'b', mq)
# add SQNR for each comparison, inplace
ns.extend_logger_results_with_comparison(
weight_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
'sqnr')
# weight_comparison contains the weights from `mp` and `mq` stored
# in pairs, and can be used for further analysis.
#
# Comparing activations, with error propagation
#
# add loggers
mp_ns, mq_ns = ns.add_loggers(
'a', copy.deepcopy(mp),
'b', copy.deepcopy(mq),
ns.OutputLogger)
# send an example datum to capture intermediate activations
datum = torch.randn(1, 1, 1, 1)
mp_ns(datum)
mq_ns(datum)
# extract intermediate activations
act_comparison = ns.extract_logger_info(
mp_ns, mq_ns, ns.OutputLogger, 'b')
# add SQNR for each comparison, inplace
ns.extend_logger_results_with_comparison(
act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
'sqnr')
# act_comparison contains the activations from `mp_ns` and `mq_ns` stored
# in pairs, and can be used for further analysis.
#
# Comparing activations, without error propagation
#
# create shadow model
mp_shadows_mq = ns.add_shadow_loggers(
'a', copy.deepcopy(mp),
'b', copy.deepcopy(mq),
ns.OutputLogger)
# send an example datum to capture intermediate activations
datum = torch.randn(1, 1, 1, 1)
mp_shadows_mq(datum)
# extract intermediate activations
shadow_act_comparison = ns.extract_shadow_logger_info(
mp_shadows_mq, ns.OutputLogger, 'b')
# add SQNR for each comparison, inplace
ns.extend_logger_results_with_comparison(
shadow_act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
'sqnr')
# shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored
# in pairs, and can be used for further analysis.
"""
import collections
import torch
import torch.nn as nn
import torch.ao.quantization.quantize_fx as quantize_fx
from torch.fx import GraphModule
from torch.fx.graph import Node
from torch.ao.ns.fx.mappings import (
get_base_name_to_sets_of_related_ops,
)
from torch.ao.ns.fx.graph_matcher import (
get_matching_subgraph_pairs,
get_type_a_related_to_b,
)
from .fx.weight_utils import (
extract_weight_from_node,
)
from .fx.graph_passes import (
add_loggers_to_model,
create_a_shadows_b,
)
from .fx.utils import (
rekey_logger_info_on_node_name_of_model,
maybe_add_missing_fqns,
get_target_type_str,
)
from .fx.ns_types import (
NSSingleResultValuesType,
NSResultsType,
NSNodeTargetType,
)
from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
from torch.ao.quantization.backend_config import BackendConfig
from torch.ao.quantization.fx.match_utils import _find_matches
from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
from torch.ao.quantization.fx.qconfig_mapping_utils import _generate_node_name_to_qconfig
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
from torch.ao.quantization.qconfig import QConfigAny
from torch.ao.quantization import QConfigMapping
from torch.ao.ns.fx.n_shadows_utils import (
OutputProp,
_get_dedup_subgraphs,
SHADOW_WRAPPER_NODE_NAME_PREFIX,
group_results_by_subgraph,
create_results_comparison,
print_n_shadows_summary,
create_n_transformed_and_logged_copies_of_subgraph,
create_add_loggers_graph,
extract_weight_comparison,
)
from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
from typing import Dict, Tuple, Callable, List, Optional, Set, Any, Type
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
class OutputLogger(nn.Module):
"""
Base class for capturing intermediate values.
"""
stats: List[torch.Tensor]
stats_rnn: List[RNNReturnType]
# Mark as impure so that calls to it will not be removed during DCE.
_is_impure = True
def __init__(
self,
ref_node_name: str,
prev_node_name: str,
model_name: str,
ref_name: str,
prev_node_target_type: str,
ref_node_target_type: str,
results_type: str,
index_within_arg: int,
index_of_arg: int,
fqn: Optional[str],
qconfig_str: Optional[str] = '',
):
super().__init__()
self.stats: List[torch.Tensor] = []
self.stats_rnn: List[RNNReturnType] = []
# name of the node which was responsible for adding this logger
# Note:
# - if we are logging node outputs, this is the same as prev_node_name
# - if we are logging node inputs, this is the name of the node
# whose input this logger is logging.
#
# example, where logger1 is logging input of op1 and logger2 is logging
# the output of op1:
#
# x1 -> logger1 -> op1 -> logger2 -> x2
#
# in this example,
# - logger1's prev_node_name is x1 and ref_node_name is op1
# - logger2's prev_node_name is op1 and ref_node_name is op1
self.ref_node_name = ref_node_name
# name of the node whose output this Logger is capturing
self.prev_node_name = prev_node_name
# name of the model from which the node originated from
self.model_name = model_name
# reference name, used to match loggers from separate models
# to each other
self.ref_name = ref_name
# type of the target of the node whose output this logger is logging
self.prev_node_target_type = prev_node_target_type
# type of the target of the node which was responsible for adding this
# logger
self.ref_node_target_type = ref_node_target_type
# what kind of values are inside of stats
self.results_type = results_type
# index of this node within the arg of the input/output node
# for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
self.index_within_arg = index_within_arg
# index of this node within the args of the input/output node
# for example, in add(x1, x2), x2 would have index_of_arg == 1
self.index_of_arg = index_of_arg
# fully qualified name
self.fqn = fqn
# if loggers are added before prepare_fx, but we do not want
# collect results of calibration, only results after convert_fx
# so, we add a flag to control whether this logger collects data
self.enabled = True
# string representation of qconfig
self.qconfig_str = qconfig_str
# this can be turned off to reduce memory usage during calibration
self.save_activations = True
# Note: cannot annotate the type of x because TorchScript does not support
# the Union type.
def forward(self, x):
"""
""" # blank docblock to make autodoc happy
# TODO(future PR): consider designing this better, as the difference
# between these two flags is subtle and not obvious.
if not self.enabled:
return x
if not self.save_activations:
return x
# TODO(future PR): consider refactoring this to better reuse the parent
# class
if isinstance(x, torch.Tensor):
self.stats.append(x.detach())
elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
self.stats_rnn.append(new_res)
return x
def __repr__(self):
clean_dict = {
k: v
for k, v in self.__dict__.items()
# skip nn.Module keys
if (k != 'training') and not k.startswith('_')
}
return f"OutputLogger({clean_dict})"
class OutputComparisonLogger(OutputLogger):
"""
Same as OutputLogger, but also requires the original activation
in order to calculate the comparison at calibration time
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO(future PR): make the comparison function configurable
self.comparison_fn = torch.ao.ns.fx.utils.compute_sqnr
self.comparison_fn_name = 'sqnr'
# precalculated comparisons of logger output versus reference
self.comparisons = []
# precalculated comparisons function
def forward(self, x, x_ref):
"""
""" # blank docblock to make autodoc happy
if not self.enabled:
return x
assert isinstance(x, torch.Tensor), 'non-tensor inputs not yet supported'
if self.save_activations:
# save the activation, for debugging
self.stats.append(x.detach())
# save the comparison
self.comparisons.append(self.comparison_fn(x, x_ref))
return x
def __repr__(self):
clean_dict = {
k: v
for k, v in self.__dict__.items()
# skip nn.Module keys
if (k != 'training') and not k.startswith('_')
}
return f"OutputComparisonLogger({clean_dict})"
class NSTracer(quantize_fx.QuantizationTracer):
"""
Just like a regular FX quantization tracer, but treats observers and fake_quantize
modules as leaf modules.
"""
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
"""
""" # blank docblock to make autodoc happy
if isinstance(m, torch.ao.quantization.ObserverBase):
return True
elif isinstance(m, torch.ao.quantization.FakeQuantizeBase):
return True
return super().is_leaf_module(m, module_qualified_name)
def _extract_weights_one_model(
model_name: str,
model: GraphModule,
nodes_and_names_to_instrument: List[Tuple[Node, str]],
results: NSResultsType,
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
) -> None:
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
for node, ref_name in nodes_and_names_to_instrument:
res_type = NSSingleResultValuesType.WEIGHT.value
extracted_weight = extract_weight_from_node(
node, model, op_to_type_to_weight_extraction_fn)
if extracted_weight:
if ref_name not in results:
results[ref_name] = {res_type: {}}
results[ref_name][res_type][model_name] = [extracted_weight]
def _extract_weights_impl(
model_name_a: str,
gm_a: GraphModule,
model_name_b: str,
gm_b: GraphModule,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
) -> NSResultsType:
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
matched_subgraph_pairs = get_matching_subgraph_pairs(
gm_a, gm_b, base_name_to_sets_of_related_ops,
unmatchable_types_map)
# split the subgraph pairs into one data structure for each model
nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
for match_name, match in matched_subgraph_pairs.items():
subgraph_a, subgraph_b = match
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
# populate the results, one model at a time
Loading ...