from collections import defaultdict
from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name
import torch.utils._pytree as pytree
from . import _pytree as fx_pytree
from ._compatibility import compatibility
import contextlib
from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type
from dataclasses import dataclass
from contextlib import contextmanager
import copy
import torch
import keyword
import re
import builtins
import math
import warnings
import inspect
__all__ = ["PythonCode", "CodeGen", "Graph"]
if TYPE_CHECKING:
from .graph_module import GraphModule # noqa: F401
from ._symbolic_trace import Tracer # noqa: F401
# Mapping of builtins to their `typing` equivalent.
_origin_type_map = {
list: List,
dict: Dict,
set: Set,
frozenset: FrozenSet,
tuple: Tuple,
}
# Signature for functions thattransforms the body (`list[str]`) of the
# generated code
TransformCodeFunc = Callable[[List[str]], List[str]]
class _CustomBuiltin(NamedTuple):
"""Additional objs that we add to every graph's globals.
The repr() for some standard library objects is not valid Python code without
an import. For common objects of this sort, we bundle them in the globals of
every FX graph.
"""
# How to import this object from the standard library.
import_str: str
# The actual object, produced from that import string.
obj: Any
_custom_builtins: Dict[str, _CustomBuiltin] = {}
def _register_custom_builtin(name: str, import_str: str, obj: Any):
_custom_builtins[name] = _CustomBuiltin(import_str, obj)
_register_custom_builtin('inf', 'from math import inf', math.inf)
_register_custom_builtin('nan', 'from math import nan', math.nan)
_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None))
_register_custom_builtin('torch', 'import torch', torch)
_register_custom_builtin('device', 'from torch import device', torch.device)
_register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree)
_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree)
def _is_magic(x: str) -> bool:
return x.startswith('__') and x.endswith('__')
def _snake_case(s: str) -> str:
"""
Transforms the given string ``s`` to a Python-style variable name
Examples:
``mod.snake_case`` -> ``mod.snake_case``
``mod.pascalCase``-> ``mod.pascal_case``
``mod.ALL_CAPS`` -> ``mod.all_caps``
"""
chars = []
prev_lower = False
for c in s:
if prev_lower and c.isupper():
chars.append('_')
chars.append(c.lower())
prev_lower = c.islower()
return ''.join(chars)
def _is_from_torch(obj: Any) -> bool:
module_name = getattr(obj, '__module__', None)
if module_name is not None:
base_module = module_name.partition('.')[0]
return (
base_module == 'torch' and
not module_name.startswith("torch._dynamo.") and
not module_name.startswith("torch._inductor.")
)
name = getattr(obj, '__name__', None)
# exclude torch because torch.torch.torch.torch works. idk mang
if name is not None and name != 'torch':
for guess in [torch, torch.nn.functional]:
if getattr(guess, name, None) is obj:
return True
return False
class _Namespace:
"""A context for associating names uniquely with objects.
The following invariants are enforced:
- Each object gets a single name.
- Each name is unique within a given namespace.
- Names generated do not shadow builtins, unless the object is indeed that builtin.
"""
def __init__(self):
self._obj_to_name: Dict[Any, str] = {}
self._unassociated_names = set()
self._used_names: Set[str] = set()
self._base_count: Dict[str, int] = defaultdict(int)
self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+')
self._name_suffix_regex = re.compile(r"(.*)_(\d+)$")
def create_name(self, candidate: str, obj: Optional[Any]) -> str:
"""Create a unique name.
Arguments:
candidate: used as the basis for the unique name, relevant to the user.
obj: If not None, an object that will be associated with the unique name.
"""
if obj is not None and obj in self._obj_to_name:
return self._obj_to_name[obj]
# delete all characters that are illegal in a Python identifier
candidate = self._illegal_char_regex.sub('_', candidate)
if not candidate:
candidate = '_unnamed'
if candidate[0].isdigit():
candidate = f'_{candidate}'
match = self._name_suffix_regex.match(candidate)
if match is None:
base = candidate
num = None
else:
base, num_str = match.group(1, 2)
num = int(num_str)
candidate = base if num is None else f'{base}_{num}'
if not num:
num = self._base_count[base]
while candidate in self._used_names or self._is_illegal_name(candidate, obj):
num += 1
candidate = f'{base}_{num}'
self._used_names.add(candidate)
self._base_count[base] = num
if obj is None:
self._unassociated_names.add(candidate)
else:
self._obj_to_name[obj] = candidate
return candidate
def associate_name_with_obj(self, name: str, obj: Any):
"""Associate a unique name with an object.
Neither `name` nor `obj` should be associated already.
"""
assert obj not in self._obj_to_name
assert name in self._unassociated_names
self._obj_to_name[obj] = name
self._unassociated_names.remove(name)
def _is_illegal_name(self, name: str, obj: Any) -> bool:
# 1. keywords are never allowed as names.
if name in keyword.kwlist:
return True
# 2. Can't shadow a builtin name, unless you *are* that builtin.
if name in builtins.__dict__:
return obj is not builtins.__dict__[name]
# 3. Can't shadow our custom builtins either
if name in _custom_builtins:
return obj is not _custom_builtins[name].obj
return False
dtype_abbrs = {
torch.bfloat16: 'bf16',
torch.float64: 'f64',
torch.float32: 'f32',
torch.float16: 'f16',
torch.complex32: 'c32',
torch.complex64: 'c64',
torch.complex128: 'c128',
torch.int8: 'i8',
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
torch.bool: 'b8',
torch.uint8: 'u8',
}
@compatibility(is_backward_compatible=True)
@dataclass
class PythonCode:
"""
Represents all the information necessary to exec or save a graph as Python code.
"""
# Python source code for the forward function definition.
src: str
# Values in global scope during exection of `src_def`.
globals: Dict[str, Any]
def _format_target(base: str, target: str) -> str:
elems = target.split('.')
r = base
for e in elems:
if not e.isidentifier():
r = f'getattr({r}, "{e}")'
else:
r = f'{r}.{e}'
return r
class _InsertPoint:
def __init__(self, graph, new_insert):
self.graph = graph
self.orig_insert, graph._insert = graph._insert, new_insert
def __enter__(self):
pass
def __exit__(self, type, value, tb):
self.graph._insert = self.orig_insert
class _node_list:
def __init__(self, graph: 'Graph', direction: str = '_next'):
assert direction in ['_next', '_prev']
self.graph = graph
self.direction = direction
def __len__(self):
return self.graph._len
def __iter__(self):
root, direction = self.graph._root, self.direction
cur = getattr(root, direction)
while cur is not root:
if not cur._erased:
yield cur
cur = getattr(cur, direction)
def __reversed__(self):
return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')
class _PyTreeInfo(NamedTuple):
"""
Contains extra info stored when we're using Pytrees
"""
orig_args: List[str]
in_spec: pytree.TreeSpec
out_spec: Optional[pytree.TreeSpec]
@compatibility(is_backward_compatible=False)
class CodeGen:
def __init__(self):
self._body_transformer: Optional[TransformCodeFunc] = None
def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str:
"""
Given the free variables and a return annotation, generates the beginning of the FX function.
By default, `gen_fn_def(['a', 'b'], '') == 'def forward(a, b):'`
"""
# If the original function didn't have self as its first argument, we
# would have added it.
if len(free_vars) == 0 or free_vars[0] != 'self':
free_vars.insert(0, 'self')
return f"def forward({', '.join(free_vars)}){maybe_return_annotation}:"
def generate_output(self, output_args: Argument) -> str:
"""
Given the output arguments, generates the return statement of the FX function.
Note: The returned statement should not be indented.
"""
return f'return {repr(output_args)}'
def process_inputs(self, *args: Any) -> Any:
"""
Transforms the inputs so that the graph can take them as arguments, as
non-default codegen may result in the inputs to the function being
different from the inputs to the graph.
If the graph was directly runnable, this invariant should hold true
`f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)`
"""
return args
def process_outputs(self, outputs: Any) -> Any:
"""
Transforms the outputs of the graph to be identical to the codegen.
See ``process_inputs`` for more details.
"""
return outputs
def additional_globals(self) -> List[Tuple[str, Any]]:
"""
If your codegen uses extra global values, add tuples of (identifier,reference to the value) here.
For example, return ['List', typing.List] if you need ``List`` in the global context.
"""
return []
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, *, verbose: bool = False) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
globals_: Dict[str, Any] = {}
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
maybe_return_annotation : List[str] = ['']
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
return _get_qualified_name(obj)
Loading ...