Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ ao / quantization / fx / convert.py

from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type, Callable
from torch.ao.quantization.quant_type import QuantType
import torch
import copy
import warnings
from torch.fx import (
    GraphModule,
)
from torch.fx.graph import (
    Graph,
    Node,
    Argument,
)
from ..utils import (
    activation_is_statically_quantized,
    weight_is_quantized,
    get_qparam_dict,
    _parent_name,
    get_swapped_custom_module_class,
)
from ..qconfig import (
    QConfigAny,
    qconfig_equals
)
from ..qconfig_mapping import QConfigMapping
from .qconfig_mapping_utils import (
    _generate_node_name_to_qconfig,
    _compare_prepare_convert_qconfig_mappings,
    _update_qconfig_for_fusion,
    _is_qconfig_supported_by_dtype_configs,
    _update_qconfig_for_qat,
)
from torch.ao.quantization.backend_config.utils import (
    get_root_module_to_quantized_reference_module,
    get_pattern_to_dtype_configs,
    get_fused_module_classes,
    get_qat_module_classes,
)
from torch.ao.quantization.backend_config import (
    BackendConfig,
    get_native_backend_config,
)
from torch.ao.quantization.observer import _is_activation_post_process
from .graph_module import (
    _is_observed_module,
    _is_observed_standalone_module,
)
from ._equalize import update_obs_for_equalization, convert_eq_obs
from torch.nn.utils.parametrize import type_before_parametrizations
from .utils import (
    _get_module,
    _is_custom_module_lstm,
    get_custom_module_class_keys,
    create_getattr_from_value,
    collect_producer_nodes,
    graph_module_from_producer_nodes,
    node_arg_is_weight,
)
from torch.ao.quantization.utils import (
    is_per_channel,
    to_underlying_dtype,
)
from torch.ao.quantization.quantize import (
    _remove_qconfig,
)
from torch.ao.quantization.stubs import DeQuantStub
from .custom_config import (
    ConvertCustomConfig,
    PrepareCustomConfig,
)
from .lower_to_fbgemm import lower_to_fbgemm
# importing the lib so that the quantized_decomposed ops are registered
from ._decomposed import quantized_decomposed_lib  # noqa: F401
import operator

__all__ = [
    "convert",
    "convert_custom_module",
    "convert_standalone_module",
    "convert_weighted_module",
]

