Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ quantization / fx / quantize.py

import torch
from torch.fx import (  # type: ignore
    GraphModule,
    Proxy,
    map_arg
)

from torch.fx.graph import (
    Graph,
    Node,
)

from torch.fx.node import Argument

from torch.quantization import (
    propagate_qconfig_,
    convert,
)

from ..quantization_mappings import (
    get_default_qat_module_mappings,
)

from ..quantize import (
    _remove_qconfig,
    is_activation_post_process
)

from ..utils import (
    get_combined_dict,
    get_swapped_custom_module_class,
    activation_is_statically_quantized,
)

from .pattern_utils import (
    is_match,
    get_default_quant_patterns,
    get_default_output_activation_post_process_map,
    input_output_observed,
    Pattern,
)

from .observed_module import (
    mark_observed_module,
    is_observed_module,
    mark_observed_standalone_module,
    is_observed_standalone_module,
)

from .quantization_patterns import *

from .utils import (
    _parent_name,
    quantize_node,
    get_custom_module_class_keys,
    get_new_attr_name_with_prefix,
    collect_producer_nodes,
    graph_module_from_producer_nodes,
    assert_and_get_unique_device,
)

from .qconfig_utils import *

from typing import Optional, Dict, Any, List, Tuple, Set, Callable

# Define helper types
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
                    QConfigAny]

# ------------------------
# Helper Functions
# ------------------------

def insert_observer(
        node: Node, observer: torch.quantization.ObserverBase,
        model: torch.nn.Module,
        activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
        env: Dict[Any, Any], observed_graph: Graph, load_arg: Callable,
        observed_node_names_set: Set[str]):
    """Insert observer for node by modifying the observed_graph and
       attach observer module to the model
       Args:
         node: Node
         observer: observer/fake_quantize module instance
    """
    # respect device affinity when adding observers
    model_device = assert_and_get_unique_device(model)
    if model_device:
        observer.to(model_device)
    # add observer module as attribute
    prefix = node.name + '_activation_post_process_'
    get_new_observer_name = get_new_attr_name_with_prefix(prefix)
    observer_name = get_new_observer_name(model)
    setattr(model, observer_name, observer)
    # put observer instance activation_post_process map
    assert activation_post_process_map is not None
    activation_post_process_map[node.name] = observer
    # insert observer call
    env[node.name] = observed_graph.create_node(
        'call_module', observer_name, (load_arg(node),), {})
    observed_node_names_set.add(node.name)

def maybe_insert_observer_for_special_module(
        quantize_handler: QuantizeHandler, modules: Dict[str, torch.nn.Module],
        prepare_custom_config_dict: Any, qconfig: Any, node: Node) -> Optional[List[int]]:
    """ Insert observer for custom module and standalone module
      Returns: standalone_module_input_idxs: the indexs for inputs that
      needs to be observed by parent module
    """
    assert modules is not None
    standalone_module_input_idxs = None
    if isinstance(quantize_handler, CustomModuleQuantizeHandler):
        custom_module = modules[node.target]  # type: ignore
        custom_module_class_mapping = prepare_custom_config_dict.get(
            "float_to_observed_custom_module_class", {})
        observed_custom_module_class = \
            get_swapped_custom_module_class(
                custom_module, custom_module_class_mapping, qconfig)
        observed_custom_module = \
            observed_custom_module_class.from_float(custom_module)
        parent_name, name = _parent_name(node.target)
        setattr(modules[parent_name], name, observed_custom_module)
    elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
        # observe standalone module
        standalone_module = modules[node.target]  # type: ignore
        standalone_module_name_configs = prepare_custom_config_dict.get("standalone_module_name", [])
        standalone_module_class_configs = prepare_custom_config_dict.get("standalone_module_class", [])
        class_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_class_configs}
        name_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_name_configs}
        config = class_config_map.get(type(standalone_module), (None, None))
        config = name_config_map.get(node.target, config)
        sm_qconfig_dict = {"": qconfig} if config[0] is None else config[0]
        sm_prepare_config_dict = {} if config[1] is None else config[1]
        prepare = \
            torch.quantization.quantize_fx._prepare_standalone_module_fx  # type: ignore
        observed_standalone_module = \
            prepare(standalone_module, sm_qconfig_dict, sm_prepare_config_dict)
        standalone_module_input_idxs = observed_standalone_module.\
            _standalone_module_input_quantized_idxs.int().tolist()
        observed_standalone_module = mark_observed_standalone_module(
            observed_standalone_module)
        parent_name, name = _parent_name(node.target)
        setattr(modules[parent_name], name,
                observed_standalone_module)
        modules[node.target] = observed_standalone_module  # type: ignore
    return standalone_module_input_idxs

