Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ _dynamo / variables / builder.py

import collections
import dataclasses
import enum
import functools
import inspect
import operator
import re
import types
from typing import Any, Optional, Union

import torch

from torch import SymInt
from torch._guards import GuardSource
from torch._ops import PyOperator
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.immutable_collections import immutable_list

from .. import config, mutation_guard, replay_record, skipfiles
from ..allowed_functions import is_allowed, is_builtin_callable, is_numpy
from ..exc import unimplemented
from ..guards import GuardBuilder
from ..side_effects import SideEffects
from ..source import (
    AttrSource,
    ConstantSource,
    GetItemSource,
    GlobalSource,
    GlobalWeakRefSource,
    is_constant_source,
    LocalInputSource,
    LocalSource,
    RandomValueSource,
    Source,
    TupleIteratorGetItemSource,
)
from ..utils import (
    clone_input,
    get_fake_value,
    getfile,
    global_key_name,
    HAS_NUMPY,
    is_namedtuple,
    is_numpy_int_type,
    is_typing,
    istensor,
    istype,
    np,
    odict_values,
    preserve_rng_state,
    tuple_iterator,
    tuple_iterator_getitem,
    tuple_iterator_len,
    wrap_fake_exception,
)

from .base import MutableLocal, typestr
from .builtin import BuiltinVariable
from .constant import ConstantVariable, EnumVariable
from .dicts import (
    ConstDictVariable,
    DataClassVariable,
    DefaultDictVariable,
    HFPretrainedConfigVariable,
)
from .functions import UserFunctionVariable
from .lists import (
    ListVariable,
    NamedTupleVariable,
    RangeVariable,
    SizeVariable,
    SliceVariable,
    TupleIteratorVariable,
    TupleVariable,
)
from .misc import (
    AutogradFunctionContextVariable,
    AutogradFunctionVariable,
    ComptimeVariable,
    GetAttrVariable,
    InspectSignatureVariable,
    LambdaVariable,
    NumpyVariable,
    PythonModuleVariable,
    SkipFilesVariable,
    TypingVariable,
)
from .nn_module import UnspecializedNNModuleVariable
from .tensor import (
    SymNodeVariable,
    TensorVariable,
    TensorWithTFOverrideVariable,
    UnspecializedPythonVariable,
)
from .torch import (
    tensor_dunder_fns,
    torch_special_class_types,
    TorchPyOperator,
    TorchVariable,
)
from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable


class _missing:
    pass


@dataclasses.dataclass
class GraphArg:
    source: Source
    example: Any
    is_unspecialized: bool
    fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor]
    # UnspecializedPythonVariable often masquerades as a tensor.
    # We MUST NOT generate shape guard code
    # that actually tries to access tensor properties on these values.
    # is_tensor lets us tell if this graph arg actually is a tensor
    # or not.
    is_tensor: bool = True

    def __post_init__(self):
        if isinstance(self.example, torch.Tensor):
            assert isinstance(
                self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor
            )
            # Mapping for downstream systems to remap back into dynamo arg positions
            if isinstance(self.source, LocalInputSource):
                if "graph_arg_pos" not in self.fake_tensor.__dict__:
                    self.fake_tensor.__dict__["graph_arg_pos"] = []
                self.fake_tensor.__dict__["graph_arg_pos"].append(self.source.pos)
        if isinstance(self.example, torch._subclasses.fake_tensor.FakeTensor):
            raise AssertionError("Fake Tensor observed in TorchDynamo Fx graph inputs")

    def load(self, tx):
        return self.source.reconstruct(tx)

    def get_examples(self):
        return [self.example]

    def get_fake_examples(self):
        if self.fake_tensor is not None:
            assert isinstance(
                self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor
            )
            return [self.fake_tensor]

    def __len__(self):
        return 1

    def erase(self):
        self.example = None


