Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ ao / ns / fx / graph_passes.py

import torch
from torch.fx import GraphModule, map_arg
from torch.fx.graph import Graph, Node
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix

from .utils import (
    get_node_first_input_and_output_type,
    getattr_from_fqn,
    NodeInputOrOutputType,
    return_first_non_observer_node,
    get_number_of_non_param_args,
    get_target_type_str,
    get_arg_indices_of_inputs_to_log,
    get_node_input_qparams,
    op_type_supports_shadowing,
    get_normalized_nth_input,
)

from .ns_types import (
    NSSingleResultValuesType,
    NSSubgraph,
    NSNodeTargetType,
)
from torch.ao.ns.fx.mappings import (
    get_node_type_to_io_type_map,
)
from torch.ao.quantization.observer import _is_activation_post_process

from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set

def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
    fqn = None
    if hasattr(gm, '_node_name_to_scope'):
        # fqn on observers is not present, because they do not
        # exist when the fqns are created during tracing. If this is
        # an observer, get the fqn of the node being observed.
        node_to_use_for_fqn = node
        if node.op == 'call_module':
            assert isinstance(node.target, str)
            module = getattr_from_fqn(gm, node.target)
            if _is_activation_post_process(module):
                node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
        fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0]  # type: ignore[index]
    return fqn  # type: ignore[return-value]

def _insert_logger_after_node(
    node: Node,
    gm: GraphModule,
    logger_cls: Callable,
    logger_node_name_suffix: str,
    ref_node_name: str,
    model_name: str,
    ref_name: str,
    ref_node_target_type: str,
    results_type: str,
    index_within_arg: int,
    index_of_arg: int,
    fqn: Optional[str],
) -> Node:
    """
    Given a starting graph of

    prev_node -> node -> next_node

    This function creates a new logger_cls obj and adds it
    after node, resulting in

    prev_node -> node -> logger_obj -> next_node
    """
    # create new name
    logger_node_name = \
        get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm)
    target_type = get_target_type_str(node, gm)
    # create the logger object
    logger_obj = logger_cls(
        ref_node_name, node.name, model_name, ref_name, target_type,
        ref_node_target_type,
        results_type, index_within_arg, index_of_arg, fqn)
    # attach the logger object to the parent module
    setattr(gm, logger_node_name, logger_obj)
    logger_node = node.graph.create_node(
        'call_module', logger_node_name, (node,), {})
    return logger_node

def add_loggers_to_model(
    gm: GraphModule,
    node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
    node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
    logger_cls: Callable,
    model_name: str,
) -> GraphModule:
    """
    Takes the graph of gm, adds loggers to the output
    of each node in nodes_to_instrument. Returns a GraphModule with the new
    graph.
    """

    new_graph = Graph()
    env: Dict[str, Any] = {}
    modules = dict(gm.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    for node in gm.graph.nodes:
        if node.op == 'output':
            new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg))
            continue

        if (
            (node in node_to_instrument_inputs_to_ref_node_name) or
            (node in node_to_instrument_outputs_to_ref_node_name)
        ):
            fqn = _maybe_get_fqn(node, gm)

            if node in node_to_instrument_inputs_to_ref_node_name:
                ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[node]
                # Ops such add and mul are special because either
                # one or two of the first two arguments can be tensors,
                # and if one argument is a tensor it can be first or
                # second (x + 1 versus 1 + x).
                arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
                for node_arg_idx in arg_indices_to_log:
                    node_arg = get_normalized_nth_input(node, gm, node_arg_idx)
                    if type(node_arg) == Node:
                        # create a single input logger
                        prev_node = env[node_arg.name]
                        env[node_arg.name] = _insert_logger_after_node(
                            prev_node, gm, logger_cls, '_ns_logger_', node.name,
                            model_name, ref_name, ref_node_type,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0, index_of_arg=node_arg_idx,
                            fqn=fqn)
                    elif type(node_arg) == torch.fx.immutable_collections.immutable_list:
                        # create N input loggers, one for each node
                        for arg_idx, arg in enumerate(node_arg):  # type: ignore[var-annotated, arg-type]
                            prev_node = env[arg.name]
                            env[prev_node.name] = _insert_logger_after_node(
                                prev_node, gm, logger_cls, '_ns_logger_', node.name,
                                model_name, ref_name, ref_node_type,
                                NSSingleResultValuesType.NODE_INPUT.value,
                                index_within_arg=arg_idx, index_of_arg=node_arg_idx,
                                fqn=fqn)
                    else:
                        pass

            # ensure env is populated with base node
            # Note: runs for both inputs and outputs
            env[node.name] = new_graph.node_copy(node, load_arg)

            if node in node_to_instrument_outputs_to_ref_node_name:
                ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[node]
                # add the logger after the base node
                env[node.name] = _insert_logger_after_node(
                    env[node.name], gm, logger_cls, '_ns_logger_', node.name,
                    model_name, ref_name, ref_node_type,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0, index_of_arg=0, fqn=fqn)

        else:
            env[node.name] = new_graph.node_copy(node, load_arg)

    new_gm = GraphModule(gm, new_graph)
    return new_gm