def insert_observer_for_output_of_the_node(
        node: Node,
        quantize_handler: QuantizeHandler,
        qconfig: Any,
        modules: Dict[str, torch.nn.Module],
        model: torch.nn.Module,
        pattern: Any,
        activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
        env: Dict[Any, Any],
        observed_graph: Graph,
        load_arg: Callable,
        observed_node_names_set: Set[str],
        matched_nodes: Optional[List[Node]],
        standalone_module_input_idxs: Optional[List[int]]):
    """ Insert observer/fake_quantize module for output of the observed
    module if needed
    """
    # don't need to insert observer for output if activation does not
    # need to be statically quantized
    assert modules is not None
    if activation_is_statically_quantized(qconfig):
        if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) \
                and model.training:
            # we only insert fake quantize module in qat
            assert pattern is not None
            activation_post_process_ctr = \
                get_default_output_activation_post_process_map().get(
                    pattern, None)
            assert activation_post_process_ctr is not None, \
                "activation_post_process constructor not provided " + \
                "for pattern:" + str(pattern)
            insert_observer(
                node, activation_post_process_ctr(),
                model, activation_post_process_map, env, observed_graph,
                load_arg, observed_node_names_set)
        elif (isinstance(quantize_handler,
                         FixedQParamsOpQuantizeHandler) and
              not model.training) or \
                isinstance(quantize_handler, CopyNode):
            # inserting observers for output of observed module, or
            # mark the output as observed
            assert node.op in [
                'call_module',
                'call_function',
                'call_method'], \
                'CopyNode of type ' + node.op + ' is not handled'

            def is_observed(input_arg):
                if isinstance(input_arg, Node):
                    return input_arg.name in observed_node_names_set
                elif isinstance(input_arg, list):
                    return all(map(is_observed, input_arg))
            # propagate observed property from input
            if is_observed(node.args[0]):
                observed_node_names_set.add(node.name)
        elif ((isinstance(quantize_handler, Add) or
                isinstance(quantize_handler, Mul)) and
              quantize_handler.num_node_args == 1):
            assert matched_nodes is not None
            input_node = matched_nodes[-1]  # first node in the sequence

            def input_is_observed(arg):
                return (isinstance(arg, Node) and
                        arg.name in observed_node_names_set)
            # This is checking if one of the argument of add/mul
            # is an observed node
            # If both of the inputs are number,
            # we will not consider the output to be observed
            if (input_is_observed(input_node.args[0]) or
                    input_is_observed(input_node.args[1])):
                observed_node_names_set.add(node.name)
        elif isinstance(quantize_handler,
                        StandaloneModuleQuantizeHandler):
            assert node.op == "call_module"
            assert isinstance(node.target, str)
            sm_out_qidxs = modules[node.target]._standalone_module_output_quantized_idxs.tolist()  # type: ignore
            output_is_quantized = 0 in sm_out_qidxs

            if output_is_quantized:
                observed_node_names_set.add(node.name)
        elif (quantize_handler.all_node_args and
              input_output_observed(quantize_handler)):
            # observer for outputs
            new_observer = qconfig.activation()
            insert_observer(
                node, new_observer, model,
                activation_post_process_map, env, observed_graph,
                load_arg, observed_node_names_set)

        # insert observer for input of standalone module
        if standalone_module_input_idxs is not None:
            for idx in standalone_module_input_idxs:
                if node.args[idx].name not in observed_node_names_set:  # type: ignore
                    new_observer = qconfig.activation()
                    insert_observer(
                        node, new_observer, model,
                        activation_post_process_map, env, observed_graph,
                        load_arg, observed_node_names_set)