def _replace_observer_with_quantize_dequantize_node_decomposed(
        model: torch.nn.Module,
        graph: Graph,
        node: Node,
        modules: Dict[str, torch.nn.Module],
        node_name_to_scope: Dict[str, Tuple[str, type]],
        node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
    """ Replace activation_post_process module call node with quantize and
    dequantize node working with decomposed Tensor

    Before:
    ... -> observer_0(x) -> ...
    After:
    ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) ->
    torch.ops.quantized_decomposed.dequantize_per_tensor() -> ...

    or quantize_per_channel and dequantize_per_channel
    """
    assert modules is not None
    assert isinstance(node.target, str)
    module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
    activation_post_process = modules[node.target]
    # skip replacing observers to quant/dequant nodes if the qconfigs of all
    # consumers and producers of this observer are None
    skip_replacement = all([
        _has_none_qconfig(n, node_name_to_qconfig) for n in
        list(node.args) + list(node.users.keys())])
    if skip_replacement or not _is_conversion_supported(activation_post_process):
        # didn't find correponding quantize op and info for the activation_post_process
        # so we just remove the observer
        with graph.inserting_before(node):
            node.replace_all_uses_with(node.args[0])
            graph.erase_node(node)
        return

    # otherwise, we can convert the activation_post_process module call to quantize/dequantize node

    # 1. extract the information from activation_post_process module for generating
    # the quantize and dequantize operator
    dtype = activation_post_process.dtype  # type: ignore[attr-defined]

    is_dynamic = False
    if hasattr(activation_post_process, "is_dynamic"):
        is_dynamic = activation_post_process.is_dynamic  # type: ignore[assignment]

    if dtype in [torch.quint8, torch.qint8, torch.qint32] and \
            (not is_dynamic):
        # TODO: probably should cleanup this condition check, it's hard
        # to reason about this if and the following elif

        # uint8/int8/int32 static quantization branch

        # 1. extract information for inserting q/dq node from activation_post_process
        node_type = "call_function"
        quantize_op : Optional[Callable] = None
        scale, zero_point = activation_post_process.calculate_qparams()  # type: ignore[attr-defined, operator]
        if is_per_channel(activation_post_process.qscheme):  # type: ignore[attr-defined]
            ch_axis = int(activation_post_process.ch_axis)  # type: ignore[attr-defined, arg-type]
            quantize_op = torch.ops.quantized_decomposed.quantize_per_channel
            dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel
            quant_min = activation_post_process.quant_min
            quant_max = activation_post_process.quant_max
            dtype_ = to_underlying_dtype(dtype)
            qparams = {
                "_scale_": scale,
                "_zero_point_": zero_point,
                "_axis_": ch_axis,
                "_quant_min_": quant_min,
                "_quant_max_": quant_max,
                "_dtype_": dtype_
            }
        else:
            quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor
            dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor
            scale = float(scale)
            zero_point = int(zero_point)
            quant_min = activation_post_process.quant_min  # type: ignore[attr-defined]
            quant_max = activation_post_process.quant_max  # type: ignore[attr-defined]
            dtype_ = to_underlying_dtype(dtype)
            qparams = {
                "_scale_": scale,
                "_zero_point_": zero_point,
                "_quant_min_": quant_min,
                "_quant_max_": quant_max,
                "_dtype_": dtype_
            }

        # 2. replace activation_post_process node with quantize and dequantize
        with graph.inserting_before(node):
            input_node = node.args[0]
            quantize_op_inputs = [input_node]
            for key, value_or_node in qparams.items():
                # TODO: we can add the information of whether a value needs to
                # be registered as an attribute in qparams dict itself
                if key in ['_scale_', '_zero_point_']:
                    # For scale and zero_point values we register them as buffers in the root module.
                    # TODO: maybe need more complex attr name here
                    qparam_node = create_getattr_from_value(
                        model, graph, module_path + prefix + key, value_or_node)
                    quantize_op_inputs.append(qparam_node)
                else:
                    # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                    quantize_op_inputs.append(value_or_node)

            quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
            # use the same qparams from quantize op
            dq_inputs = [quantized_node] + quantize_op_inputs[1:]
            dequantized_node = graph.call_function(
                dequantize_op,
                tuple(dq_inputs),
                {}
            )
            node.replace_all_uses_with(dequantized_node)
            graph.erase_node(node)
    elif is_dynamic:

        # uint8/int8/fp16 dynamic quantization

        # 1. extract information for inserting q/dq node from activation_post_process
        node_type = "call_function"
        quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor
        # we only use choose_qparams for is_decomposed now,
        # but we should probably align the non-decomposed path with this as well,
        # and that can be done after we remove reduce_range flag
        # 1. extract qparams from activation_post_process module
        dtype_ = to_underlying_dtype(dtype)
        assert dtype_ in [torch.uint8, torch.int8], \
            "only uint8 and int8 are supported in reference flow for " \
            "dynamic quantization right now"
        quant_min = activation_post_process.quant_min  # type: ignore[attr-defined]
        quant_max = activation_post_process.quant_max  # type: ignore[attr-defined]
        # note: scale and zero_point are missing for quantize_per_tensor op
        # we'll need to get this from choose_qparams op, which we'll add after
        # this step
        qparams = {
            "_quant_min_": quant_min,
            "_quant_max_": quant_max,
            "_dtype_": dtype_
        }

        # 2. insert choose_qparams op and update the qparams list
        with graph.inserting_before(node):
            input_node = node.args[0]
            choose_qparams_op_inputs = [node.args[0]]
            for key, value in qparams.items():
                # we have quant_min, quant_max and dtype, all should be stored
                # as literals
                choose_qparams_op_inputs.append(value)
            choose_qparams_node = graph.create_node(
                "call_function",
                torch.ops.quantized_decomposed.choose_qparams.tensor,
                tuple(choose_qparams_op_inputs),
                {}
            )
            # choose_qparms returns (scale, zero_point)
            scale_node = graph.create_node(
                "call_function",
                operator.getitem,
                (choose_qparams_node, 0),
                {}
            )
            zero_point_node = graph.create_node(
                "call_function",
                operator.getitem,
                (choose_qparams_node, 1),
                {}
            )
            quant_min = qparams["_quant_min_"]
            quant_max = qparams["_quant_max_"]
            dtype = qparams["_dtype_"]
            qparams = {
                "_scale_": scale_node,
                "_zero_point_": zero_point_node,
                "_quant_min_": quant_min,
                "_quant_max_": quant_max,
                "_dtype_": dtype
            }

        # 3. replace activation_post_process node to quantize and dequantize node
        with graph.inserting_before(node):
            input_node = node.args[0]
            quantize_op_inputs = [input_node]
            for key, value_or_node in qparams.items():
                # TODO: we can add the information of whether a value needs to
                # be registered as an attribute in qparams dict itself
                if key in ['_scale_', '_zero_point_']:
                    # in this case we have a node in the graph since it's dynamically
                    # computed from the input, with choose_qparams op
                    qparam_node = value_or_node
                    quantize_op_inputs.append(qparam_node)
                else:
                    # for qparams that are not scale/zero_point (like axis, dtype) we
                    # store them as literals in the graph.
                    quantize_op_inputs.append(value_or_node)

            quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
            # use the same qparams from quantize op
            dq_inputs = [quantized_node] + quantize_op_inputs[1:]
            # need to use the tensor variant of this op, since scale and zero_point
            # from choose_qparam are Tensors, instead of float/int, this is to
            # prevent these nodes being traced away by downstream systems
            dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
            dequantized_node = graph.call_function(
                dequantize_op,
                tuple(dq_inputs),
                {}
            )
            node.replace_all_uses_with(dequantized_node)
            graph.erase_node(node)
    elif dtype == torch.float16:
        raise NotImplementedError("decomposed to float16 op not implemented yet")

    # should not reach since we have checks in the begining to make sure the
    # activation_post_process is supported

def _replace_observer_with_quantize_dequantize_node(
        model: torch.nn.Module,
        graph: Graph,
        node: Node,
        modules: Dict[str, torch.nn.Module],
        node_name_to_scope: Dict[str, Tuple[str, type]],
        node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
    """ Replace activation_post_process module call node with quantize and
    dequantize node

    Before:
    ... -> observer_0(x) -> ...
    After:
    ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
    """
    assert modules is not None
    assert isinstance(node.target, str)
    module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
    activation_post_process = modules[node.target]
    # skip replacing observers to quant/dequant nodes if the qconfigs of all
    # consumers and producers of this observer are None
    skip_replacement = all([
        _has_none_qconfig(n, node_name_to_qconfig) for n in
        list(node.args) + list(node.users.keys())])
    if skip_replacement or not _is_conversion_supported(activation_post_process):
        # didn't find correponding quantize op and info for the activation_post_process
        # so we just remove the observer
        with graph.inserting_before(node):
            node.replace_all_uses_with(node.args[0])
            graph.erase_node(node)
        return

    # otherwise, we can convert the activation_post_process module call to quantize/dequantize node
    dtype = activation_post_process.dtype  # type: ignore[attr-defined]

    is_dynamic = False
    if hasattr(activation_post_process, "is_dynamic"):
        is_dynamic = activation_post_process.is_dynamic  # type: ignore[attr-defined, assignment]

    if dtype in [torch.quint8, torch.qint8, torch.qint32] and \
            (not is_dynamic):
        # TODO: probably should cleanup this condition check, it's hard
        # to reason about this if and the following elif

        # uint8/int8/int32 static quantization branch

        # 1. extract the information from activation_post_process module for generating
        # the quantize and dequantize operator
Loading ...