Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ quantization / fx / quantization_patterns.py

import torch
from torch.fx.graph import (
    Node,
)
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
from torch.quantization import (
    default_affine_fixed_qparams_fake_quant,
    default_symmetric_fixed_qparams_fake_quant,
)

from ..quantization_mappings import (
    get_static_quant_module_class,
    get_dynamic_quant_module_class,
    get_quantized_operator,
)
from ..utils import (
    get_swapped_custom_module_class,
    activation_is_statically_quantized,
    weight_is_statically_quantized,
    weight_dtype,
    get_qconfig_dtypes,
)

from .pattern_utils import (
    register_quant_pattern,
    mark_input_output_not_observed,
)

from .utils import (
    _parent_name,
    quantize_node,
    get_per_tensor_qparams,
    get_linear_prepack_op_for_dtype,
    create_qparam_nodes,
    get_qconv_prepack_op,
    get_qconv_op,
)

from .quantization_types import QuantizerCls

from abc import ABC, abstractmethod
import operator
import warnings

from typing import Any, Callable, Dict

# -------------------------
# Pattern Registrations
# -------------------------

# 1. Post Training Static Quantization and Quantization Aware Training Patterns

# Base Pattern Handler
class QuantizeHandler(ABC):
    """ Base handler class for the quantizer patterns
    """
    def __init__(self, quantizer: QuantizerCls, node: Node):
        """ Records pattern information in __init__, which will be used
        in convert
        """
        # this is an indicator of whether all the inputs are Node or not
        # since some op might be quantized differently depending on whether
        # all inputs are tensors or not, e.g. add/mul
        self.num_node_args = len(node.args)
        self.all_node_args = True

    @abstractmethod
    def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
                debug: bool = False,
                convert_custom_config_dict: Dict[str, Any] = None) -> Node:
        """ Convert the given node to a quantized node and insert
        it to the quantized graph
        """
        return NotImplemented

@register_quant_pattern(operator.add)
@register_quant_pattern(torch.add)
@register_quant_pattern((torch.nn.ReLU, operator.add))
@register_quant_pattern((torch.nn.ReLU, torch.add))
@register_quant_pattern((torch.nn.functional.relu, operator.add))
@register_quant_pattern((torch.nn.functional.relu, torch.add))
class Add(QuantizeHandler):
    def __init__(self, quantizer: QuantizerCls, node: Node):
        super().__init__(quantizer, node)
        self.relu_node = None
        if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
           (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
            self.relu_node = node
            node = node.args[0]  # type: ignore
        assert node.op == 'call_function' and node.target in [operator.add, torch.add]
        self.add_node = node
        self.num_node_args = len([a for a in self.add_node.args[:2] if isinstance(a, Node)])

    def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
                debug: bool = False,
                convert_custom_config_dict: Dict[str, Any] = None) -> Node:
        if self.num_node_args == 1:
            # add scalar
            if self.relu_node is not None:
                op = torch.ops.quantized.add_relu
            else:
                op = torch.ops.quantized.add

            if isinstance(self.add_node.args[0], Node):
                quantized_index = 0
            else:
                quantized_index = 1

            return quantizer.quantized_graph.create_node(
                'call_function', op,
                load_arg(quantized=[quantized_index])(self.add_node.args), self.add_node.kwargs)
        else:
            activation_post_process = quantizer.activation_post_process_map[node.name]
            scale, zero_point = activation_post_process.calculate_qparams()
            scale = float(scale)
            zero_point = int(zero_point)
            scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)

            if self.relu_node is not None:
                op = torch.ops.quantized.add_relu
            else:
                op = torch.ops.quantized.add
            kwargs = {**self.add_node.kwargs}
            add_args = (*load_arg(quantized=True)(self.add_node.args), scale_arg, zero_point_arg)
            op = quantizer.quantized_graph.create_node(
                'call_function', op, add_args, kwargs)
            return op

