Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ fx / graph.py

from collections import defaultdict
from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name
import torch.utils._pytree as pytree
from . import _pytree as fx_pytree
from ._compatibility import compatibility

import contextlib
from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type
from dataclasses import dataclass
from contextlib import contextmanager
import copy
import torch
import keyword
import re
import builtins
import math
import warnings
import inspect

__all__ = ["PythonCode", "CodeGen", "Graph"]

    from .graph_module import GraphModule  # noqa: F401
    from ._symbolic_trace import Tracer   # noqa: F401

# Mapping of builtins to their `typing` equivalent.
_origin_type_map = {
    list: List,
    dict: Dict,
    set: Set,
    frozenset: FrozenSet,
    tuple: Tuple,

# Signature for functions thattransforms the body (`list[str]`) of the
# generated code
TransformCodeFunc = Callable[[List[str]], List[str]]

class _CustomBuiltin(NamedTuple):
    """Additional objs that we add to every graph's globals.

    The repr() for some standard library objects is not valid Python code without
    an import. For common objects of this sort, we bundle them in the globals of
    every FX graph.
    # How to import this object from the standard library.
    import_str: str
    # The actual object, produced from that import string.
    obj: Any

_custom_builtins: Dict[str, _CustomBuiltin] = {}

def _register_custom_builtin(name: str, import_str: str, obj: Any):
    _custom_builtins[name] = _CustomBuiltin(import_str, obj)

_register_custom_builtin('inf', 'from math import inf', math.inf)
_register_custom_builtin('nan', 'from math import nan', math.nan)
_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None))
_register_custom_builtin('torch', 'import torch', torch)
_register_custom_builtin('device', 'from torch import device', torch.device)
_register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree)
_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree)

def _is_magic(x: str) -> bool:
    return x.startswith('__') and x.endswith('__')

def _snake_case(s: str) -> str:
    Transforms the given string ``s`` to a Python-style variable name

        ``mod.snake_case`` -> ``mod.snake_case``
        ``mod.pascalCase``-> ``mod.pascal_case``
        ``mod.ALL_CAPS`` -> ``mod.all_caps``
    chars = []
    prev_lower = False
    for c in s:
        if prev_lower and c.isupper():
        prev_lower = c.islower()
    return ''.join(chars)

def _is_from_torch(obj: Any) -> bool:
    module_name = getattr(obj, '__module__', None)
    if module_name is not None:
        base_module = module_name.partition('.')[0]
        return (
            base_module == 'torch' and
            not module_name.startswith("torch._dynamo.") and
            not module_name.startswith("torch._inductor.")

    name = getattr(obj, '__name__', None)
    # exclude torch because torch.torch.torch.torch works. idk mang
    if name is not None and name != 'torch':
        for guess in [torch, torch.nn.functional]:
            if getattr(guess, name, None) is obj:
                return True

    return False

class _Namespace:
    """A context for associating names uniquely with objects.

    The following invariants are enforced:
    - Each object gets a single name.
    - Each name is unique within a given namespace.
    - Names generated do not shadow builtins, unless the object is indeed that builtin.
    def __init__(self):
        self._obj_to_name: Dict[Any, str] = {}
        self._unassociated_names = set()
        self._used_names: Set[str] = set()
        self._base_count: Dict[str, int] = defaultdict(int)

        self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+')
        self._name_suffix_regex = re.compile(r"(.*)_(\d+)$")

    def create_name(self, candidate: str, obj: Optional[Any]) -> str:
        """Create a unique name.

            candidate: used as the basis for the unique name, relevant to the user.
            obj: If not None, an object that will be associated with the unique name.
        if obj is not None and obj in self._obj_to_name:
            return self._obj_to_name[obj]

        # delete all characters that are illegal in a Python identifier
        candidate = self._illegal_char_regex.sub('_', candidate)

        if not candidate:
            candidate = '_unnamed'

        if candidate[0].isdigit():
            candidate = f'_{candidate}'

        match = self._name_suffix_regex.match(candidate)
        if match is None:
            base = candidate
            num = None
            base, num_str = match.group(1, 2)
            num = int(num_str)

        candidate = base if num is None else f'{base}_{num}'
        if not num:
            num = self._base_count[base]

        while candidate in self._used_names or self._is_illegal_name(candidate, obj):
            num += 1
            candidate = f'{base}_{num}'

        self._base_count[base] = num
        if obj is None:
            self._obj_to_name[obj] = candidate
        return candidate

    def associate_name_with_obj(self, name: str, obj: Any):
        """Associate a unique name with an object.

        Neither `name` nor `obj` should be associated already.
        assert obj not in self._obj_to_name
        assert name in self._unassociated_names
        self._obj_to_name[obj] = name

    def _is_illegal_name(self, name: str, obj: Any) -> bool:
        # 1. keywords are never allowed as names.
        if name in keyword.kwlist:
            return True

        # 2. Can't shadow a builtin name, unless you *are* that builtin.
        if name in builtins.__dict__:
            return obj is not builtins.__dict__[name]

        # 3. Can't shadow our custom builtins either
        if name in _custom_builtins:
            return obj is not _custom_builtins[name].obj

        return False

dtype_abbrs = {
    torch.bfloat16: 'bf16',
    torch.float64: 'f64',
    torch.float32: 'f32',
    torch.float16: 'f16',
    torch.complex32: 'c32',
    torch.complex64: 'c64',
    torch.complex128: 'c128',
    torch.int8: 'i8',
    torch.int16: 'i16',
    torch.int32: 'i32',
    torch.int64: 'i64',
    torch.bool: 'b8',
    torch.uint8: 'u8',

class PythonCode:
    Represents all the information necessary to exec or save a graph as Python code.
    # Python source code for the forward function definition.
    src: str
    # Values in global scope during exection of `src_def`.
    globals: Dict[str, Any]

def _format_target(base: str, target: str) -> str:
    elems = target.split('.')
    r = base
    for e in elems:
        if not e.isidentifier():
            r = f'getattr({r}, "{e}")'
            r = f'{r}.{e}'
    return r

class _InsertPoint:
    def __init__(self, graph, new_insert):
        self.graph = graph
        self.orig_insert, graph._insert = graph._insert, new_insert

    def __enter__(self):

    def __exit__(self, type, value, tb):
        self.graph._insert = self.orig_insert

class _node_list:
    def __init__(self, graph: 'Graph', direction: str = '_next'):
        assert direction in ['_next', '_prev']
        self.graph = graph
        self.direction = direction

    def __len__(self):
        return self.graph._len

    def __iter__(self):
        root, direction = self.graph._root, self.direction
        cur = getattr(root, direction)
        while cur is not root:
            if not cur._erased:
                yield cur
            cur = getattr(cur, direction)

    def __reversed__(self):
        return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')

class _PyTreeInfo(NamedTuple):
    Contains extra info stored when we're using Pytrees
    orig_args: List[str]
    in_spec: pytree.TreeSpec
    out_spec: Optional[pytree.TreeSpec]

class CodeGen:
    def __init__(self):
        self._body_transformer: Optional[TransformCodeFunc] = None

    def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str:
        Given the free variables and a return annotation, generates the beginning of the FX function.
        By default, `gen_fn_def(['a', 'b'], '') == 'def forward(a, b):'`
        # If the original function didn't have self as its first argument, we
        # would have added it.
        if len(free_vars) == 0 or free_vars[0] != 'self':
            free_vars.insert(0, 'self')
        return f"def forward({', '.join(free_vars)}){maybe_return_annotation}:"

    def generate_output(self, output_args: Argument) -> str:
        Given the output arguments, generates the return statement of the FX function.
        Note: The returned statement should not be indented.
        return f'return {repr(output_args)}'

    def process_inputs(self, *args: Any) -> Any:
        Transforms the inputs so that the graph can take them as arguments, as
        non-default codegen may result in the inputs to the function being
        different from the inputs to the graph.

        If the graph was directly runnable, this invariant should hold true
        `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)`
        return args

    def process_outputs(self, outputs: Any) -> Any:
        Transforms the outputs of the graph to be identical to the codegen.

        See ``process_inputs`` for more details.
        return outputs

    def additional_globals(self) -> List[Tuple[str, Any]]:
        If your codegen uses extra global values, add tuples of (identifier,reference to the value) here.
        For example, return ['List', typing.List] if you need ``List`` in the global context.
        return []

    def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, *, verbose: bool = False) -> PythonCode:
        free_vars: List[str] = []
        body: List[str] = []
        globals_: Dict[str, Any] = {}
        wrapped_fns: Dict[str, None] = {}

        # Wrap string in list to pass by reference
        maybe_return_annotation : List[str] = ['']

        def add_global(name_hint: str, obj: Any):
            """Add an obj to be tracked as a global.

            We call this for names that reference objects external to the
            Graph, like functions or types.

            Returns: the global name that should be used to reference 'obj' in generated source.
            if _is_from_torch(obj) and obj != torch.device:  # to support registering torch.device
                # HACK: workaround for how torch custom ops are registered. We
                # can't import them like normal modules so they must retain their
                # fully qualified name.
                return _get_qualified_name(obj)
Loading ...