import collections
import dataclasses
import enum
import functools
import inspect
import operator
import re
import types
from typing import Any, Optional, Union
import torch
from torch import SymInt
from torch._guards import GuardSource
from torch._ops import PyOperator
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.immutable_collections import immutable_list
from .. import config, mutation_guard, replay_record, skipfiles
from ..allowed_functions import is_allowed, is_builtin_callable, is_numpy
from ..exc import unimplemented
from ..guards import GuardBuilder
from ..side_effects import SideEffects
from ..source import (
AttrSource,
ConstantSource,
GetItemSource,
GlobalSource,
GlobalWeakRefSource,
is_constant_source,
LocalInputSource,
LocalSource,
RandomValueSource,
Source,
TupleIteratorGetItemSource,
)
from ..utils import (
clone_input,
get_fake_value,
getfile,
global_key_name,
HAS_NUMPY,
is_namedtuple,
is_numpy_int_type,
is_typing,
istensor,
istype,
np,
odict_values,
preserve_rng_state,
tuple_iterator,
tuple_iterator_getitem,
tuple_iterator_len,
wrap_fake_exception,
)
from .base import MutableLocal, typestr
from .builtin import BuiltinVariable
from .constant import ConstantVariable, EnumVariable
from .dicts import (
ConstDictVariable,
DataClassVariable,
DefaultDictVariable,
HFPretrainedConfigVariable,
)
from .functions import UserFunctionVariable
from .lists import (
ListVariable,
NamedTupleVariable,
RangeVariable,
SizeVariable,
SliceVariable,
TupleIteratorVariable,
TupleVariable,
)
from .misc import (
AutogradFunctionContextVariable,
AutogradFunctionVariable,
ComptimeVariable,
GetAttrVariable,
InspectSignatureVariable,
LambdaVariable,
NumpyVariable,
PythonModuleVariable,
SkipFilesVariable,
TypingVariable,
)
from .nn_module import UnspecializedNNModuleVariable
from .tensor import (
SymNodeVariable,
TensorVariable,
TensorWithTFOverrideVariable,
UnspecializedPythonVariable,
)
from .torch import (
tensor_dunder_fns,
torch_special_class_types,
TorchPyOperator,
TorchVariable,
)
from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable
class _missing:
pass
@dataclasses.dataclass
class GraphArg:
source: Source
example: Any
is_unspecialized: bool
fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor]
# UnspecializedPythonVariable often masquerades as a tensor.
# We MUST NOT generate shape guard code
# that actually tries to access tensor properties on these values.
# is_tensor lets us tell if this graph arg actually is a tensor
# or not.
is_tensor: bool = True
def __post_init__(self):
if isinstance(self.example, torch.Tensor):
assert isinstance(
self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor
)
# Mapping for downstream systems to remap back into dynamo arg positions
if isinstance(self.source, LocalInputSource):
if "graph_arg_pos" not in self.fake_tensor.__dict__:
self.fake_tensor.__dict__["graph_arg_pos"] = []
self.fake_tensor.__dict__["graph_arg_pos"].append(self.source.pos)
if isinstance(self.example, torch._subclasses.fake_tensor.FakeTensor):
raise AssertionError("Fake Tensor observed in TorchDynamo Fx graph inputs")
def load(self, tx):
return self.source.reconstruct(tx)
def get_examples(self):
return [self.example]
def get_fake_examples(self):
if self.fake_tensor is not None:
assert isinstance(
self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor
)
return [self.fake_tensor]
def __len__(self):
return 1
def erase(self):
self.example = None
class VariableBuilder:
"""Wrap a python value in a VariableTracker() instance"""
def __init__(
self,
tx,
source: Source,
):
assert source is not None
super().__init__()
self.tx = tx
self.source = source
self.name = source.name()
def __call__(self, value):
if value in self.tx.output.side_effects:
# TODO(jansel): add guard for alias relationship
return self.tx.output.side_effects[value]
return self._wrap(value).clone(**self.options())
@staticmethod
@functools.lru_cache(None)
def _common_constants():
return set(range(17)).union(
{
20,
30,
40,
32,
64,
96,
128,
144,
240,
256,
672,
1024,
2048,
4096,
0.1,
0.01,
0.001,
0.5,
0.05,
800,
1.873536229133606,
4.135166556742356, # Work around for vision_maskrcnn where torch.clamp can't be on different devices
}
)
@staticmethod
def list_type(value):
if is_namedtuple(value):
return functools.partial(NamedTupleVariable, tuple_cls=type(value))
return {
tuple: TupleVariable,
list: ListVariable,
odict_values: ListVariable,
torch.nn.ParameterList: ListVariable,
torch.nn.ModuleList: ListVariable,
}[type(value)]
def get_source(self):
return self.source
def options(self):
return {"source": self.get_source()}
def make_guards(self, *guards):
source = self.get_source()
if (
isinstance(source, ConstantSource)
or source.guard_source() == GuardSource.CONSTANT
):
return None
return {source.make_guard(guard) for guard in guards}
def _wrap(self, value):
from ..comptime import comptime
make_guards = self.make_guards
if istype(value, (torch.SymInt, torch.SymFloat)):
return self.wrap_sym(value)
if istensor(value):
return self.wrap_tensor(value)
elif istype(value, (tuple, list, odict_values)) or is_namedtuple(value):
# One can index a tensor with a list/tuple. Therefore, we need to
# have a stricter match.
if istype(value, (tuple, list)) and all(
[isinstance(x, int) or is_numpy_int_type(x) or x is None for x in value]
):
guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
else:
guards = self.make_guards(GuardBuilder.LIST_LENGTH)
output = [
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
item
).add_guards(guards)
for i, item in enumerate(value)
]
result = self.list_type(value)(output, guards=guards)
if istype(value, list):
return self.tx.output.side_effects.track_list(
self.source, value, result
)
return result
elif istype(value, tuple_iterator):
guards = self.make_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
output = [
VariableBuilder(
self.tx, TupleIteratorGetItemSource(self.get_source(), i)
)(tuple_iterator_getitem(value, i)).add_guards(guards)
for i in range(tuple_iterator_len(value))
]
return TupleIteratorVariable(
output, mutable_local=MutableLocal(), guards=guards
)
elif istype(value, (slice, range)):
items = [
VariableBuilder(self.tx, AttrSource(self.get_source(), k))(
getattr(value, k)
)
for k in ("start", "stop", "step")
]
if isinstance(value, slice):
return SliceVariable(items, guards=make_guards(GuardBuilder.TYPE_MATCH))
else:
return RangeVariable(
items, guards=make_guards(GuardBuilder.EQUALS_MATCH)
)
elif istype(
value, (dict, collections.defaultdict, collections.OrderedDict)
) and all(
map(
lambda k: ConstantVariable.is_literal(k)
or self.tensor_can_be_dict_key(k)
or isinstance(k, enum.Enum),
value.keys(),
)
):
guards = self.make_guards(GuardBuilder.DICT_KEYS)
# store key variables in global location for reconstruction
for key in value.keys():
if self.tensor_can_be_dict_key(key):
self.tx.store_dict_key(global_key_name(key), key)
def index_source(key):
if self.tensor_can_be_dict_key(key):
return GlobalWeakRefSource(global_key_name(key))
else:
return key
result = {
k: VariableBuilder(
self.tx, GetItemSource(self.get_source(), index_source(k))
)(value[k]).add_guards(guards)
for k in value.keys()
}
if istype(value, collections.defaultdict):
result = DefaultDictVariable(
result, type(value), value.default_factory, guards=guards
)
else:
result = ConstDictVariable(result, type(value), guards=guards)
return self.tx.output.side_effects.track_dict(self.source, value, result)
elif isinstance(value, torch.nn.Module):
if (
isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
and not config.allow_rnn
):
unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs")
if mutation_guard.is_dynamic_nn_module(value):
# created dynamically, don't specialize on it
result = UnspecializedNNModuleVariable(
value, guards=make_guards(GuardBuilder.TYPE_MATCH)
)
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__
return result
return self.tx.output.side_effects.track_object_existing(
self.source, value, result
)
elif getattr(value, "_is_fsdp_managed_module", False) or issubclass(
value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
):
if getattr(value, "_is_fsdp_managed_module", False):
# Note: we can't do this assert inside FSDP constructor,
# since we don't know yet whether dynamo will be used
assert getattr(
Loading ...