import operator
from typing import Any, Callable, Dict, Tuple, Optional
import torch
import torch.fx
import torch.fx as fx
from torch.fx import Transformer, Proxy
from torch.fx.node import Argument, Target, Node, map_aggregate
from torch.fx.operator_schemas import (
normalize_module,
normalize_function,
create_type_hint,
)
from .schema_type_annotation import AnnotateTypesWithSchema
class NormalizeArgs(Transformer):
"""
Normalize arguments to Python targets. This means that
`args/kwargs` will be matched up to the module/functional's
signature and rewritten to exclusively kwargs in positional order
if `normalize_to_only_use_kwargs` is true. Also populates default
values. Does not support positional-only parameters or varargs
parameters (*args, **kwargs).
If the nodes have 'type' metadata, it will use it to disambiguate
overloads. Otherwise, it will throw an error.
Example usage:
m = torchvision.models.resnet18()
traced = torch.fx.symbolic_trace(m)
traced = NormalizeArgs(traced).transform()
"""
def __init__(
self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True
):
super().__init__(module)
self.node_map: Dict[Proxy, Node] = {}
self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
def run_node(self, n: Node) -> Any:
args, kwargs = self.fetch_args_kwargs_from_env(n)
def get_type(arg):
if isinstance(arg, fx.Node):
return n.meta["type"] if "type" in n.meta else None
return type(arg)
arg_types = map_aggregate(n.args, get_type)
assert isinstance(arg_types, tuple)
arg_types = tuple([create_type_hint(i) for i in arg_types])
kwarg_types = {k: get_type(v) for k, v in kwargs.items()}
if n.op == "call_function":
out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types)
else:
out = super().run_node(n)
if n.op != "output":
self.node_map[out] = n
out.node.meta = n.meta
out.node.type = n.type
return out
def call_function(
self,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Any],
arg_types: Optional[Tuple[Any, ...]] = None,
kwarg_types: Optional[Dict[str, Any]] = None,
):
assert callable(target)
new_args_and_kwargs = normalize_function(
target,
args, # type: ignore[arg-type]
kwargs,
arg_types, # type: ignore[arg-type]
kwarg_types,
self.normalize_to_only_use_kwargs,
)
if new_args_and_kwargs:
new_args, new_kwargs = new_args_and_kwargs
return self.tracer.create_proxy(
"call_function", target, new_args, new_kwargs
)
else:
return super().call_function(target, args, kwargs)
def call_module(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
):
assert isinstance(target, str)
new_args_and_kwargs = normalize_module(
self.module,
target,
args, # type: ignore[arg-type]
kwargs,
self.normalize_to_only_use_kwargs,
)
if new_args_and_kwargs:
new_args, new_kwargs = new_args_and_kwargs
return super().call_module(target, new_args, new_kwargs)
else:
return super().call_module(target, args, kwargs)
class NormalizeOperators(AnnotateTypesWithSchema):
"""
Normalize callsites that are different ways of "spelling" the same
invocation into a single, canonical call. Currently supports:
1. Normalize operators (e.g. operator.add) to the `torch` ops they
ultimately invoke (e.g. torch.add) when it is possible to statically
reason that
Example usage:
m = torchvision.models.resnet18()
traced = torch.fx.symbolic_trace(m)
traced = NormalizeOperators(traced).transform()
"""
binary_magic_method_remap: Dict[
Callable[[Any, Any], Any], Callable[[Any, Any], Any]
] = {
torch.add: operator.add,
torch.mul: operator.mul,
torch.sub: operator.sub,
torch.div: operator.truediv,
torch.floor_divide: operator.floordiv,
torch.remainder: operator.mod,
torch.eq: operator.eq,
torch.ne: operator.ne,
torch.lt: operator.lt,
torch.le: operator.le,
torch.gt: operator.gt,
torch.ge: operator.ge,
}
def call_function(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
):
# Normalize operators according to the magic methods implemented on tensors here:
# https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950
assert callable(target)
if target in self.binary_magic_method_remap:
if len(args) != 2:
return super().call_function(target, args, kwargs)
lhs, rhs = args
return super().call_function(
target=self.binary_magic_method_remap[target],
args=(lhs, rhs),
kwargs={},
)
return super().call_function(target, args, kwargs)