class VariableBuilder:
    """Wrap a python value in a VariableTracker() instance"""

    def __init__(
        self,
        tx,
        source: Source,
    ):
        assert source is not None
        super().__init__()
        self.tx = tx
        self.source = source
        self.name = source.name()

    def __call__(self, value):
        if value in self.tx.output.side_effects:
            # TODO(jansel): add guard for alias relationship
            return self.tx.output.side_effects[value]
        return self._wrap(value).clone(**self.options())

    @staticmethod
    @functools.lru_cache(None)
    def _common_constants():
        return set(range(17)).union(
            {
                20,
                30,
                40,
                32,
                64,
                96,
                128,
                144,
                240,
                256,
                672,
                1024,
                2048,
                4096,
                0.1,
                0.01,
                0.001,
                0.5,
                0.05,
                800,
                1.873536229133606,
                4.135166556742356,  # Work around for vision_maskrcnn where torch.clamp can't be on different devices
            }
        )

    @staticmethod
    def list_type(value):
        if is_namedtuple(value):
            return functools.partial(NamedTupleVariable, tuple_cls=type(value))
        return {
            tuple: TupleVariable,
            list: ListVariable,
            odict_values: ListVariable,
            torch.nn.ParameterList: ListVariable,
            torch.nn.ModuleList: ListVariable,
        }[type(value)]

    def get_source(self):
        return self.source

    def options(self):
        return {"source": self.get_source()}

    def make_guards(self, *guards):
        source = self.get_source()
        if (
            isinstance(source, ConstantSource)
            or source.guard_source() == GuardSource.CONSTANT
        ):
            return None
        return {source.make_guard(guard) for guard in guards}

    def _wrap(self, value):
        from ..comptime import comptime

        make_guards = self.make_guards
        if istype(value, (torch.SymInt, torch.SymFloat)):
            return self.wrap_sym(value)
        if istensor(value):
            return self.wrap_tensor(value)
        elif istype(value, (tuple, list, odict_values)) or is_namedtuple(value):
            # One can index a tensor with a list/tuple. Therefore, we need to
            # have a stricter match.
            if istype(value, (tuple, list)) and all(
                [isinstance(x, int) or is_numpy_int_type(x) or x is None for x in value]
            ):
                guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
            else:
                guards = self.make_guards(GuardBuilder.LIST_LENGTH)
            output = [
                VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
                    item
                ).add_guards(guards)
                for i, item in enumerate(value)
            ]
            result = self.list_type(value)(output, guards=guards)
            if istype(value, list):
                return self.tx.output.side_effects.track_list(
                    self.source, value, result
                )
            return result
        elif istype(value, tuple_iterator):
            guards = self.make_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
            output = [
                VariableBuilder(
                    self.tx, TupleIteratorGetItemSource(self.get_source(), i)
                )(tuple_iterator_getitem(value, i)).add_guards(guards)
                for i in range(tuple_iterator_len(value))
            ]
            return TupleIteratorVariable(
                output, mutable_local=MutableLocal(), guards=guards
            )
        elif istype(value, (slice, range)):
            items = [
                VariableBuilder(self.tx, AttrSource(self.get_source(), k))(
                    getattr(value, k)
                )
                for k in ("start", "stop", "step")
            ]
            if isinstance(value, slice):
                return SliceVariable(items, guards=make_guards(GuardBuilder.TYPE_MATCH))
            else:
                return RangeVariable(
                    items, guards=make_guards(GuardBuilder.EQUALS_MATCH)
                )
        elif istype(
            value, (dict, collections.defaultdict, collections.OrderedDict)
        ) and all(
            map(
                lambda k: ConstantVariable.is_literal(k)
                or self.tensor_can_be_dict_key(k)
                or isinstance(k, enum.Enum),
                value.keys(),
            )
        ):
            guards = self.make_guards(GuardBuilder.DICT_KEYS)

            # store key variables in global location for reconstruction
            for key in value.keys():
                if self.tensor_can_be_dict_key(key):
                    self.tx.store_dict_key(global_key_name(key), key)

            def index_source(key):
                if self.tensor_can_be_dict_key(key):
                    return GlobalWeakRefSource(global_key_name(key))
                else:
                    return key

            result = {
                k: VariableBuilder(
                    self.tx, GetItemSource(self.get_source(), index_source(k))
                )(value[k]).add_guards(guards)
                for k in value.keys()
            }

            if istype(value, collections.defaultdict):
                result = DefaultDictVariable(
                    result, type(value), value.default_factory, guards=guards
                )
            else:
                result = ConstDictVariable(result, type(value), guards=guards)

            return self.tx.output.side_effects.track_dict(self.source, value, result)
        elif isinstance(value, torch.nn.Module):
            if (
                isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
                and not config.allow_rnn
            ):
                unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs")
            if mutation_guard.is_dynamic_nn_module(value):
                # created dynamically, don't specialize on it
                result = UnspecializedNNModuleVariable(
                    value, guards=make_guards(GuardBuilder.TYPE_MATCH)
                )
                if not SideEffects.cls_supports_mutation_side_effects(type(value)):
                    # don't allow STORE_ATTR mutation with custom __setattr__
                    return result
                return self.tx.output.side_effects.track_object_existing(
                    self.source, value, result
                )
            elif getattr(value, "_is_fsdp_managed_module", False) or issubclass(
                value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
            ):
                if getattr(value, "_is_fsdp_managed_module", False):
                    # Note: we can't do this assert inside FSDP constructor,
                    # since we don't know yet whether dynamo will be used
                    assert getattr(
Loading ...