Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ ao / quantization / fx / utils.py

import copy
import torch
import torch.nn as nn
from torch.ao.quantization import (
from torch.ao.quantization.backend_config import (
from torch.ao.quantization.fake_quantize import (
from torch.ao.quantization.observer import (
from torch.ao.quantization.qconfig import (
from torch.ao.quantization.stubs import DeQuantStub
from torch.ao.quantization.utils import (
from torch.ao.quantization.observer import _is_activation_post_process
from torch.ao.quantization.qconfig_mapping import QConfigMapping

from torch.fx import GraphModule, map_arg

from torch.fx.graph import (
from .custom_config import PrepareCustomConfig
# importing the lib so that the quantized_decomposed ops are registered
from ._decomposed import quantized_decomposed_lib  # noqa: F401

from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type
from dataclasses import dataclass
from collections import namedtuple
import operator
import warnings

# TODO: revisit this list. Many helper methods shouldn't be public
__all__ = [

NON_QUANTIZABLE_WEIGHT_OPS = {torch.nn.functional.layer_norm, torch.nn.functional.group_norm, torch.nn.functional.instance_norm}

class ObservedGraphModuleAttrs:
    node_name_to_qconfig: Dict[str, QConfigAny]
    node_name_to_scope: Dict[str, Tuple[str, type]]
    prepare_custom_config: PrepareCustomConfig
    equalization_node_name_to_qconfig: Dict[str, Any]
    qconfig_mapping: QConfigMapping
    is_qat: bool
    observed_node_names: Set[str]
    is_observed_standalone_module: bool = False
    standalone_module_input_quantized_idxs: Optional[List[int]] = None
    standalone_module_output_quantized_idxs: Optional[List[int]] = None

def node_arg_is_weight(node: Node, arg: Any, backend_config: BackendConfig) -> bool:
    """Returns if node arg is weight"""
    if isinstance(node, Node) and node.op == "call_function" and \
            node.target in backend_config._pattern_complex_format_to_config:
        weight_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("weight")
        if weight_index is not None and weight_index < len(node.args) and node.args[weight_index] is arg:
            return True
        return node.kwargs.get("weight") is arg
    return False

def node_arg_is_bias(node: Node, arg: Any, backend_config: BackendConfig) -> bool:
    """Returns if node arg is bias"""
    if isinstance(node, Node) and node.op == "call_function" and \
            node.target in backend_config._pattern_complex_format_to_config:
        bias_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("bias")
        if bias_index is not None and bias_index < len(node.args) and node.args[bias_index] is arg:
            return True
        return node.kwargs.get("bias") is arg
    return False

def get_custom_module_class_keys(custom_module_mapping: Dict[QuantType, Dict[Type, Type]]) -> List[Any]:
    r""" Get all the unique custom module keys in the custom config dict
        QuantType.STATIC: {
            CustomModule1: ObservedCustomModule
        QuantType.DYNAMIC: {
            CustomModule2: DynamicObservedCustomModule
        QuantType.WEIGHT_ONLY: {
            CustomModule3: WeightOnlyObservedCustomModule

    # extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts
    [CustomModule1, CustomModule2, CustomModule3]
    # using set to dedup
    float_custom_module_classes : Set[Any] = set()
    for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]:
        quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {})
        quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys())
        float_custom_module_classes |= quant_mode_custom_module_classes
    return list(float_custom_module_classes)

def get_linear_prepack_op_for_dtype(dtype):
    if dtype == torch.float16:
        return torch.ops.quantized.linear_prepack_fp16
    elif dtype == torch.qint8:
        return torch.ops.quantized.linear_prepack
        raise Exception("can't get linear prepack op for dtype:", dtype)

def get_qconv_prepack_op(conv_op: Callable) -> Callable:
    prepack_ops = {
        torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack,
        torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack,
        torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack
    prepack_op = prepack_ops.get(conv_op, None)
    assert prepack_op, "Didn't find prepack op for {}".format(conv_op)
    return prepack_op

# Returns a function that can get a new attribute name for module with given
# prefix, for example,
# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
# >> new_name = get_new_observer_name(module)
# new_name will be an unused attribute name on module, e.g. `_observer_1`
def get_new_attr_name_with_prefix(prefix: str) -> Callable:
    prefix = prefix.replace(".", "_")

    def get_new_attr_name(module: torch.nn.Module):
        def get_attr_name(i: int):
            return prefix + str(i)
        i = 0
        attr_name = get_attr_name(i)
        while hasattr(module, attr_name):
            i += 1
            attr_name = get_attr_name(i)
        return attr_name
    return get_new_attr_name

def collect_producer_nodes(node: Node) -> Optional[List[Node]]:
    r''' Starting from a target node, trace back until we hit inpu or
    getattr node. This is used to extract the chain of operators
    starting from getattr to the target node, for example
    def forward(self, x):
      observed = self.observer(self.weight)
      return F.linear(x, observed)
    collect_producer_nodes(observed) will either return a list of nodes that
    produces the observed node or None if we can't extract a self contained
    graph without free variables(inputs of the forward function).
    nodes = [node]
    frontier = [node]
    while frontier:
        node = frontier.pop()
        all_args = list(node.args) + list(node.kwargs.values())
        for arg in all_args:
            if not isinstance(arg, Node):
            if arg.op == 'placeholder':
                # hit input, can't fold in this case
                return None
            if not (arg.op == 'call_function' and arg.target == getattr):
    return nodes

def graph_module_from_producer_nodes(
        root: GraphModule, producer_nodes: List[Node]) -> GraphModule:
    r''' Construct a graph module from extracted producer nodes
    from `collect_producer_nodes` function
      root: the root module for the original graph
      producer_nodes: a list of nodes we use to construct the graph
      A graph module constructed from the producer nodes
    assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
    # since we traced back from node to getattrr
    graph = Graph()
    env: Dict[Any, Any] = {}

    def load_arg(a):
        return map_arg(a, lambda node: env[node])
    for producer_node in producer_nodes:
        env[producer_node] = graph.node_copy(producer_node, load_arg)
    graph_module = GraphModule(root, graph)
    return graph_module

def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
    Returns the unique device for a module, or None if no device is found.
    Throws an error if multiple devices are detected.
    devices = {p.device for p in module.parameters()} | \
        {p.device for p in module.buffers()}
    assert len(devices) <= 1, (
        "prepare only works with cpu or single-device CUDA modules, "
        "but got devices {}".format(devices)
    device = next(iter(devices)) if len(devices) > 0 else None
    return device

def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node:
    Given a value of any type, creates a getattr node corresponding to the value and
    registers the value as a buffer to the module.
    get_new_attr_name = get_new_attr_name_with_prefix(prefix)
    attr_name = get_new_attr_name(module)
    device = assert_and_get_unique_device(module)
    new_value = value.clone().detach() if isinstance(value, torch.Tensor) \
        else torch.tensor(value, device=device)
    module.register_buffer(attr_name, new_value)
    # Create get_attr with value
    attr_node = graph.create_node("get_attr", attr_name)
    return attr_node

def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module], cache: Dict[Node, bool]) -> bool:
    If we know for sure that all of this node's args have no
    tensors (are primitives), return True.  If we either
    find a tensor or are not sure, return False. Note: this
    function is not exact.
    if cache and node in cache:
        return cache[node]

    result = False  # will be overwritten
    if not isinstance(node, Node):
        result = True
    elif node.op == 'placeholder':
        result = False
    elif node.op == 'call_module':
        assert isinstance(node.target, str)
        if _is_activation_post_process(modules[node.target]):
            result = all_node_args_have_no_tensors(node.args[0], modules, cache)  # type: ignore[arg-type]
    elif node.op == 'call_module':
        result = False
    elif node.op == 'call_function' and node.target is operator.getitem:
        result = all_node_args_have_no_tensors(node.args[0], modules, cache)  # type: ignore[arg-type]
    elif node.op == 'get_attr':
        result = False
    elif node.target is getattr and node.args[1] in ['ndim', 'shape']:
        # x1 = x0.ndim
        result = True
    elif node.op == 'call_method' and node.target == 'size':
        # x1 = x0.size(0)
        result = True
        found_one_tensor = False
        for arg in node.args:
            if isinstance(arg, list):
                for list_el in arg:
                    if isinstance(list_el, Node):
                        this_list_el_args_have_no_tensors = \
                            all_node_args_have_no_tensors(list_el, modules, cache)
                        found_one_tensor = found_one_tensor or \
                            (not this_list_el_args_have_no_tensors)
                        # If found_one_tensor is True, there is no point in
                        # recursing further as the end result will always
                        # be True.
                        # TODO(future PR): remove this entire function  and
                        # change to dtype inference without recursion.
                        if found_one_tensor:
                            result = not found_one_tensor
                            if cache:
                                cache[node] = result
                            return result
            elif isinstance(arg, int):
                if isinstance(arg, Node):
                    this_arg_args_have_no_tensors = all_node_args_have_no_tensors(arg, modules, cache)
                    found_one_tensor = found_one_tensor or \
                        (not this_arg_args_have_no_tensors)
                    # If found_one_tensor is True, there is no point in
                    # recursing further as the end result will always
                    # be True.
                    # TODO(future PR): remove this entire function  and
                    # change to dtype inference without recursion.
                    if found_one_tensor:
                        result = not found_one_tensor
                        if cache:
                            cache[node] = result
                        return result
                    found_one_tensor = True
            result = not found_one_tensor
    if cache:
        cache[node] = result
    return result

def all_node_args_except_first(node: Node) -> List[int]:
    Returns all node arg indices after first
    return list(range(1, len(node.args)))

def return_arg_list(arg_indices: List[int]) -> Callable[[Node], List[int]]:
    Constructs a function that takes a node as arg and returns the arg_indices
    that are valid for node.args
    def arg_indices_func(node: Node) -> List[int]:
        return [i for i in arg_indices if i < len(node.args)]
    return arg_indices_func

NodeInfo = namedtuple("NodeInfo", "op target")
Loading ...