def _insert_quantize_per_tensor_node(
    prev_node_c: Node,
    node_a: Node,
    gm_b: GraphModule,
    graph_c: Graph,
    scale: Union[torch.Tensor, float],
    zero_point: Union[torch.Tensor, int],
    dtype_cast_name: str,
) -> Node:
    # copy scale
    scale_node_name = \
        get_new_attr_name_with_prefix(
            node_a.name + '_input_scale_')(gm_b)
    setattr(gm_b, scale_node_name, scale)
    scale_node = graph_c.create_node(
        'get_attr', scale_node_name, (), {}, scale_node_name)
    # copy zero_point
    zero_point_node_name = \
        get_new_attr_name_with_prefix(
            node_a.name + '_input_zero_point_')(gm_b)
    setattr(gm_b, zero_point_node_name, zero_point)
    zero_point_node = graph_c.create_node(
        'get_attr', zero_point_node_name, (), {}, zero_point_node_name)
    # create the quantize_per_tensor call
    return graph_c.create_node(
        'call_function', torch.quantize_per_tensor,
        (prev_node_c, scale_node, zero_point_node, torch.quint8), {},
        dtype_cast_name)

def _insert_dtype_cast_after_node(
    node_a: Node,
    node_c: Node,
    prev_node_c: Union[Node, List[Node]],
    gm_a: GraphModule,
    gm_b: GraphModule,
    graph_c: Graph,
    node_name_prefix: str,
    logger_cls: Callable,
    node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
) -> Union[Node, List[Node]]:
    """
    Given a starting graph C (derived from graph B) of

    ... -> prev_node_c -> node_c -> ...

    And a corresponding related node_a, inserts the correct dtype
    cast node after prev_node_c to cast into the dtype expected
    by node_a, resulting in:

                          dtype_cast
                        /
    ... -> prev_node_c -> node_c -> ...

    For example, if node_c is an int8 op and node_a is an fp32 op, this function
    will insert a dequant.
    """
    dtype_cast_op = None
    dtype_cast_mod_cls = None
    dtype_cast_method = None
    dtype_cast_method_dtype = None
    dtype_cast_scale = None
    dtype_cast_zero_point = None
    node_input_type_a, _node_output_type_a = \
        get_node_first_input_and_output_type(
            node_a, gm_a, logger_cls, node_type_to_io_type_map)
    node_input_type_c, _node_output_type_c = \
        get_node_first_input_and_output_type(
            node_c, gm_b, logger_cls, node_type_to_io_type_map)

    if (
        (node_input_type_a == NodeInputOrOutputType.FP32 and
         node_input_type_c == NodeInputOrOutputType.INT8) or
        (node_input_type_a == NodeInputOrOutputType.FP32 and
         node_input_type_c == NodeInputOrOutputType.FP16) or
        # TODO(future PR): determine the actual dtype of node_c,
        # the current code only works because dequantize works with
        # multiple input dtypes.
        (node_input_type_a == NodeInputOrOutputType.FP32 and
         node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8)
    ):
        dtype_cast_op = torch.dequantize
    elif (
        node_input_type_a == node_input_type_c and
        node_input_type_a != NodeInputOrOutputType.UNKNOWN
    ):
        dtype_cast_mod_cls = torch.nn.Identity
    elif (
        node_input_type_a == NodeInputOrOutputType.INT8 and
        node_input_type_c == NodeInputOrOutputType.FP32
    ):
        # int8 shadows fp32, the dtype cast needs to quantize to int8
        # with the right qparams.
        node_a_input_qparams = get_node_input_qparams(
            node_a, gm_a, node_type_to_io_type_map)
        if node_a_input_qparams is not None:
            dtype_cast_op = torch.quantize_per_tensor  # type: ignore[assignment]
            dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
    elif (
        node_input_type_a == NodeInputOrOutputType.FP16 and
        node_input_type_c == NodeInputOrOutputType.FP32
    ):
        dtype_cast_method = 'to'
        dtype_cast_method_dtype = torch.float16
    else:
        raise AssertionError(
            f"dtype cast from {node_input_type_c} {node_c.format_node()} to " +
            f"{node_input_type_a} {node_a.format_node()} needs to be implemented")

    if isinstance(prev_node_c, Node):
        new_dtype_cast_name = \
            get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
        if dtype_cast_op:
            if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
                return _insert_quantize_per_tensor_node(
                    prev_node_c, node_a, gm_b, graph_c, dtype_cast_scale,
                    dtype_cast_zero_point, new_dtype_cast_name)
            else:
                return graph_c.create_node(
                    'call_function', dtype_cast_op, (prev_node_c,), {},
                    new_dtype_cast_name)
        elif dtype_cast_method:
            return graph_c.create_node(
                'call_method', dtype_cast_method,
                (prev_node_c, dtype_cast_method_dtype), {}, new_dtype_cast_name)
        else:
            assert dtype_cast_mod_cls
            dtype_cast_mod = dtype_cast_mod_cls()
            setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
            return graph_c.create_node(
                'call_module', new_dtype_cast_name, (prev_node_c,), {},
                new_dtype_cast_name)
    elif isinstance(prev_node_c, list):
        results = []
        for prev_node_c_inner in prev_node_c:
            new_dtype_cast_name = \
                get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
            if dtype_cast_op:
                # TODO(future PR): add handling for quantize_per_tensor
                new_dtype_cast_node = graph_c.create_node(
                    'call_function', dtype_cast_op, (prev_node_c_inner,), {},
                    new_dtype_cast_name)
                results.append(new_dtype_cast_node)
            else:
                assert dtype_cast_mod_cls
                dtype_cast_mod = dtype_cast_mod_cls()
                setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
                new_dtype_cast_node = graph_c.create_node(
                    'call_module', new_dtype_cast_name, (prev_node_c_inner,), {},
                    new_dtype_cast_name)
                results.append(new_dtype_cast_node)
        return results
    else:
        raise AssertionError(f"type f{type(prev_node_c)} is not handled")

# TODO(future PR): look into using copy_node API instead
def _copy_node_from_a_to_c(
    node_a: Node,
    gm_a: GraphModule,
    gm_b: GraphModule,
    graph_c: Graph,
) -> Node:
    """
    Simple copy of node_a to graph_c.
    """
    if node_a.op == 'get_attr':
        node_a_copy_name = \
            get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
        node_a_obj = getattr_from_fqn(gm_a, node_a.target)  # type: ignore[arg-type]
        if torch.is_tensor(node_a_obj):
            node_a_obj = node_a_obj.detach()
        setattr(gm_b, node_a_copy_name, node_a_obj)
        node_a_copy = graph_c.create_node(
            node_a.op, node_a_copy_name, (), {}, node_a_copy_name)
        return node_a_copy
    elif node_a.op == 'call_method':
        assert node_a.target in ('dequantize', 'to'), \
            f"target {node_a.target} is not implemented"
        if node_a.target == 'dequantize':
            arg_copy = _copy_node_from_a_to_c(
                get_normalized_nth_input(node_a, gm_a, 0),
Loading ...