import copy
import itertools
import warnings
import torch
import torch.nn as nn
import torch.nn.quantized as nnq
import torch.nn.quantizable as nnqa
from torch.nn.intrinsic import _FusedModule
from .quantization_mappings import (
get_default_dynamic_quant_module_mappings,
get_default_static_quant_module_mappings,
get_default_qat_module_mappings,
get_default_qconfig_propagation_list,
_has_special_act_post_process,
_get_special_act_post_process,
)
from .stubs import DeQuantStub, QuantWrapper
from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_weight_only_qconfig
def is_activation_post_process(module):
return (isinstance(module, torch.quantization.ObserverBase) or
isinstance(module, torch.quantization.FakeQuantizeBase))
def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None,
qconfig_parent=None, prefix=''):
r"""This is a helper function for `propagate_qconfig_`
Args:
module: input module
qconfig_dict: dictionary that maps from name of submodule to quantization
configuration
allow_list: list of quantizable modules
qconfig_parent: quantization config of parent module, we will fallback to
this config when there is no specified config for current
module
prefix: corresponding prefix of the current module, used as key in
qconfig_dict
Return:
None, module is modified inplace with qconfig attached
"""
# TODO: Add test
if allow_list is None:
allow_list = get_default_qconfig_propagation_list()
module_qconfig = qconfig_dict.get(type(module), qconfig_parent)
module_qconfig = qconfig_dict.get(prefix, module_qconfig)
module_qconfig = getattr(module, 'qconfig', module_qconfig)
torch.quantization.qconfig.assert_valid_qconfig(module_qconfig, module)
module.qconfig = module_qconfig
for name, child in module.named_children():
module_prefix = prefix + '.' + name if prefix else name
_propagate_qconfig_helper(child, qconfig_dict, allow_list,
module_qconfig, module_prefix)
# TODO(jerryzh): expose allow_list
def propagate_qconfig_(module, qconfig_dict=None, allow_list=None):
r"""Propagate qconfig through the module hierarchy and assign `qconfig`
attribute on each leaf module
Args:
module: input module
qconfig_dict: dictionary that maps from name or type of submodule to
quantization configuration, qconfig applies to all submodules of a
given module unless qconfig for the submodules are specified (when
the submodule already has qconfig attribute)
Return:
None, module is modified inplace with qconfig attached
"""
if qconfig_dict is None:
qconfig_dict = {}
_propagate_qconfig_helper(module, qconfig_dict, allow_list)
def _observer_forward_hook(self, input, output):
r"""Forward hook that calls observer on the output
"""
return self.activation_post_process(output)
def register_activation_post_process_hook(module):
assert hasattr(module, 'activation_post_process'), \
'Expect activation_post_process attribut already attached to the module'
return module.register_forward_hook(_observer_forward_hook)
def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
r"""Add observer for the leaf child of the module.
This function insert observer module to all leaf child module that
has a valid qconfig attribute.
Args:
module: input module with qconfig attributes for all the leaf modules that we want to quantize
device: parent device, if any
non_leaf_module_list: list of non-leaf modules we want to add observer
Return:
None, module is modified inplace with added observer modules and forward_hooks
"""
if qconfig_propagation_list is None:
qconfig_propagation_list = get_default_qconfig_propagation_list()
if custom_module_class_mapping is None:
custom_module_class_mapping = {}
# respect device affinity when adding observers
if device is None:
devices = get_unique_devices_(module)
assert len(devices) <= 1, (
"add_observer_ only works with cpu or single-device CUDA modules, "
"but got devices {}".format(devices)
)
device = next(iter(devices)) if len(devices) > 0 else None
def get_activation_post_process(qconfig, device, special_act_post_process=None):
activation = qconfig.activation() if special_act_post_process is None else special_act_post_process()
if device is not None:
activation.to(device)
return activation
def needs_observation(m):
return hasattr(m, 'qconfig') and m.qconfig is not None
def insert_activation_post_process(m, special_act_post_process=None):
""" Adds an activation post process module and register
a post hook that calls the module
"""
# We don't insert observer/fake_quantize for DeQuantStub
if needs_observation(m) and not isinstance(m, DeQuantStub):
# observer and hook will be gone after we swap the module
m.add_module('activation_post_process', get_activation_post_process(m.qconfig, device, special_act_post_process))
# Register observer as the first entry in the hook list
# All post forward hooks are preserved and will be executed after the observer before convert
handle = register_activation_post_process_hook(m)
m._forward_hooks.move_to_end(handle.id, last=False)
for name, child in module.named_children():
if type(child) in [nnq.FloatFunctional, nnq.QFunctional]:
if needs_observation(child):
child.activation_post_process = get_activation_post_process(child.qconfig, device)
elif isinstance(child, _FusedModule):
# activation_post_process are now added directly to nn.Sequentail/_FusedModule
if needs_observation(child):
insert_activation_post_process(child)
elif _has_special_act_post_process(child):
special_act_post_process = _get_special_act_post_process(child)
insert_activation_post_process(child, special_act_post_process)
elif non_leaf_module_list is not None and type(child) in non_leaf_module_list:
if needs_observation(child):
insert_activation_post_process(child)
elif needs_observation(child) and type(child) in custom_module_class_mapping:
observed_child = custom_module_class_mapping[type(child)].from_float(child)
setattr(module, name, observed_child)
# TODO: These are the modules that cannot be observed
# Once there are more, we should move them to a separate list
if custom_module_class_mapping[type(child)] != nnqa.LSTM:
insert_activation_post_process(observed_child)
else:
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
# Insert observers only for leaf nodes, note that this observer is for
# the output of the module, for input QuantStub will observe them
if len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential) \
and type(module) in qconfig_propagation_list:
insert_activation_post_process(module)
def get_unique_devices_(module):
return {p.device for p in module.parameters()} | \
{p.device for p in module.buffers()}
def add_quant_dequant(module):
r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
Note that this function will modify the children of module inplace and it
can return a new module which wraps the input module as well.
Args:
module: input module with qconfig attributes for all the leaf modules
that we want to quantize
Return:
Either the inplace modified module with submodules wrapped in
`QuantWrapper` based on qconfig or a new `QuantWrapper` module which
wraps the input module, the latter case only happens when the input
module is a leaf module and we want to quantize it.
"""
if len(module._modules) == 0 and hasattr(module, 'qconfig') and module.qconfig:
return QuantWrapper(module)
for name, child in module.named_children():
module._modules[name] = add_quant_dequant(child)
return module
def prepare(model, inplace=False, allow_list=None,
observer_non_leaf_module_list=None,
prepare_custom_config_dict=None):
r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
Quantization configuration should be assigned preemptively
to individual submodules in `.qconfig` attribute.
The model will be attached with observer or fake quant modules, and qconfig
will be propagated.
Args:
`model`: input model to be modified in-place
`inplace`: carry out model transformations in-place, the original module is mutated
`allow_list`: list of quantizable modules
`observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
`prepare_custom_config_dict`: customization configuration dictionary for prepare function
.. code-block:: python
# Example of prepare_custom_config_dict:
prepare_custom_config_dict = {
# user will manually define the corresponding observed
# module class which has a from_float class method that converts
# float custom module to observed custom module
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
}
"""
torch._C._log_api_usage_once("quantization_api.quantize.prepare")
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
if not inplace:
model = copy.deepcopy(model)
# TODO: remove allow_list
qconfig_propagation_list = allow_list
if qconfig_propagation_list is None:
qconfig_propagation_list = get_default_qconfig_propagation_list()
propagate_qconfig_(model, qconfig_dict=None)
# sanity check common API misusage
if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()):
warnings.warn("None of the submodule got qconfig applied. Make sure you "
"passed correct configuration through `qconfig_dict` or "
"by assigning the `.qconfig` attribute directly on submodules")
add_observer_(
model, qconfig_propagation_list, observer_non_leaf_module_list,
custom_module_class_mapping=custom_module_class_mapping)
return model
def _remove_activation_post_process(module):
# TODO: maybe we should change activation_post_process to _activation_post_process
# to prevent it from being used by user
if hasattr(module, 'activation_post_process') and \
is_activation_post_process(module.activation_post_process):
delattr(module, 'activation_post_process')
# remove activation_post_proceess hook
handle_ids_to_remove = set()
for handle_id, hook_fn in module._forward_hooks.items():
if hook_fn is _observer_forward_hook:
handle_ids_to_remove.add(handle_id)
for handle_id in handle_ids_to_remove:
module._forward_hooks.pop(handle_id)
# TODO: rename to something more general
def _remove_qconfig(module):
r"""Clean up the qconfig left in the module so that new qconfig can be
propagated.
Args:
module: module to be cleaned up
"""
for child in module.children():
_remove_qconfig(child)
if hasattr(module, "qconfig"):
del module.qconfig
_remove_activation_post_process(module)
def quantize(model, run_fn, run_args, mapping=None, inplace=False):
r"""Quantize the input float model with post training static quantization.
First it will prepare the model for calibration, then it calls
`run_fn` which will run the calibration step, after that we will
convert the model to a quantized model.
Args:
model: input float model
run_fn: a calibration function for calibrating the prepared model
run_args: positional arguments for `run_fn`
inplace: carry out model transformations in-place, the original module is mutated
mapping: correspondence between original module types and quantized counterparts
Return:
Quantized model.
"""
torch._C._log_api_usage_once("quantization_api.quantize.quantize")
if mapping is None:
mapping = get_default_static_quant_module_mappings()
if not inplace:
model = copy.deepcopy(model)
model.eval()
prepare(model, inplace=True)
run_fn(model, *run_args)
convert(model, mapping, inplace=True)
return model
def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
mapping=None, inplace=False):
r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
by default is performed for layers with large weights size - i.e. Linear and RNN variants.
Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`.
If `qconfig` is provided, the `dtype` argument is ignored.
Args:
model: input model
qconfig_spec: Either:
- A dictionary that maps from name or type of submodule to quantization
configuration, qconfig applies to all submodules of a given
module unless qconfig for the submodules are specified (when the
submodule already has qconfig attribute). Entries in the dictionary
need to be QConfigDynamic instances.
- A set of types and/or submodule names to apply dynamic quantization to,
in which case the `dtype` argument is used to specify the bit-width
inplace: carry out model transformations in-place, the original module is mutated
mapping: maps type of a submodule to a type of corresponding dynamically quantized version
with which the submodule needs to be replaced
"""
torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")
if qconfig_spec is None:
if dtype == torch.qint8:
qconfig_spec = {
Loading ...