Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ ao / quantization / _quantize_pt2e.py

from torch.fx import GraphModule

from .qconfig_mapping import QConfigMapping
from .backend_config import BackendConfig
from .fx import prepare
from .quantize_fx import _convert_to_reference_decomposed_fx
from ._pt2e.utils import (
    _get_renamed_nn_module_stack,
    _fuse_conv_bn_,
    _rearrange_weight_observer_for_addmm,
)

from typing import Tuple, Any, Dict

def prepare_pt2e(
    model: GraphModule,
    qconfig_mapping: QConfigMapping,
    example_inputs: Tuple[Any, ...],
    backend_config: BackendConfig,
):
    # TODO: move this information to fx node itself
    node_name_to_scope: Dict[str, Tuple[str, type]] = {}
    for n in model.graph.nodes:
        renamed_stack = _get_renamed_nn_module_stack(n.meta.get("nn_module_stack", None))
        current_scope = list(renamed_stack.items())[-1]
        node_name_to_scope[n.name] = current_scope

    # TODO: check qconfig_mapping to make sure conv and bn are both configured
    # to be quantized before fusion
    # TODO: (maybe) rewrite this with subgraph_rewriter
    _fuse_conv_bn_(model)
    model = prepare(
        model,
        qconfig_mapping,
        False,  # is_qat
        node_name_to_scope,
        example_inputs,
        backend_config=backend_config
    )

    # TODO: remove hack when we have better support for pattern matching
    # move around the observer for addmm
    _rearrange_weight_observer_for_addmm(model)
    return model

def convert_pt2e(
    model: GraphModule
):
    return _convert_to_reference_decomposed_fx(model)