# TODO: merge with Add
@register_quant_pattern(operator.mul)
@register_quant_pattern(torch.mul)
@register_quant_pattern((torch.nn.ReLU, operator.mul))
@register_quant_pattern((torch.nn.ReLU, torch.mul))
@register_quant_pattern((torch.nn.functional.relu, operator.mul))
@register_quant_pattern((torch.nn.functional.relu, torch.mul))
class Mul(QuantizeHandler):
    def __init__(self, quantizer: QuantizerCls, node: Node):
        super().__init__(quantizer, node)
        self.relu_node = None
        if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
           (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
            self.relu_node = node
            node = node.args[0]  # type: ignore
        assert node.op == 'call_function' and node.target in [operator.mul, torch.mul]
        self.mul_node = node
        self.num_node_args = len([a for a in self.mul_node.args[:2] if isinstance(a, Node)])

    def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
                debug: bool = False,
                convert_custom_config_dict: Dict[str, Any] = None) -> Node:
        if self.num_node_args == 1:
            # mul scalar
            if self.relu_node is not None:
                op = torch.ops.quantized.mul_relu
            else:
                op = torch.ops.quantized.mul

            if isinstance(self.mul_node.args[0], Node):
                quantized_index = 0
            else:
                quantized_index = 1

            return quantizer.quantized_graph.create_node(
                'call_function', op, load_arg(quantized=[quantized_index])(self.mul_node.args), self.mul_node.kwargs)
        else:
            activation_post_process = quantizer.activation_post_process_map[node.name]
            scale, zero_point = activation_post_process.calculate_qparams()
            scale = float(scale)
            zero_point = int(zero_point)

            scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)

            if self.relu_node is not None:
                op = torch.ops.quantized.mul_relu
            else:
                op = torch.ops.quantized.mul
            kwargs = {**self.mul_node.kwargs}
            args = (*load_arg(quantized=True)(self.mul_node.args), scale_arg, zero_point_arg)
            return quantizer.quantized_graph.create_node('call_function', op, args, kwargs)

@register_quant_pattern(torch.cat)
class Cat(QuantizeHandler):
    def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
                debug: bool = False,
                convert_custom_config_dict: Dict[str, Any] = None) -> Node:
        if not self.all_node_args:
            return NotImplemented
        activation_post_process = quantizer.activation_post_process_map[node.name]
        scale, zero_point = activation_post_process.calculate_qparams()
        scale = float(scale)
        zero_point = int(zero_point)

        scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)

        kwargs = {**load_arg(quantized=False)(node.kwargs), 'scale': scale_arg, 'zero_point': zero_point_arg}
        return quantizer.quantized_graph.create_node(
            'call_function', torch.ops.quantized.cat, load_arg(quantized=[0])(node.args), kwargs)

