from torch.fx import GraphModule
from .qconfig_mapping import QConfigMapping
from .backend_config import BackendConfig
from .fx import prepare
from .quantize_fx import _convert_to_reference_decomposed_fx
from ._pt2e.utils import (
_get_renamed_nn_module_stack,
_fuse_conv_bn_,
_rearrange_weight_observer_for_addmm,
)
from typing import Tuple, Any, Dict
def prepare_pt2e(
model: GraphModule,
qconfig_mapping: QConfigMapping,
example_inputs: Tuple[Any, ...],
backend_config: BackendConfig,
):
# TODO: move this information to fx node itself
node_name_to_scope: Dict[str, Tuple[str, type]] = {}
for n in model.graph.nodes:
renamed_stack = _get_renamed_nn_module_stack(n.meta.get("nn_module_stack", None))
current_scope = list(renamed_stack.items())[-1]
node_name_to_scope[n.name] = current_scope
# TODO: check qconfig_mapping to make sure conv and bn are both configured
# to be quantized before fusion
# TODO: (maybe) rewrite this with subgraph_rewriter
_fuse_conv_bn_(model)
model = prepare(
model,
qconfig_mapping,
False, # is_qat
node_name_to_scope,
example_inputs,
backend_config=backend_config
)
# TODO: remove hack when we have better support for pattern matching
# move around the observer for addmm
_rearrange_weight_observer_for_addmm(model)
return model
def convert_pt2e(
model: GraphModule
):
return _convert_to_reference_decomposed_fx(model)