Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ ao / quantization / _pt2e / utils.py

import torch
from torch.fx import GraphModule
from torch.nn.utils.fusion import fuse_conv_bn_weights
# TODO[jerryzh168]: move this to a more general util function
from torch.ao.quantization.fx.prepare import (
    _is_activation_post_process_node,
)
from collections import OrderedDict
import operator

# TODO[qihan]: longer term, this should happen in the dynamo stack as well
def _get_renamed_nn_module_stack(nn_module_stack):
    # initialize with top level parent scope
    nn_module_stack_renamed = OrderedDict([("", None)])
    if nn_module_stack:
        # Rename module_key, e.g. "self_layer1_1__conv1" to "self.layer1.1._conv1", for easier downstream parsing
        prev_key = ""
        for key, value in nn_module_stack.items():
            if not prev_key:
                if key.startswith("self_"):
                    new_key = key[5:]
                    prev_key = new_key
            else:
                new_key = prev_key + "." + key[len(prev_key) + 6 :]
            nn_module_stack_renamed[new_key] = value
            prev_key = new_key
    return nn_module_stack_renamed

def _get_tensor_constant_from_node(node, m):
    if node is None:
        return None
    assert node.op == "get_attr"
    return getattr(m, node.target)

# fuse conv bn weights, inplace modification of the graph_module and graph
def _fuse_conv_bn_(m: GraphModule) -> None:
    for n in m.graph.nodes:
        if n.op != "call_function" or n.target != torch.ops.aten.native_batch_norm.default:
            continue
        bn_op = n
        n = bn_op.args[0]
        if n.op != "call_function" or n.target != torch.ops.aten.convolution.default:
            continue
        conv_op = n

        # conv weight
        conv_w = _get_tensor_constant_from_node(conv_op.args[1], m)
        # conv bias
        conv_b = _get_tensor_constant_from_node(conv_op.args[2], m)
        transpose = conv_op.args[6]

        # bn weight
        bn_w = _get_tensor_constant_from_node(bn_op.args[1], m)
        # bn bias
        bn_b = _get_tensor_constant_from_node(bn_op.args[2], m)
        # bn running mean
        bn_rm = _get_tensor_constant_from_node(bn_op.args[3], m)
        # bn running variance
        bn_rv = _get_tensor_constant_from_node(bn_op.args[4], m)
        bn_eps = bn_op.args[7]

        fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False)

        # update the weight and bias for conv
        conv_args = list(conv_op.args)
        # calling data since the fused_weight and fused_bias are nn.Parameter
        weight_attr_name = conv_args[1].target
        setattr(m, weight_attr_name, fused_weight)
        if conv_args[2] is not None:
            bias_attr_name = conv_args[2].target
        else:
            bias_attr_name = weight_attr_name + "_bias"
            with m.graph.inserting_before(conv_op):
                get_bias_node = m.graph.get_attr(bias_attr_name)
            conv_args[2] = get_bias_node
        setattr(m, bias_attr_name, fused_bias)
        conv_op.args = tuple(conv_args)

        # native_batch_norm has 3 outputs, we expect getitem calls on the output
        # and we want to replace the uses of getitem 0 with the output of conv
        #
        # Before:
        # conv -> bn - (first output) -> users1
        #          \ - (second output) -> users2
        #          \ - (third output) -> users3
        # After:
        # conv -> (first output) -> users1
        #       bn -
        #          \ - (second output) -> users2
        #          \ - (third output) -> users3
        # if users2 and users3 are empty then bn will be removed through dead code elimination

        for user in bn_op.users:
            if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0:
                continue
            user.replace_all_uses_with(conv_op)
    m.graph.eliminate_dead_code()
    m.recompile()

def _rearrange_weight_observer_for_addmm(
    model: GraphModule,
) -> None:
    """
    before:
         weight - t - observer \
          input - observer - addmm
    after:
         weight - observer - t \
           input - observer - addmm
    """
    named_modules = dict(model.named_modules(remove_duplicate=False))
    for node in model.graph.nodes:
        if node.target != torch.ops.aten.addmm.default:
            continue
        addmm = node
        maybe_weight_obs = addmm.args[2]
        if not _is_activation_post_process_node(maybe_weight_obs, named_modules):
            continue
        transpose_node = maybe_weight_obs.args[0]
        if transpose_node.target != torch.ops.aten.t.default:
            continue
        # swap the order of transpose and observation

        maybe_weight_obs.replace_input_with(transpose_node, transpose_node.args[0])
        # remove the transpose node
        with model.graph.inserting_after(maybe_weight_obs):
            args = list(transpose_node.args)
            args[0] = maybe_weight_obs
            new_transpose_node = model.graph.create_node(
                "call_function",
                torch.ops.aten.t.default,
                tuple(args),
                transpose_node.kwargs
            )
        addmm.replace_input_with(maybe_weight_obs, new_transpose_node)

    model.graph.eliminate_dead_code()
    model.graph.lint()