import copy
import logging
import random
import weakref
import torch
import torch._dynamo.config as dynamo_config
import torch.nn as nn
from torch import _prims
from torch._dynamo.utils import fake_mode_from_tensors
from torch.fx.experimental.optimization import (
matches_module_pattern,
replace_node_module,
)
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
from torch.fx.passes.shape_prop import ShapeProp
from torch.nn import functional as F
from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
from torch.overrides import TorchFunctionMode
from . import config
from .fx_utils import matches_module_function_pattern
from .mkldnn import mkldnn_fuse_fx
log = logging.getLogger(__name__)
class AutogradMonkeypatch(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if not kwargs:
kwargs = {}
if func in replacements and not (
config.fallback_random
and replacements[func] in replacements_using_triton_random
):
return replacements[func](*args, **kwargs)
return func(*args, **kwargs)
patch_functions = AutogradMonkeypatch
def replace_fx(gm: torch.fx.GraphModule):
# Sometimes patch_functions() misses things already in the graph
for node in reversed(list(gm.graph.nodes)):
if node.op == "call_function" and node.target in replacements:
if (
config.fallback_random
and replacements[node.target] in replacements_using_triton_random
):
continue
with gm.graph.inserting_before(node):
node.replace_all_uses_with(
gm.graph.call_function(
replacements[node.target], node.args, node.kwargs
)
)
gm.graph.erase_node(node)
gm.graph.lint()
gm.recompile()
return gm
def fuse_fx(gm: torch.fx.GraphModule, example_inputs):
is_cpu = all(
example_input.device == torch.device("cpu")
for example_input in example_inputs
if isinstance(example_input, torch.Tensor)
)
fake_mode = fake_mode_from_tensors(example_inputs)
gm = sink_cat_after_pointwise(gm)
if config.permute_fusion and not is_cpu:
# For linear permute fusion, we need to check input info to identify
# and perform proper permutation/transpose
ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
gm = linear_permute_fusion(gm)
gm = permute_linear_fusion(gm)
gm = permute_matmul_fusion(gm)
# make sure the autograd is disabled.
if torch.is_grad_enabled():
return gm
if not is_cpu:
return gm
gm = remove_identity(gm)
gm = fuse_conv_bn(gm)
# do mkldnn fusion(conv(linear)+unary(binary)
# This is skipped when dynamic shapes is enabled, as the resulting
# mkl packing ops don't support dynamic shapes. Once they do support,
# you can remove this. A good test case is wav2vec2, see
# https://github.com/pytorch/pytorch/issues/91719
if not dynamo_config.dynamic_shapes:
gm = mkldnn_fuse_fx(gm, example_inputs)
return gm
def fetch_attr(target: str, mod):
target_atoms = target.split(".")
attr_itr = mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr
def remove_identity(gm: torch.fx.GraphModule):
"""
Removes all identity layers from the module.
"""
class IdentityRemover(torch.fx.Transformer):
def call_module(self, target, args, kwargs):
if isinstance(self.submodules[target], nn.Identity):
assert len(args) == 1
return args[0]
else:
return super().call_module(target, args, kwargs)
return IdentityRemover(gm).transform()
def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False):
"""
Fuses Convolution/BN layers for inference purposes.
"""
modules_patterns = [
(torch.nn.Conv1d, torch.nn.BatchNorm1d),
(torch.nn.Conv2d, torch.nn.BatchNorm2d),
(torch.nn.Conv3d, torch.nn.BatchNorm3d),
]
module_function_patterns = [
(torch.nn.Conv1d, F.batch_norm),
(torch.nn.Conv2d, F.batch_norm),
(torch.nn.Conv3d, F.batch_norm),
]
modules = dict(gm.named_modules())
for pattern in modules_patterns:
for node in gm.graph.nodes:
if matches_module_pattern(pattern, node, modules):
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
continue
conv = modules[node.args[0].target]
bn = modules[node.target]
eval_mode = all(not n.training for n in [conv, bn])
if not eval_mode:
continue
if not bn.track_running_stats:
continue
fused_conv = fuse_conv_bn_eval(conv, bn)
replace_node_module(node.args[0], modules, fused_conv)
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
gm.graph.lint()
for pattern in module_function_patterns:
for node in gm.graph.nodes:
if matches_module_function_pattern(pattern, node, modules):
# TODO: support kwargs.
if len(node.args) != 8:
continue
conv = modules[node.args[0].target]
bn_training = node.args[5]
bn_eps = node.args[7]
if conv.training or bn_training:
continue
if type(bn_eps) is not float:
continue
bn_args_is_constant = all(
n.op == "get_attr" and len(n.users) == 1 for n in node.args[1:5]
)
if not bn_args_is_constant:
continue
bn_running_mean = fetch_attr(node.args[1].target, gm)
bn_running_var = fetch_attr(node.args[2].target, gm)
bn_weight = fetch_attr(node.args[3].target, gm)
bn_bias = fetch_attr(node.args[4].target, gm)
if bn_running_mean is None or bn_running_var is None:
continue
fused_conv = copy.deepcopy(conv)
fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
fused_conv.weight,
fused_conv.bias,
bn_running_mean,
bn_running_var,
bn_eps,
bn_weight,
bn_bias,
)
replace_node_module(node.args[0], modules, fused_conv)
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
gm.graph.lint()
gm.recompile()
return gm
def _philox_rand_like_meta(input, seed, offset):
return _prims.TensorMeta(input)
def _philox_rand_like(input, seed, offset):
# placeholder only used in tracing
return torch.rand_like(input)
class NormalizedLinearNode:
def __init__(self, node: torch.fx.Node) -> None:
assert node.op == "call_function"
assert node.target in [torch.nn.functional.linear]
self.node: torch.fx.Node = node
def get_input(self) -> torch.fx.Node:
if len(self.node.args) > 0:
return self.node.args[0]
else:
return self.node.kwargs["input"]
def get_weight(self) -> torch.fx.Node:
if len(self.node.args) > 1:
return self.node.args[1]
else:
return self.node.kwargs["weight"]
def get_bias(self) -> torch.fx.Node:
if len(self.node.args) > 2:
return self.node.args[2]
else:
return self.node.kwargs["bias"]
class NormalizedMatmulNode:
def __init__(self, node: torch.fx.Node) -> None:
assert node.op == "call_function"
assert node.target in [torch.bmm, torch.matmul]
self.node: torch.fx.Node = node
def get_input(self) -> torch.fx.Node:
if len(self.node.args) > 0:
return self.node.args[0]
else:
return self.node.kwargs["input"]
def get_other(self) -> torch.fx.Node:
if len(self.node.args) > 1:
return self.node.args[1]
else:
return self.node.kwargs["other"]
def check_permute(node: torch.fx.Node):
ranks = len(node.meta["tensor_meta"].shape)
if len(node.args) > 3:
permutation = [node.args[i] % ranks for i in range(1, ranks + 1)]
elif (
"permutation" in node.kwargs
and node.kwargs["permutation"] is not None
and len(node.kwargs["permutation"]) > 2
):
permutation = [i % ranks for i in node.kwargs["permutation"]]
else:
return False
allowed_permutation = list(range(ranks))
allowed_permutation[-1] = ranks - 2
allowed_permutation[-2] = ranks - 1
return permutation == allowed_permutation
def sink_cat_after_pointwise(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
def one_user(node):
users = list(node.users)
return users[0] if len(users) == 1 else None
def is_view(node):
view = {"view"}
return node.op == "call_method" and node.target in view
def is_pointwise_unary(node):
pointwise = {torch.relu, torch.tanh, "relu", "tanh"}
return node.op in {"call_function", "call_method"} and node.target in pointwise
g = module.graph
for node in g.nodes:
if node.op != "call_function" or node.target != torch.cat:
continue
cat_or_view = node
while True:
user = one_user(cat_or_view)
if not user or not is_view(user):
break
cat_or_view = user
if user and is_pointwise_unary(user):
with g.inserting_before(node):
new_tensors = [
g.create_node(user.op, user.target, args=(arg,), kwargs=user.kwargs)
for arg in node.args[0]
]
node.args = (new_tensors,) + node.args[1:]
user.replace_all_uses_with(cat_or_view)
g.erase_node(user)
g.lint()
module.recompile()
return module
def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in module.graph.nodes:
if (
node.op == "call_method"
and node.target == "permute"
and check_permute(node)
):
if len(node.args) > 0:
input_node = node.args[0]
else:
input_node = node.kwargs["input"]
if (
input_node.op == "call_function"
and input_node.target == torch.nn.functional.linear
):
normalized = NormalizedLinearNode(input_node)
input = normalized.get_input()
weight = normalized.get_weight()
bias = normalized.get_bias()
with module.graph.inserting_before(node):
fused_node = module.graph.call_function(
linear_transpose, args=(input, weight, bias)
)
node.replace_all_uses_with(fused_node)
module.graph.erase_node(node)
if len(input_node.users) == 0:
module.graph.erase_node(input_node)
module.graph.lint()
module.recompile()
return module
Loading ...