Repository URL to install this package:
|
Version:
2.4.0 ▾
|
# mypy: allow-untyped-defs
from typing import Any, Dict, Set, Tuple, Callable, List
import torch
import torch.nn as nn
import torch.ao.nn.qat as nnqat
from abc import ABC, abstractmethod
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.fx.graph_module import GraphModule
from torch.ao.quantization.fx._model_report.model_report_observer import ModelReportObserver
from torch.ao.quantization.qconfig import (
QConfig,
default_qconfig,
_assert_valid_qconfig,
)
from torch.ao.quantization.observer import (
ObserverBase,
default_dynamic_quant_observer,
default_per_channel_weight_observer,
default_observer,
default_weight_observer,
)
from torch.ao.quantization.fx._equalize import (
default_equalization_qconfig,
EqualizationQConfig,
)
from torch.ao.quantization.observer import _is_activation_post_process
# Names for observer insert keys
DETECTOR_TARGET_NODE_KEY = "target_node"
DETECTOR_OBS_TO_INSERT_KEY = "observer_to_insert"
DETECTOR_IS_POST_OBS_KEY = "is_post_observer"
DETECTOR_OBS_ARGS_KEY = "observer_args"
# Mapping related code
class DetectorQConfigInfo:
r"""
This class contains the QConfig information for a single module.
The list of variables / values this contains can grow depending on the
extensibility of the qconfig mapping feature set but this currently includes:
- if activation observer is dynamic
- if weight observer is per channel
Args:
module_fqn (str): The fully qualified name (fqn) of the module that this
information contains info relevant to qconfig for
"""
def __init__(self, module_fqn: str):
super().__init__()
self.module_fqn = module_fqn
# populate this section with all the variables we might find important
# change from none if your detector is actually using this
self.is_activation_dynamic = False
self.is_weight_per_channel = False
# equalization related options
self.is_equalization_recommended = False
def generate_quantization_qconfig(self, module: torch.nn.Module) -> QConfig:
r"""
Args:
module (torch.nn.Module) The module we are generating
the qconfig for
Returns the generated quantization QConfig according to what a valid configuration is
"""
# Apply suggestions to new qconfig
module_qconfig = default_qconfig
# keep track of dynamic and per_channel recommendations
recommendations_list = []
# append as if a list of combinations
recommendations_list.append((self.is_activation_dynamic, self.is_weight_per_channel))
recommendations_list.append((self.is_activation_dynamic, False)) # only trying dynamic rec
recommendations_list.append((False, self.is_weight_per_channel)) # only trying dynamic
# now we try each of the combinations
for rec in recommendations_list:
# rec[0] -> dynamic recommended
# rec[1] -> per channel recommended
activation = default_dynamic_quant_observer if rec[0] else default_observer
weight = default_per_channel_weight_observer if rec[1] else default_weight_observer
test_config = QConfig(activation, weight)
try:
_assert_valid_qconfig(test_config, module)
module_qconfig = test_config
break
except AssertionError:
# if not a valid configuration, we move on to the next one in priority
continue
# return the QConfig chosen
return module_qconfig
def generate_equalization_qconfig(self) -> EqualizationQConfig:
r"""
This returns the equalization configuration for a module.
For now, it just returns the default, but as more equalization options become
possible, this method can get more fleshed out with more nuanced granularity.
Returns the generated equalization QConfig according to what a valid configuration is
"""
# in this case, we just return default equalization config
# we know this is valid because only valid modules would even
# have this option
return default_equalization_qconfig
# Adding base class for detectors
class DetectorBase(ABC):
r""" Base Detector Module
Any detector class should derive from this class.
Concrete detectors should follow the same general API, which includes:
- A method to calculate and return observer insertion points
- Should return both the fqns and the Observer class to insert
- A method to return a report based on the detector
- Should return a str-based report and dict info in Tuple[str,Dict] format
"""
def __init__(self):
super().__init__()
self.detector_config_info = None
@abstractmethod
def determine_observer_insert_points(self, model) -> Dict:
r"""
Args
model (nn.Module or subclass): model to find observer insertion points
Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict.
This dict maps string keys to detector specific information
"""
pass
@abstractmethod
def get_detector_name(self) -> str:
r""" Returns the name of the current detector """
pass
@abstractmethod
def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
r""" Returns the DetectorQConfigInfo for each module_fqn relevant
Args
model (nn.Module or subclass): model to find observer insertion points
Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
A DetectorQConfigInfo with the information to generate a QConfig for a specific module
"""
pass
def _get_targeting_node(self, prepared_fx_model: GraphModule, target_fqn: str) -> torch.fx.node.Node:
r"""
Takes in a GraphModule and the target_fqn and finds the node whose target is this fqn.
If it's not found, it means it is most likely inside a fused layer
We just go one layer up in terms of the fqn we are searching for until we find parent node
If we get to empty string, then we know that it doesn't exist
The reason for the recursion is that if the model that we are looking for got fused,
we will have module fqn as e.g. x.linear.0 but the graph will only have a node for the fused module,
which would have fqn as x.linear so they will not match.
To handle this, if we don't match, we then take off the last bit of the fqn e.g. x.linear.0 -> x.linear,
or more generally foo.bar.baz -> foo.bar and search again, this will allow us to locate the correct module
even in cases with fusion
Args:
prepared_fx_model (GraphModule): The prepared Fx GraphModule
target_fqn (str): The fqn of the layer we are trying to target
Returns the node object we are trying to add observers around
"""
for node in prepared_fx_model.graph.nodes:
# if the node's target is our target, return it
if node.target == target_fqn:
return node
# getting here means node not found
# if no "." we are already at base and failed
parent_fqn_sep_index = target_fqn.rfind(".")
if parent_fqn_sep_index == -1:
raise ValueError("passed in target_fqn not found in graph's targets.")
else:
# recursively call it with parent fqn
return self._get_targeting_node(prepared_fx_model, target_fqn[:parent_fqn_sep_index])
@abstractmethod
def generate_detector_report(self, model) -> Tuple[str, Dict[str, Any]]:
r"""
Args
model (nn.Module or subclass): model to find observer insertion points
Returns a Tuple of two elements:
Str: string report of the suggested improvements
Dict: contains useful data collected by the observer pertinent to this report
"""
pass
class PerChannelDetector(DetectorBase):
r""" This class is used to detect if any Linear or Conv layers in a model utilize per_channel quantization.
Only Linear and Conv layers can use per_channel as of now so only these two are currently checked.
per_channel quantization can lead to major benefits in the form of accuracy.
Therefore, if the backend used by the user supports it, it is recommended to use
Args:
backend (str, optional): the backend the user wishes to use in production
Default value is current torch.backends.quantized.engine
"""
# Keys for return dictionary
BACKEND_KEY = "backend"
PER_CHAN_SUPPORTED_KEY = "per_channel_quantization_supported"
PER_CHAN_USED_KEY = "per_channel_quantization_used"
# Default map for representing supported per channel quantization modules for different backends
DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: Dict[str, Set[Any]] = {
"fbgemm": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
"qnnpack": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
"onednn": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
"x86": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
}
def __init__(self, backend: str = torch.backends.quantized.engine):
super().__init__()
# store the backend information
self.backend_chosen = backend
self.supported_modules = set()
if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES:
self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[self.backend_chosen]
else:
raise ValueError(f"Not configured to work with {self.backend_chosen}. Try a different default backend")
def get_detector_name(self) -> str:
r""" returns the string name of this detector"""
return "per_channel_detector"
def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
r""" Returns the DetectorQConfigInfo for each module_fqn relevant
Args
model (nn.Module or subclass): model to find observer insertion points
Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
A DetectorQConfigInfo with the information to generate a QConfig for a specific module
"""
# run the helper function to populate the dictionary
per_channel_info = self._detect_per_channel_helper(model)
# we actually have a qconfig info object we are populating
module_fqn_to_detector_qconfig_info = {}
for module_fqn in per_channel_info:
# create a detector info instance
detector_qconfig_info = DetectorQConfigInfo(module_fqn)
# see if per channel quantization is supported
per_chan_supported: bool = per_channel_info[module_fqn][self.PER_CHAN_SUPPORTED_KEY]
detector_qconfig_info.is_weight_per_channel = per_chan_supported
module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
return module_fqn_to_detector_qconfig_info
def determine_observer_insert_points(self, model: nn.Module) -> Dict:
r"""
There is no observers inserted for the PerChannelDetector.
Returns an empty dictionary since no observers are added or needed
"""
return {}
def _detect_per_channel_helper(self, model: nn.Module):
r"""
determines if per_channel quantization is supported in modules and submodules.
Returns a dictionary in the higher level _detect_per_channel function.
Each entry maps the fully-qualified-name to information on whether per_channel quantization.
Args:
model: The current module that is being checked to see if it is per_channel quantizable
Returns dictionary mapping fqns to if per_channel quantization is possible
"""
# create dict we will return
per_channel_info: Dict = {}
# get the fully qualified name and check if in list of modules to include and list of modules to ignore
for fqn, module in model.named_modules():
is_in_include_list = any(isinstance(module, x) for x in self.supported_modules)
# check if the module per_channel is supported
# based on backend
per_channel_supported = False
if is_in_include_list:
per_channel_supported = True
# assert statement for MyPy
q_config_file = module.qconfig
assert isinstance(q_config_file, QConfig)
# this object should either be fake quant or observer
q_or_s_obj = module.qconfig.weight.p.func()
assert isinstance(q_or_s_obj, (FakeQuantize, ObserverBase))
per_channel_used = False # will be true if found in qconfig
if hasattr(q_or_s_obj, "ch_axis"): # then we know that per_channel quantization used
# all fake quants have channel axis so need to check is_per_channel
if isinstance(q_or_s_obj, FakeQuantize):
if hasattr(q_or_s_obj, "is_per_channel") and q_or_s_obj.is_per_channel:
per_channel_used = True
elif isinstance(q_or_s_obj, ObserverBase):
# should be an observer otherwise
per_channel_used = True
else:
raise ValueError("Should be either observer or fake quant")
per_channel_info[fqn] = {
self.PER_CHAN_SUPPORTED_KEY: per_channel_supported,
self.PER_CHAN_USED_KEY: per_channel_used,
self.BACKEND_KEY: self.backend_chosen
}
return per_channel_info
def generate_detector_report(self, model: nn.Module) -> Tuple[str, Dict[str, Any]]:
r"""Checks if any Linear or Conv layers in the model utilize per_channel quantization.
Only Linear and Conv layers can use per_channel as of now so only these two are currently checked.
Looks at q_config format and backend to determine if per_channel can be utilized.
Uses the DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES structure to determine support
Args:
model: The prepared and calibrated model we want to check if using per_channel
Returns a tuple with two elements:
String report of potential actions to improve model (if per_channel quantization is available in backend)
Dictionary mapping per_channel quantizable elements to:
whether per_channel quantization is supported by the backend
if it is being utilized in the current model
"""
# run the helper function to populate the dictionary
per_channel_info = self._detect_per_channel_helper(model)
# String to let the user know of further optimizations
further_optims_str = f"Further Optimizations for backend {self.backend_chosen}: \n"
optimizations_possible = False
for fqn in per_channel_info:
fqn_dict = per_channel_info[fqn]
if fqn_dict[self.PER_CHAN_SUPPORTED_KEY] and not fqn_dict[self.PER_CHAN_USED_KEY]:
optimizations_possible = True
further_optims_str += f"Module {fqn} can be configured to use per_channel quantization.\n"
if optimizations_possible:
further_optims_str += (
"To use per_channel quantization, make sure the qconfig has a per_channel weight observer."
)
else:
further_optims_str += "No further per_channel optimizations possible."
# return the string and the dictionary form of same information
return (further_optims_str, per_channel_info)
class DynamicStaticDetector(DetectorBase):
r"""
Determines whether dynamic or static quantization is more appropriate for a given module.
Takes advantage of the ModelReportObserver that records range information.
Stationary distribution of data are strictly above tolerance level for the comparison statistic:
S = average_batch_activation_range/epoch_activation_range
Nonstationary distributions are below or at the tolerance level for this metric.
If the distribution of data right after the module is non-stationary, recommend dynamic quantization
Otherwise recommend static quantization
Args:
tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5
"""
# names for the pre and post observers that are inserted
DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer"
DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer"
# naming conventions for stationary vs non-stationary data
STATIONARY_STR = "stationary"
NON_STATIONARY_STR = "non-stationary"
# naming for activation
INPUT_ACTIVATION_PREFIX = "input_activation_"
OUTPUT_ACTIVATION_PREFIX = "output_activation_"
# naming conventions for the keys of the return module info
TOLERANCE_KEY = "dynamic_static_tolerance"
DEFAULT_DYNAMIC_REC_KEY = "dynamic_recommended"
PRE_OBS_COMP_STAT_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat"
POST_OBS_COMP_STAT_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat"
PRE_OBS_DATA_DIST_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification"
POST_OBS_DATA_DIST_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification"
IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported"
# modules that are supported both dynamic and static for this report function
DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear}
# modules that will be supported soon for both
DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d}
def __init__(self, tolerance=0.5):
super().__init__()
# set tolerance level and initialize a set to keep track of useful fqn locations
self.tolerance = tolerance
self.useful_observer_fqns: Set[str] = set()
def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]:
r"""
Determines where observers need to be inserted for the Dynamic vs Static detector.
For this detector, we want to place observers on either side of linear layers in the model.
Currently inserts observers for:
linear layers
Args:
prepared_fx_model (GraphModule): The prepared Fx GraphModule
Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
key "observer_args" -> The arguments that are meant to be passed into the observer
"""
# observer for this detector is ModelReportObserver
obs_ctr = ModelReportObserver
# return dict
obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
for fqn, module in prepared_fx_model.named_modules():
# make sure module is supported
if self._is_supported(module, insert=True):
# if it's a supported type, we want to get node and add observer insert locations
targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
# add entry for pre-observer
pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
obs_fqn_to_info[pre_obs_fqn] = {
DETECTOR_TARGET_NODE_KEY: targeted_node,
DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(),
DETECTOR_IS_POST_OBS_KEY: False,
DETECTOR_OBS_ARGS_KEY: targeted_node.args
}
# add entry for post-observer
post_obs_fqn = fqn + "." + self.DEFAULT_POST_OBSERVER_NAME
obs_fqn_to_info[post_obs_fqn] = {
DETECTOR_TARGET_NODE_KEY: targeted_node,
DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(),
DETECTOR_IS_POST_OBS_KEY: True,
DETECTOR_OBS_ARGS_KEY: (targeted_node,)
}
return obs_fqn_to_info
def get_detector_name(self) -> str:
r""" returns the string name of this detector"""
return "dynamic_vs_static_detector"
def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
r""" Returns the DetectorQConfigInfo for each module_fqn relevant
Args
model (nn.Module or subclass): model to find observer insertion points
Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
A DetectorQConfigInfo with the information to generate a QConfig for a specific module
"""
# run the helper function to populate the dictionary
dynamic_static_info = self._generate_dict_info(model)
# we actually have a qconfig info object we are populating
module_fqn_to_detector_qconfig_info = {}
for module_fqn in dynamic_static_info:
# create a detector info instance
detector_qconfig_info = DetectorQConfigInfo(module_fqn)
# see if per channel quantization is supported
dynamic_static_recommended: bool = dynamic_static_info[module_fqn][self.DEFAULT_DYNAMIC_REC_KEY]
detector_qconfig_info.is_activation_dynamic = dynamic_static_recommended
module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
return module_fqn_to_detector_qconfig_info
def _is_supported(self, module: nn.Module, insert: bool = False) -> bool:
r"""Returns whether the given module is supported for observers
Args
module: The module to check and ensure is supported
insert: True if this is check for observer insertion, false if for report gen
Returns True if the module is supported by observer, False otherwise
"""
# check to see if module is of a supported type
is_supported_type = any(isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED)
# check if it will be supported
future_supported_type = any(isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED)
# supported
supported = is_supported_type or future_supported_type
# this is check for observer insertion
if insert:
return supported
else:
# this is for report gen and we also need to check if it contains observers
has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) and hasattr(module, self.DEFAULT_POST_OBSERVER_NAME)
return supported and has_obs
def _generate_dict_info(self, model: GraphModule) -> Dict[str, Any]:
r"""
Helper function for generate_detector_report that does the generation of the dictionary.
This process is done as specified in generate_detector_report documentation
Args:
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
Returns a Dictionary mapping modules with ModelReportObservers around them to:
whether dynamic quantization is recommended
their S metric of input to module
whether input to module is stationary or non-stationary
their S metric of output of module
whether output of module is stationary or non-stationary
the tolerance level to decided whether input/output is stationary or non-stationary
whether it is currently supported or planned for the future
"""
# store modules dynamic vs static information
module_dynamic_static_info = {}
# This for loop goes through the modules, and extracts all relevant information into module_dynamic_static_info
# This information primary includes whether the data distributions around a supported module is stationary or not
# Based on this, it is recorded whether dynamic or static quantization is recommended
# loop through all submodules included nested ones
for fqn, module in model.named_modules():
# if module is Linear has the ModelReportObserver attached to it
if self._is_supported(module):
# get pre and post observers for the module
pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
post_obs = getattr(module, self.DEFAULT_POST_OBSERVER_NAME)
# get the statistics for each module
pre_stat = pre_obs.get_batch_to_epoch_ratio()
post_stat = post_obs.get_batch_to_epoch_ratio()
# record module, pre and post stat, and whether to do dynamic or static based off it
# true if post observer data distribution is non-stationary, false if it's stationary
dynamic_recommended = post_stat <= self.tolerance
# specify the classifications for whether data distributions considered stationary or non-stationary
pre_obs_dist_classif = self.STATIONARY_STR if pre_stat > self.tolerance else self.NON_STATIONARY_STR
post_obs_dist_classif = self.STATIONARY_STR if post_stat > self.tolerance else self.NON_STATIONARY_STR
# check if current support or future support
is_supported_type = any(isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED)
# store the set of important information for this module
module_info = {
self.TOLERANCE_KEY: self.tolerance,
self.DEFAULT_DYNAMIC_REC_KEY: dynamic_recommended,
self.PRE_OBS_COMP_STAT_KEY: pre_stat,
self.PRE_OBS_DATA_DIST_KEY: pre_obs_dist_classif,
self.POST_OBS_COMP_STAT_KEY: post_stat,
self.POST_OBS_DATA_DIST_KEY: post_obs_dist_classif,
self.IS_CURRENTLY_SUPPORTED_KEY: is_supported_type,
}
module_dynamic_static_info[fqn] = module_info
return module_dynamic_static_info
def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]:
r"""
Determines whether dynamic or static quantization is more appropriate for a given module.
Takes advantage of the ModelReportObserver that records range information.
Stationary distribution of data are strictly above tolerance level for the comparison statistic:
S = average_batch_activation_range/epoch_activation_range
Nonstationary distributions are below or at the tolerance level for this metric.
If the distribution of data right after the module is non-stationary, recommend dynamic quantization
Otherwise recommend static quantization
This will then generate suggestions for dynamic vs static quantization focused around Linear.
Args:
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
Returns a tuple with two elements:
String report of of whether dynamic or static quantization is recommended for certain modules
Dictionary mapping modules with ModelReportObservers around them to:
whether dynamic quantization is recommended
their S metric of input to module
whether input to module is stationary or non-stationary
their S metric of output of module
whether output of module is stationary or non-stationary
the tolerance level to decided whether input/output is stationary or non-stationary
whether it is currently supported or planned for the future
"""
# get the dictionary of the information to format the string report
module_dynamic_static_info = self._generate_dict_info(model)
dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n"
modules_added: bool = False # check to make sure at least 1 module added.
dynamic_benefit = " You will get more accurate results if you use dynamic quantization"
static_benefit = " You can increase model efficiency if you use static quantization"
future_support_str = ". This layer is not yet supported for dynamic quantization"
# This for loop goes through the information collected in module_dynamic_static_info and:
# Populates the string based report with the information from module_dynamic_static_info
# Compiles the complete report by appending relevant formatted strings
for module_fqn in module_dynamic_static_info.keys():
# there is at least 1 module for suggestion
modules_added = True
module_info = module_dynamic_static_info[module_fqn]
suggestion_string_template = "For module {} it is suggested to use {} quantization because {}.\n"
# decide what string formatting values will be
quantization_type = ""
quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}."
benefit_str = ""
# strings for if dynamic quantized per tensor is needed
recommend_per_tensor = ". We recommend to add a {} before this module if it is static."
rec_lay_to_add = "dynamic quantize per tensor layer"
dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add)
dynamic_per_tensor_reasoning_string = (
" This is because the input to this module has a non-stationary distribution"
)
# start composing explanation
if module_info[self.DEFAULT_DYNAMIC_REC_KEY]:
quantization_type = "dynamic"
# check if currently supported or future supported
benefit_str = dynamic_benefit
if not module_info[self.IS_CURRENTLY_SUPPORTED_KEY]:
benefit_str += future_support_str
else:
quantization_type = "static"
benefit_str = static_benefit
# now set the quantization explanation string
quantization_reasoning = (
quantization_reasoning.format(
module_fqn, module_info[self.PRE_OBS_DATA_DIST_KEY], module_info[self.POST_OBS_DATA_DIST_KEY]
)
+ benefit_str
)
# if we have a non-stationary input -> linear -> stationary we suggested static
# however, we want to also recommend they add a dynamic quantize per tensor right if this change is made
if (
module_info[self.PRE_OBS_DATA_DIST_KEY] == self.NON_STATIONARY_STR
and module_info[self.POST_OBS_DATA_DIST_KEY] == self.STATIONARY_STR
):
quantization_reasoning = (
quantization_reasoning + dynamic_per_tensor_string + dynamic_per_tensor_reasoning_string
)
# format the overall suggestion string with the specific inputs
module_suggestion_string = suggestion_string_template.format(
module_fqn, quantization_type, quantization_reasoning
)
# append to overall suggestion
dynamic_vs_static_string += module_suggestion_string
if not modules_added:
dynamic_vs_static_string += "No applicable layers for suggestions. Only linear and conv are valid.\n"
# return the string as well as the dictionary of information
return (dynamic_vs_static_string, module_dynamic_static_info)
class InputWeightEqualizationDetector(DetectorBase):
r"""
Determines whether input-weight equalization can help improve quantization for certain modules.
Specifically, this list of modules includes:
linear
conv
Determines whether input-weight equalization is recommended based on the comp stat:
s_c = sqrt(w_c/W)/sqrt(i_c/I)
where:
w_c is range of weight for channel c, W is range of weight over all channels
i_c is range of input for channel c, I is range of input over all channels
if s_c >= threshold or <= 1 / threshold, recommends input-weight equalization
Args:
ratio_threshold (float): The threshold for s_c to determine if input-weight equalization is suggested
Should be between 0 and 1 (both non-inclusive)
ch_axis (int, optional): The channel axis being observed to determine input weight equalization
Default: 1
* :attr:`ratio_threshold`: The threshold for s_c to determine if input-weight equalization is suggested
Should be between 0 and 1
* :attr:`ch_axis`: The channel axis being observed to determine input weight equalization
* :attr:`SUPPORTED_MODULES`: This specifies the modules that are supported for input-weight equalization
* :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
"""
SUPPORTED_MODULES: Set[Callable] = {nn.Linear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nnqat.Linear,
nnqat.Conv1d,
nnqat.Conv2d,
nnqat.Conv3d}
# names for the pre and post observers that are inserted
DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"
# weight / activation prefix for each of the below info
WEIGHT_PREFIX = "weight_"
ACTIVATION_PREFIX = "input_activation_"
# string names for keys of info dictionaries
PER_CHANNEL_MAX_KEY = "per_channel_max"
PER_CHANNEL_MIN_KEY = "per_channel_min"
GLOBAL_MAX_KEY = "global_max"
GLOBAL_MIN_KEY = "global_min"
# keys for return dict of recommendations
RECOMMENDED_KEY = "input_weight_equalization_recommended"
COMP_METRIC_KEY = "input_weight_channel_comparison_metrics"
THRESHOLD_KEY = "input_weight_threshold"
CHANNEL_KEY = "input_weight_channel_axis"
# default weight and info strings
WEIGHT_STR = "weight"
INPUT_STR = "input"
# default for what ratio we recommend input weight
DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO = 0.4
def __init__(self, ratio_threshold: float, ch_axis: int = 1):
# ensure passed in inputs are valid
if ratio_threshold <= 0 or ratio_threshold >= 1:
raise ValueError("Make sure threshold is > 0 and < 1")
# initialize attributes based on args
self.ratio_threshold: float = ratio_threshold
self.ch_axis: int = ch_axis
def _is_supported(self, module: nn.Module, insert: bool = False) -> bool:
r"""Returns whether the given module is supported for observers
Args
module: The module to check and ensure is supported
insert: True if this is check for observer insertion, false if for report gen
Returns True if the module is supported by observer, False otherwise
"""
# check to see if module is of a supported type
is_supported_type = any(type(module) is x for x in self.SUPPORTED_MODULES)
# this is check for observer insertion
if insert:
return is_supported_type
else:
# this is for report gen and we also need to check if it contains observers
has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
return is_supported_type and has_obs
def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
r""" Returns the DetectorQConfigInfo for each module_fqn relevant
Args
model (nn.Module or subclass): model to find observer insertion points
Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
A DetectorQConfigInfo with the information to generate a QConfig for a specific module
"""
# run the helper function to populate the dictionary
# find the range of inputs
input_values: Dict[str, Dict] = self._extract_input_info(model)
# find the range of weights
weight_values: Dict[str, Dict] = self._extract_weight_info(model)
# calculate per_channel comparison statistic s_c
comp_stats: Dict[str, torch.Tensor] = self._generate_comparison_values(input_values, weight_values)
# generate the return dictionary
input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info(input_values, weight_values, comp_stats)
# we actually have a qconfig info object we are populating
module_fqn_to_detector_qconfig_info = {}
for module_fqn in input_weight_equalization_info:
# create a detector info instance
detector_qconfig_info = DetectorQConfigInfo(module_fqn)
# see if per channel quantization is supported
input_weight_recommended: bool = input_weight_equalization_info[module_fqn][self.RECOMMENDED_KEY]
detector_qconfig_info.is_equalization_recommended = input_weight_recommended
module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
return module_fqn_to_detector_qconfig_info
def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]:
r"""Determines where observers need to be inserted for the Input Weight Equalization Detector.
For this detector, we want to place observers in front of supported layers.
Currently inserts observers for:
linear layers
conv layers
Args:
prepared_fx_model (GraphModule): The prepared Fx GraphModule
Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
key "observer_args" -> The arguments that are meant to be passed into the observer
"""
# observer for this detector is ModelReportObserver
obs_ctr = ModelReportObserver
# return dict
obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
for fqn, module in prepared_fx_model.named_modules():
# check to see if module is of a supported type
if self._is_supported(module, insert=True):
# if it's a supported type, we want to get node and add observer insert locations
targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
# add entry for pre-observer
pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
obs_fqn_to_info[pre_obs_fqn] = {
DETECTOR_TARGET_NODE_KEY: targeted_node,
DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis),
DETECTOR_IS_POST_OBS_KEY: False,
DETECTOR_OBS_ARGS_KEY: targeted_node.args,
}
return obs_fqn_to_info
def get_detector_name(self) -> str:
r"""Returns the name of this detector"""
return "input_weight_equalization_detector"
def _extract_input_info(self, model: GraphModule) -> Dict[str, Dict]:
r"""
Takes in a calibrated GraphModule and then finds the relevant observers.
It then extracts the input information for each observer returns it
Args
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
Returns a dict mapping relevant module fqns (str) to a dict with keys:
"input_activation_per_channel_max" : maps to the per_channel max values
"input_activation_per_channel_min" : maps to the per_channel min values
"input_activation_global_max" : maps to the global max recorded
"input_activation_global_min" : maps to the global min recorded
"""
# return dictionary mapping observer fqns to desired info
input_info: Dict[str, Dict] = {}
for fqn, module in model.named_modules():
# if module is supported and it has a pre-observer
if self._is_supported(module):
# get pre observer for the module
pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
input_info[fqn] = {
self.ACTIVATION_PREFIX + self.PER_CHANNEL_MAX_KEY: pre_obs.max_val,
self.ACTIVATION_PREFIX + self.PER_CHANNEL_MIN_KEY: pre_obs.min_val,
self.ACTIVATION_PREFIX + self.GLOBAL_MAX_KEY: max(pre_obs.max_val),
self.ACTIVATION_PREFIX + self.GLOBAL_MIN_KEY: min(pre_obs.min_val),
}
return input_info
def _extract_weight_info(self, model: GraphModule) -> Dict[str, Dict]:
r"""
Takes in a calibrated GraphModule and then finds the relevant observers.
It then extracts the weight information for each layer an observer is attached to.
Args
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
Returns a dict mapping module fqns (str) to a dict with keys:
"per_channel_max" : maps to the per_channel max values
"per_channel_min" : maps to the per_channel min values
"global_max" : maps to the global max recorded
"global_min" : maps to the global min recorded
"""
# return dictionary mapping observer fqns to desired info
weight_info: Dict[str, Dict] = {}
for fqn, module in model.named_modules():
# if module is supported and it has a pre-observer
if self._is_supported(module):
# we don't need actual observer, just the module weights
# calculate min and max vals
device = module.weight.device
min_val: torch.Tensor = torch.tensor([float('inf')], device=device)
max_val: torch.Tensor = torch.tensor([float('-inf')], device=device)
x_copy = module.weight
x_dim = x_copy.size()
new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
new_axis_list[self.ch_axis] = 0
new_axis_list[0] = self.ch_axis
y = x_copy.permute(new_axis_list)
# Need to match dtype of min/max because the updates to buffers
# are done in place and types need to match for comparisons
y = y.to(min_val.dtype)
y = torch.flatten(y, start_dim=1)
if min_val.numel() == 0 or max_val.numel() == 0:
min_val, max_val = torch.aminmax(y, dim=1)
else:
min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
min_val = torch.min(min_val_cur, min_val)
max_val = torch.max(max_val_cur, max_val)
weight_info[fqn] = {
self.WEIGHT_PREFIX + self.PER_CHANNEL_MAX_KEY: max_val,
self.WEIGHT_PREFIX + self.PER_CHANNEL_MIN_KEY: min_val,
self.WEIGHT_PREFIX + self.GLOBAL_MAX_KEY: max(max_val),
self.WEIGHT_PREFIX + self.GLOBAL_MIN_KEY: min(min_val),
}
return weight_info
def _calculate_range_ratio(self, info_dict: Dict, info_str: str, module_fqn: str) -> torch.Tensor:
r"""
Takes in an info dict and calculates the s_c matrix.
Args:
info_dict (dict): A dictionary of either input or weight range info
info_str (str): A str describing whether currently looking at weight or input info
Either "weight" or "input"
module_fqn (str): The fqn of the module we are looking at
Returns a tensor of values, where each value is the s_c stat for a different channel
"""
# calculate the ratios of the info
# get the prefix str
prefix_str = self.ACTIVATION_PREFIX if info_str == self.INPUT_STR else self.WEIGHT_PREFIX
per_channel_range = info_dict[prefix_str + self.PER_CHANNEL_MAX_KEY] - info_dict[prefix_str + self.PER_CHANNEL_MIN_KEY]
global_range = info_dict[prefix_str + self.GLOBAL_MAX_KEY] - info_dict[prefix_str + self.GLOBAL_MIN_KEY]
if global_range == 0:
range_zero_explanation = "We recommend removing this channel as it doesn't provide any useful information."
raise ValueError(
f"The range of the {info_str} data for module {module_fqn} is 0, "
f"which means you have a constant value channel. {range_zero_explanation}"
)
ratio = per_channel_range / global_range
return ratio
def _generate_comparison_values(self, input_info: Dict, weight_info: Dict) -> Dict[str, torch.Tensor]:
r"""
Takes in the information on the min and max values of the inputs and weights and:
Calculates the comp stat for each channel: s_c = sqrt(w_c/W)/sqrt(i_c/I)
Args:
input_info (dict): A dict mapping each observer to input range information
weight_info (dict): A dict mapping each observer to weight range information
Returns a dict mapping relevant observer fqns (str) to a 1-D tensor.
Each value is a different s_c value for a different channel
"""
# create return dictionary for each observer
module_fqn_to_channel: Dict[str, torch.Tensor] = {}
# for each module (both passed in dicts should have same keys)
for module_fqn in input_info:
# raise error if not in weight info
if module_fqn not in weight_info:
raise KeyError(f"Unable to find weight range stats for module {module_fqn}")
# calculate the ratios of the weight info and input info
weight_ratio = self._calculate_range_ratio(weight_info[module_fqn], self.WEIGHT_STR, module_fqn)
input_ratio = self._calculate_range_ratio(input_info[module_fqn], self.INPUT_STR, module_fqn)
# if mismatched size, because of grouping, we want to replicate weight enough times
weight_channels = len(weight_ratio)
input_channels = len(input_ratio)
if weight_channels != input_channels:
# we try to replicate
assert input_channels % weight_channels == 0, "input channels should be divisible by weight channels."
# get replication factor
rep_factor: int = input_channels // weight_channels
# weight ratio is (n,), input ratio is (k,), we just repeat weight ratio k // n
weight_ratio = weight_ratio.repeat(rep_factor)
# calculate the s metric per channel
s = torch.sqrt(weight_ratio) / torch.sqrt(input_ratio)
module_fqn_to_channel[module_fqn] = s
# return compiled observer ratios
return module_fqn_to_channel
def _generate_dict_info(self, input_info: Dict, weight_info: Dict, comp_stats: Dict) -> Dict[str, Dict]:
r"""
Helper function for generate_detector_report that does the generation of the dictionary.
This process is done as specified in generate_detector_report documentation
Args:
input_info (dict): A dict mapping each module to input range information
weight_info (dict): A dict mapping each module to weight range information
comp_stats (dict): A dict mapping each module to its corresponding comp stat
Returns a dictionary mapping each module with relevant ModelReportObservers around them to:
whether input weight equalization is recommended
their s_c metric compared to the threshold
the threshold used to make the recommendation
the channel used for recording data
the input channel range info
the weight channel range info
"""
# store modules input weight equalization info
input_weight_equalization_info: Dict[str, Dict] = {}
# for each module we add separate set of suggestions
for module_fqn in input_info:
# get relevant info for this module
mod_input_info: Dict = input_info[module_fqn]
mod_weight_info: Dict = weight_info[module_fqn]
mod_comp_stat: Dict = comp_stats[module_fqn]
# decide if each channel should have input weight equalization or not
channel_rec_vals: list = []
for val in mod_comp_stat:
float_rep: float = val.item()
# decide if recommending input weight equalization
recommended: bool = float_rep >= self.ratio_threshold and float_rep <= 1 / self.ratio_threshold
channel_rec_vals.append(recommended)
# build the return dict input
# also unpack input and weight dicts into it
input_weight_equalization_info[module_fqn] = {
self.RECOMMENDED_KEY: channel_rec_vals,
self.COMP_METRIC_KEY: mod_comp_stat,
self.THRESHOLD_KEY: self.ratio_threshold,
self.CHANNEL_KEY: self.ch_axis,
**mod_input_info,
**mod_weight_info,
}
# return our compiled info for each module
return input_weight_equalization_info
def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]:
r"""
Determines whether input weight equalization is appropriate for a given module.
Takes advantage of the ModelReport Observer which records per channel information of input range
It then uses the passed in weight info inconjunction to compute the desired ratio
Finally, it gives suggestions based on this information for each module of interest
Args:
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
Returns a tuple with two elements:
String report of of whether input weight equalization is recommended for certain modules
Dictionary mapping modules of interest to:
whether input weight equalization is recommended
their s_c metric compared to the threshold
the threshold used to make the recommendation
the channel used for recording data
the input channel range info
the weight channel range info
"""
# find the range of inputs
input_values: Dict[str, Dict] = self._extract_input_info(model)
# find the range of weights
weight_values: Dict[str, Dict] = self._extract_weight_info(model)
# calculate per_channel comparison statistic s_c
comp_stats: Dict[str, torch.Tensor] = self._generate_comparison_values(input_values, weight_values)
# generate the return dictionary
input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info(input_values, weight_values, comp_stats)
# now we can generate report based on this information
input_weight_string = "Input-Weight Equalization suggestions: \n"
# some strings to be formatted depending on module we are adding
module_suggestion_str = "For Module {} looked at with axis {}: \n"
channel_suggestion_str = "\tWe suggest {} input weight equalization because {}\n"
use_str = "to use"
no_use_str = "to not use"
input_weight_benefit_str = "{}/{} channels would benefit and we expect significant reduction in quantization error."
input_weight_non_benefit_reasoning = "{}/{} channels benefitting from input-weight equalization being applied."
input_weight_non_benefit_str = "we don't expect much improvement from input-weight equalization based on {}"
# added module check
added_module: bool = False
# compile the suggestion string
for module_fqn in input_weight_equalization_info:
# we added at least 1 module
added_module = True
# add the module level description
input_weight_string += module_suggestion_str.format(module_fqn, self.ch_axis)
mod_info: Dict[str, Any] = input_weight_equalization_info[module_fqn]
# gather info on how many channels would benefit from input weight and
recommendation_per_channel: torch.Tensor = mod_info[self.RECOMMENDED_KEY]
num_recs = sum(recommendation_per_channel)
if num_recs / len(recommendation_per_channel) >= self.DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO:
input_benefit_formatted = input_weight_benefit_str.format(num_recs, len(recommendation_per_channel))
channel_str = channel_suggestion_str.format(use_str, input_benefit_formatted)
input_weight_string += channel_str
else:
non_benefit_reason_formatted = input_weight_non_benefit_reasoning.format(num_recs, len(recommendation_per_channel))
non_benefit_str = input_weight_non_benefit_str.format(non_benefit_reason_formatted)
channel_str = channel_suggestion_str.format(no_use_str, non_benefit_str)
input_weight_string += channel_str
# if no modules looked at, amend return string
if not added_module:
input_weight_string += "No applicable layers for suggestions. Only linear and conv valid.\n"
# return a tuple with the string explanation and the compiled dict info
return (input_weight_string, input_weight_equalization_info)
class OutlierDetector(DetectorBase):
r"""
Determines whether there are significant outliers in activation data around a certain layer.
This is ideally used in conjunction with information on stationary vs. non-stationary distribution:
If the data is stationary, and there are significant outliers, then we want to flag them
We want to do this on a per channel basis for detecting outliers
Determines whether activation data is flagged as outlier based on if data is stationary and:
p_r = avg(100th percentile / "reference_percentile"th percentile)
where:
p_r is average percentile ratio across all batches in the epoch
reference_percentile is a percentile values between 0 and 100 exclusive
if p_r is above some threshold, then we consider the activations to have significant outliers
Args:
ratio_threshold (float, optional): The threshold for p_r to determine if there are outliers in activations
Should be >= 1
Default: 3.5
reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile
Should be between 0 and 1
Default: 0.975
fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier
If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user
regardless of whether we detected outliers or not in channel to take a closer look at channel results
Should be between 0 and 1
Default: 0.95
ch_axis (int, optional): The channel axis being observed to determine input weight equalization
Default: 1
* :attr:`ratio_threshold`: The threshold for p_r to determine if there are outliers in activations
The p_r value (average ratio of 100th percentile/reference_percentile) is compared to ratio_threshold
If it is significantly greater, then we consider it an outlier
This threshold was calculated based on the ratio of the percentiles in a normal distribution
The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing
* :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile
Should be between 0 and 1
The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing
* :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this
Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used
Should be between 0 and 1
* :attr:`ch_axis`: The channel axis being observed to determine outliers
* :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
"""
# names for the pre observers that are inserted
DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"
# pre activation prefix
INPUT_ACTIVATION_PREFIX = "input_activation_"
# names for dict keys
OUTLIER_KEY = "outliers_detected"
NUM_BATCHES_KEY = "outlier_detection_batches_used"
IS_SUFFICIENT_BATCHES_KEY = "outlier_detection_is_sufficient_batches"
COMP_METRIC_KEY = "outlier_detection_percentile_ratios"
RATIO_THRES_KEY = "outlier_detection_ratio_threshold"
REF_PERCENTILE_KEY = "outlier_detection_reference_percentile"
CHANNEL_AXIS_KEY = "outlier_detection_channel_axis"
MAX_VALS_KEY = INPUT_ACTIVATION_PREFIX + "per_channel_max"
CONSTANT_COUNTS_KEY = "constant_batch_counts"
def __init__(
self,
ratio_threshold: float = 3.5,
reference_percentile: float = 0.975,
fraction_batches_used_threshold: float = 0.95,
ch_axis: int = 1,
):
# initialize the variables of interest
self.ratio_threshold = ratio_threshold
# make sure passed in percentile is valid
assert reference_percentile >= 0 and reference_percentile <= 1
assert fraction_batches_used_threshold >= 0 and fraction_batches_used_threshold <= 1
self.reference_percentile = reference_percentile
self.fraction_batches_used_threshold = fraction_batches_used_threshold
self.ch_axis = ch_axis
def get_detector_name(self) -> str:
r"""Returns the name of this detector"""
return "outlier_detector"
def _supports_insertion(self, module: nn.Module) -> bool:
r"""Returns whether the given module is supported for observers insertion
Any module that doesn't have children and isn't an observer itself is supported
Args
module: The module to check and ensure is supported
Returns True if the module is supported by observer, False otherwise
"""
# case for insertion of module
# check if the module has any children and isn't observer
num_children = len(list(module.children()))
return num_children == 0 and not _is_activation_post_process(module)
def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
r""" Returns the DetectorQConfigInfo for each module_fqn relevant
Args
model (nn.Module or subclass): model to find observer insertion points
Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
A DetectorQConfigInfo with the information to generate a QConfig for a specific module
"""
# currently doesn't do anything for outlier detector
return {}
def _supports_report_gen(self, module: nn.Module) -> bool:
r"""Returns whether the given module is supported for report generation
Any module that has a model report pre-observer is supported
Args
module: The module to check and ensure is supported
Returns True if the module is supported by observer, False otherwise
"""
return hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]:
r""" Determines where observers need to be inserted for the Outlier Detector.
For this detector, we want to place observers in front of supported layers.
Currently inserts observers for:
all layers that do not have children (leaf level layers)
Args:
prepared_fx_model (GraphModule): The prepared Fx GraphModule
Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
key "observer_args" -> The arguments that are meant to be passed into the observer
"""
# observer for this detector is ModelReportObserver
obs_ctr = ModelReportObserver
# return dict
obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
for fqn, module in prepared_fx_model.named_modules():
# check to see if module is of a supported type
if self._supports_insertion(module):
# if it's a supported type, we want to get node and add observer insert locations
targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
# add entry for pre-observer
pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
obs_fqn_to_info[pre_obs_fqn] = {
DETECTOR_TARGET_NODE_KEY: targeted_node,
DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis, comp_percentile=self.reference_percentile),
DETECTOR_IS_POST_OBS_KEY: False,
DETECTOR_OBS_ARGS_KEY: targeted_node.args,
}
return obs_fqn_to_info
def _calculate_outlier_info(
self,
percentile_ratios: torch.Tensor,
counted_batches: torch.Tensor,
total_batches: int,
) -> Dict[str, List[bool]]:
r"""
Gives info on whether the percentile ratios calculated would be considered outliers
Also gives information on whether the collected data is statistically significant to make this claim
Args:
percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer
counted_batches (torch.Tensor): The number of batches used for average calculation per tensor
total_batches (int): The total number of batches that passed through observer in this epoch
Returns a dictionary mapping:
"outliers_detected" : list of bools per channel that are true if it is considered an outlier
"is_sufficient_batches": if o_r was >= fraction_batches_used_threshold:
where o_r = counted_batches / total_batches
"""
outlier_dict: Dict[str, List[bool]] = {self.OUTLIER_KEY: [], self.IS_SUFFICIENT_BATCHES_KEY: []}
# get both as flattened lists for easy mapping
ratios_list: List = percentile_ratios.tolist()
num_batches_list: List = counted_batches.tolist()
# calculate whether channels were statistically significant
significant_size = [
batch_size / total_batches >= self.fraction_batches_used_threshold for batch_size in num_batches_list
]
outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size
# calculate for each channel whether it's an outlier or not based on ratio
outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list]
outlier_dict[self.OUTLIER_KEY] = outlier_detected
# return the dictionary with the two lists
return outlier_dict
def _generate_info_dict(self, model: GraphModule) -> Dict[str, Dict]:
r"""
Helper function for generate_detector_report that does the generation of the dictionary.
This process is done as specified in generate_detector_report documentation
Args:
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
Returns a dict mapping relevant module fqns to:
whether there were outliers found in activation before
the number of batches used for each channel
whether fraction of applicable batches used is above fraction_batches_used_threshold
their p_r metric compared to the threshold
the threshold used to make the recommendation
the reference_percentile used to make the recommendation
the channel axis used to determine individual channels
the constant batch counts per channel
the per channel max values
"""
# return dictionary mapping observer fqns to desired info
info_dict: Dict[str, Dict] = {}
for fqn, module in model.named_modules():
# if module is supported and it has a pre-observer
if self._supports_report_gen(module):
# get pre observer for the module
pre_obs: ModelReportObserver = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
# get the number of batches and calculated ratio thresholds
num_batches: torch.Tensor = pre_obs.percentile_batches_tracked
average_ratios: torch.Tensor = pre_obs.average_percentile_ratio
channel_batch_cnts: torch.Tensor = pre_obs.constant_channels
total_batches: int = pre_obs.num_batches_tracked
# also get the max values
max_vals: torch.Tensor = pre_obs.max_val
# we have to specifically modify how we are recording negative ratio for pre-relu layers
for index, ratio_val in enumerate(average_ratios):
# check if we have a negative ratio
# a ratio might be negative if we have a situation where the 100th percentile is
# > 0 while the nth percentile is < 0, in which case this would not be detected
# as an outlier. Since we care more about magnitude, we make it positive.
if ratio_val.item() < 0:
# first make it positive
average_ratios[index] = -ratio_val
if ratio_val.item() < 1:
# if it's less than 1 we have the flip it as well
average_ratios[index] = 1 / ratio_val
outlier_calcs = self._calculate_outlier_info(average_ratios, num_batches, total_batches)
# calculate whether ratios were outliers
info_dict[fqn] = {
self.CHANNEL_AXIS_KEY: self.ch_axis,
self.REF_PERCENTILE_KEY: self.reference_percentile,
self.RATIO_THRES_KEY: self.ratio_threshold,
self.COMP_METRIC_KEY: average_ratios,
self.NUM_BATCHES_KEY: num_batches,
self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY],
self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[self.IS_SUFFICIENT_BATCHES_KEY],
self.CONSTANT_COUNTS_KEY: channel_batch_cnts,
self.MAX_VALS_KEY: max_vals
}
return info_dict
def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]:
r"""
Determines whether input weight equalization is appropriate for a given module.
Takes advantage of the ModelReport Observer which records the relevant percentile information
Args:
model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
Returns a tuple with two elements:
String report of of whether there are outliers in the activations around certain modules
Dictionary mapping modules of interest to:
whether there were outliers found in activation before
the number of batches used for each channel
whether fraction of applicable batches used is above fraction_batches_used_threshold
their p_r metric compared to the threshold
the threshold used to make the recommendation
the reference_percentile used to make the recommendation
the channel axis used to determine individual channels
the constant batch counts per channel
the per channel max values
"""
# generate the information dictionary of outlier information
info_dict = self._generate_info_dict(model)
# now we can generate report based on this information
outlier_string = "Outlier detection report: \n"
# added module check
added_module: bool = False
# some strings to be formatted depending on module we are adding
module_suggestion_str = "For Module {} looked at with axis {}: \n"
channel_suggestion_str = "\tFor channel {}, we found outliers in the preceding activation data with {}.\n"
channel_max_value_str = "a max value across all batches of {}"
note_string = "Note: outlier detection is only reliable for {}. We recommend {} to ensure the most accurate results."
note_distribution = "stationary distributions"
note_rec = "running the static vs. dynamic detector to ensure activation data before modules above is stationary"
# suggestion for constant batch check since that can make it no outliers
constant_str = "\tFor channel {}, we found {} constant value batches. {}\n"
constant_suggestion = "We recommend taking a look at the dict and data to see how frequent this occurred and why."
# compile the suggestion string
for module_fqn in info_dict:
# get module specific info
mod_info: Dict[str, Any] = info_dict[module_fqn]
# check to see if we already added high level model desc
added_model_desc = False
# look at each individual channel and add a suggestion
for index, outlier_detected in enumerate(mod_info[self.OUTLIER_KEY]):
if outlier_detected:
# we found at least 1 outlier
if not added_model_desc:
# add the module level description
outlier_string += module_suggestion_str.format(module_fqn, self.ch_axis)
added_model_desc = True
# we mark that we found at least one outlier
added_module = True
max_value_found_str = channel_max_value_str.format(mod_info[self.MAX_VALS_KEY][index])
channel_str = channel_suggestion_str.format(index, max_value_found_str)
outlier_string += channel_str
# also check if we found constant batch
if mod_info[self.CONSTANT_COUNTS_KEY][index] != 0:
# make sure we add a module level highlight.
if not added_model_desc:
# add the module level description
outlier_string += module_suggestion_str.format(module_fqn, self.ch_axis)
added_model_desc = True
constant_values_for_channel = mod_info[self.CONSTANT_COUNTS_KEY][index]
formatted_str = constant_str.format(index, constant_values_for_channel, constant_suggestion)
outlier_string += formatted_str
# we also added at least one thing to description
added_module = True
# if found outlier, give suggestion, else give default response
if added_module:
# compose the note string
note_composed = note_string.format(note_distribution, note_rec)
outlier_string += note_composed
else:
outlier_string += "There were no outliers found in the activations.\n"
return (outlier_string, info_dict)