import torch
from torch.fx import GraphModule
from torch.nn.utils.fusion import fuse_conv_bn_weights
# TODO[jerryzh168]: move this to a more general util function
from torch.ao.quantization.fx.prepare import (
_is_activation_post_process_node,
)
from collections import OrderedDict
import operator
# TODO[qihan]: longer term, this should happen in the dynamo stack as well
def _get_renamed_nn_module_stack(nn_module_stack):
# initialize with top level parent scope
nn_module_stack_renamed = OrderedDict([("", None)])
if nn_module_stack:
# Rename module_key, e.g. "self_layer1_1__conv1" to "self.layer1.1._conv1", for easier downstream parsing
prev_key = ""
for key, value in nn_module_stack.items():
if not prev_key:
if key.startswith("self_"):
new_key = key[5:]
prev_key = new_key
else:
new_key = prev_key + "." + key[len(prev_key) + 6 :]
nn_module_stack_renamed[new_key] = value
prev_key = new_key
return nn_module_stack_renamed
def _get_tensor_constant_from_node(node, m):
if node is None:
return None
assert node.op == "get_attr"
return getattr(m, node.target)
# fuse conv bn weights, inplace modification of the graph_module and graph
def _fuse_conv_bn_(m: GraphModule) -> None:
for n in m.graph.nodes:
if n.op != "call_function" or n.target != torch.ops.aten.native_batch_norm.default:
continue
bn_op = n
n = bn_op.args[0]
if n.op != "call_function" or n.target != torch.ops.aten.convolution.default:
continue
conv_op = n
# conv weight
conv_w = _get_tensor_constant_from_node(conv_op.args[1], m)
# conv bias
conv_b = _get_tensor_constant_from_node(conv_op.args[2], m)
transpose = conv_op.args[6]
# bn weight
bn_w = _get_tensor_constant_from_node(bn_op.args[1], m)
# bn bias
bn_b = _get_tensor_constant_from_node(bn_op.args[2], m)
# bn running mean
bn_rm = _get_tensor_constant_from_node(bn_op.args[3], m)
# bn running variance
bn_rv = _get_tensor_constant_from_node(bn_op.args[4], m)
bn_eps = bn_op.args[7]
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False)
# update the weight and bias for conv
conv_args = list(conv_op.args)
# calling data since the fused_weight and fused_bias are nn.Parameter
weight_attr_name = conv_args[1].target
setattr(m, weight_attr_name, fused_weight)
if conv_args[2] is not None:
bias_attr_name = conv_args[2].target
else:
bias_attr_name = weight_attr_name + "_bias"
with m.graph.inserting_before(conv_op):
get_bias_node = m.graph.get_attr(bias_attr_name)
conv_args[2] = get_bias_node
setattr(m, bias_attr_name, fused_bias)
conv_op.args = tuple(conv_args)
# native_batch_norm has 3 outputs, we expect getitem calls on the output
# and we want to replace the uses of getitem 0 with the output of conv
#
# Before:
# conv -> bn - (first output) -> users1
# \ - (second output) -> users2
# \ - (third output) -> users3
# After:
# conv -> (first output) -> users1
# bn -
# \ - (second output) -> users2
# \ - (third output) -> users3
# if users2 and users3 are empty then bn will be removed through dead code elimination
for user in bn_op.users:
if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0:
continue
user.replace_all_uses_with(conv_op)
m.graph.eliminate_dead_code()
m.recompile()
def _rearrange_weight_observer_for_addmm(
model: GraphModule,
) -> None:
"""
before:
weight - t - observer \
input - observer - addmm
after:
weight - observer - t \
input - observer - addmm
"""
named_modules = dict(model.named_modules(remove_duplicate=False))
for node in model.graph.nodes:
if node.target != torch.ops.aten.addmm.default:
continue
addmm = node
maybe_weight_obs = addmm.args[2]
if not _is_activation_post_process_node(maybe_weight_obs, named_modules):
continue
transpose_node = maybe_weight_obs.args[0]
if transpose_node.target != torch.ops.aten.t.default:
continue
# swap the order of transpose and observation
maybe_weight_obs.replace_input_with(transpose_node, transpose_node.args[0])
# remove the transpose node
with model.graph.inserting_after(maybe_weight_obs):
args = list(transpose_node.args)
args[0] = maybe_weight_obs
new_transpose_node = model.graph.create_node(
"call_function",
torch.ops.aten.t.default,
tuple(args),
transpose_node.kwargs
)
addmm.replace_input_with(maybe_weight_obs, new_transpose_node)
model.graph.eliminate_dead_code()
model.graph.lint()