def insert_observer_for_input_arg_of_observed_node(
        node: Node, observed_node_names_set: Set[str],
        quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]],
        model: torch.nn.Module,
        activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
        env: Dict[str, str], observed_graph: Graph,
        load_arg: Callable):
    if node.name not in observed_node_names_set and node.name in quants:
        _, activation_post_process_ctr = quants[node.name]
        if activation_post_process_ctr is not None:
            insert_observer(
                node, activation_post_process_ctr(),
                model, activation_post_process_map,
                env, observed_graph, load_arg, observed_node_names_set)

# A dictionary for querying the weight index for a given op
WEIGHT_INDEX_DICT = {
    torch.nn.functional.conv1d : [1],
    torch.nn.functional.conv2d : [1],
    torch.nn.functional.conv3d : [1],
    torch.nn.functional.linear : [1],
}

def node_arg_is_weight(node: Node, arg: Any) -> bool:
    if isinstance(node, Node) and node.op == 'call_function' and \
            node.target in WEIGHT_INDEX_DICT:
        for i, node_arg in enumerate(node.args):
            if arg is node_arg and i in \
                    WEIGHT_INDEX_DICT[node.target]:  # type: ignore
                return True
    return False

CONV_OPS_WITH_BIAS = {
    torch.nn.functional.conv1d,
    torch.nn.functional.conv2d,
    torch.nn.functional.conv3d,
}
CONV_BIAS_ARG_INDEX = 2

def node_arg_is_bias(node: Node, arg: Any) -> bool:
    if isinstance(node, Node) and node.op == 'call_function':
        if node.target in CONV_OPS_WITH_BIAS:
            for i, node_arg in enumerate(node.args):
                if arg is node_arg and i == CONV_BIAS_ARG_INDEX:
                    return True
        elif node.target is torch.nn.functional.linear:
            for kwarg_name, kwarg_value in node.kwargs.items():
                if kwarg_name == 'bias' and arg is kwarg_value:
                    return True
    return False

# weight prepacking ops
WEIGHT_PREPACK_OPS = {
    torch._ops.ops.quantized.linear_prepack,
    torch._ops.ops.quantized.linear_prepack_fp16,
    torch._ops.ops.quantized.conv2d_prepack,
}

class Quantizer:
    def __init__(self):
        # mapping from matched node to activation_post_process
        # must be filled before convert
        self.activation_post_process_map: Optional[
            Dict[str, torch.quantization.observer.ObserverBase]] = None
        # mapping from node name to qconfig that should be used for that node
        # filled out for a model during _generate_qconfig_map
        self.qconfig_map: Optional[Dict[str, QConfigAny]] = None
        # mapping from fully qualified module name to module instance
        # for example,
        # {
        #   '': Model(...),
        #   'linear': Linear(...),
        #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
        # }
        self.modules: Optional[Dict[str, torch.nn.Module]] = None
        # mapping from a tuple of nodes in reverse order to uninitialized
        #   QuantizeHandler subclass. For example,
        # {
        #   # match a single node
        #   (<class 'torch.nn.modules.conv.Conv3d'>:
        #     <class 'torch.quantization.fx.quantize.ConvRelu'>),
        #   # match multiple nodes in reverse order
        #   ((<function relu at 0x7f766a7360d0>, <built-in function add>):
        #     <class 'torch.quantization.fx.quantize.Add'>),
        # }
        self.patterns: Optional[Dict[Pattern, QuantizeHandler]] = None
        self.prepare_custom_config_dict: Dict[str, Any] = {}

        # mapping from node name to the scope of the module which contains the node.
        self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}


    def _qat_swap_modules(
            self, root: torch.nn.Module,
            additional_qat_module_mapping: Dict[Callable, Callable]) -> None:
        all_mappings = get_combined_dict(
            get_default_qat_module_mappings(), additional_qat_module_mapping)
        convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False)
Loading ...