# handle conv, maybe followed by relu
# NB: matching order is reversed, that is we match from the bottom of this list to the beginning
@register_quant_pattern(torch.nn.Conv1d)
@register_quant_pattern(torch.nn.Conv2d)
@register_quant_pattern(torch.nn.Conv3d)
@register_quant_pattern(torch.nn.functional.conv1d)
@register_quant_pattern(torch.nn.functional.conv2d)
@register_quant_pattern(torch.nn.functional.conv3d)
# TODO: add qat.Conv1d and qat.Conv3d
@register_quant_pattern(torch.nn.qat.Conv2d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU1d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU2d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU3d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn1d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU1d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d)
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv1d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv3d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv1d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv2d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv3d))
# just for error checks
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
class ConvRelu(QuantizeHandler):
    def __init__(self, quantizer: QuantizerCls, node: Node):
        super().__init__(quantizer, node)
        self.relu_node = None
        if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
           (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
            self.relu_node = node
            node = node.args[0]  # type: ignore
        self.conv_node = node
        if node.op == "call_module":
            self.conv = quantizer.modules[self.conv_node.target]
        elif node.op == "call_function":
            self.conv = node.target

    def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
                debug: bool = False,
                convert_custom_config_dict: Dict[str, Any] = None) -> Node:
        # Supported combinations are:
        # quant_type | activation (compute_type) | weight
        #  static       quint8                      qint8

        # tuple (activation_dtype, weight_dtype, compute_dtype)
        supported_dtypes = [
            (torch.quint8, torch.qint8, None),
        ]

        # TODO: debug option for conv module
        qconfig = quantizer.qconfig_map[node.name]
        dtypes = get_qconfig_dtypes(qconfig)
        # leave the op unquantized if the dtype combination is not supported
        if dtypes not in supported_dtypes:
            warnings.warn(
                "dtype combination: {} is not "
                "supported by Conv "
                "supported dtype combinations are: {}".format(dtypes, supported_dtypes))
            if self.relu_node:
                conv_out = quantizer.quantized_graph.node_copy(self.conv_node, load_arg(quantized=False))
                relu_args = [conv_out]
                relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
                relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
                return quantizer.quantized_graph.create_node(
                    "call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
            else:
                return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))

        activation_statically_quantized = activation_is_statically_quantized(qconfig)

        if self.conv_node.op == 'call_module':
            # note that relu should already be fused into conv module in the fusion step
            assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
                'please make sure to run fusion before prepare'
            if convert_custom_config_dict is None:
                convert_custom_config_dict = {}
            additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
            # 1. attach activation post process to module
            self.conv.activation_post_process = quantizer.activation_post_process_map[node.name]
            # 2. select quantized class
            qconv_cls = get_static_quant_module_class(
                type(self.conv), additional_static_quant_mapping)
            quantized = qconv_cls.from_float(self.conv)
            parent_name, name = _parent_name(self.conv_node.target)
            setattr(quantizer.modules[parent_name], name, quantized)
            return quantizer.quantized_graph.create_node(
                'call_module',
                self.conv_node.target,
                (load_arg(quantized=True)(self.conv_node.args[0]),),
                {})
        else:  # call_function
            assert self.conv_node.op == "call_function"
            if debug:
                args = load_arg(quantized=[0, 1])(self.conv_node.args)
                args = load_arg(quantized=False)(self.conv_node.args)
                kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
                op_out = quantizer.quantized_graph.create_node(
                    "call_function", self.conv, args, kwargs)
                if self.relu_node:
                    relu_args = [op_out]
                    relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
                    relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
                    op_out = quantizer.quantized_graph.create_node(
                        "call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)

                if activation_statically_quantized:
                    root_module = quantizer.modules['']
                    act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
                    act_post_process_node = self.relu_node if self.relu_node else self.conv_node
                    return quantize_node(
                        quantizer, op_out, quantizer.activation_post_process_map[act_post_process_name],
                        act_post_process_node, is_input=False)
                else:
                    # output for dynamically quantized conv op is not quantized
                    return op_out
            else:
                assert len(self.conv_node.args) >= 7, \
                    "only conv2d calls with all arguments specified is supported right now in debug=False option"
                args = load_arg(quantized=[0, 1])(self.conv_node.args)
                # pack weight
                weight = load_arg(quantized=True)(self.conv_node.args[1])
                other_args = load_arg(quantized=False)(self.conv_node.args[2:])
                prepack_args = tuple([weight] + list(other_args))
                prepack_op = get_qconv_prepack_op(self.conv)
                packed_weight = quantizer.quantized_graph.create_node(
                    "call_function", prepack_op, prepack_args, {})
                assert activation_statically_quantized, \
                    "currently only static quantization is supported for conv"
                # construct conv input
                if activation_statically_quantized:
                    qconv_op = get_qconv_op(self.conv, self.relu_node is not None)
                    conv_input = load_arg(quantized=True)(self.conv_node.args[0])
                    act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
                    activation_post_process = quantizer.activation_post_process_map[act_post_process_name]
                    scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
                    scale_node, zero_point_node = create_qparam_nodes(quantizer, self.conv_node.name, scale, zero_point)
                    qconv_args = (conv_input, packed_weight, scale_node, zero_point_node)
                    kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
                    op = quantizer.quantized_graph.create_node(
                        'call_function', qconv_op, qconv_args, kwargs)
                    # Store the name of the fused op to get the path of node after fusion as well.
                    # TODO: may need to change the key to Node regenerate the map in each transformation,
Loading ...