import inspect
import itertools
import operator
import types
from typing import Dict, List
import torch.fx
import torch.random
from torch.fx.experimental.symbolic_shapes import guard_scalar
from .. import config, variables
from ..exc import unimplemented
from ..guards import GuardBuilder
from ..source import AttrSource
from ..utils import (
fqn,
get_fake_value,
get_real_value,
HAS_NUMPY,
np,
product,
proxy_args_kwargs,
tensortype_to_dtype,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .lists import ShapeVariable, SizeVariable
supported_tensor_comparison_ops = {
">": operator.gt,
"<": operator.lt,
">=": operator.ge,
"<=": operator.le,
"==": operator.eq,
"!=": operator.ne,
}
supported_const_comparison_ops = {
"is": operator.is_,
"is not": operator.is_not,
"==": operator.eq,
"!=": operator.ne,
}
class TensorVariable(VariableTracker):
"""A torch.Tensor input or an intermediate value in the FX graph"""
_nonvar_fields = [
"proxy",
"dtype",
"device",
"layout",
"ndim",
"size",
"stride",
"requires_grad",
"is_quantized",
"is_contiguous",
]
def get_real_value(self):
"""
Get the actual value represented by this variable if computation is run
using the user-provided inputs.
NOTE: this runs actual tensor computation and may be
slow and memory-intensive.
"""
return get_real_value(self.proxy.node, self.proxy.tracer)
def __init__(
self,
proxy: torch.fx.Proxy,
dtype=None,
device=None,
layout=None,
ndim=None,
size=None,
stride=None,
requires_grad=None,
is_quantized=None,
is_contiguous=None,
is_sparse=None,
class_type=torch.Tensor,
specialized_value=None,
**kwargs,
):
super().__init__(**kwargs)
self.proxy = proxy
self.dtype = dtype
self.device = device
self.layout = layout
self.ndim = ndim
self.size = size
self.stride = stride
self.requires_grad = requires_grad
self.is_quantized = is_quantized
self.is_contiguous = is_contiguous
self.is_sparse = is_sparse
self.class_type = class_type
self.specialized_value = specialized_value
def as_proxy(self):
return self.proxy
def python_type(self):
return self.class_type
def call_isinstance(self, tensor_type):
def check_type(ty):
if ty not in tensortype_to_dtype:
return issubclass(self.python_type(), ty)
dtypes = tensortype_to_dtype[ty]
return self.dtype in dtypes
if type(tensor_type) is tuple:
return any([check_type(ty) for ty in tensor_type])
else:
return check_type(tensor_type)
@staticmethod
def specialize(value: torch.Tensor):
props = {
"dtype": value.dtype,
"device": value.device,
"layout": value.layout,
"ndim": int(value.ndim),
"requires_grad": value.requires_grad,
"is_quantized": value.is_quantized,
"is_sparse": value.is_sparse,
"class_type": type(value),
}
if not config.dynamic_shapes:
props["size"] = tuple(value.size())
props["stride"] = tuple(value.stride())
props["is_contiguous"] = tuple(
[
x
for x in torch._prims_common._memory_formats
if value.is_contiguous(memory_format=x)
]
)
return props
def var_getattr(self, tx, name):
from . import ConstantVariable, TorchVariable
result = None
options = VariableTracker.propagate(self)
if name == "ndim" and self.ndim is not None:
result = ConstantVariable(self.ndim, **options)
elif name == "dtype" and self.dtype is not None:
result = TorchVariable(self.dtype, **options)
elif name == "device" and self.device is not None:
result = TorchVariable(self.device, **options)
elif name == "layout" and self.layout is not None:
result = TorchVariable(self.layout, **options)
elif name == "is_cuda" and self.device is not None:
result = ConstantVariable(self.device.type == "cuda", **options)
elif name == "shape" and self.size is not None:
sizes = [variables.ConstantVariable(x) for x in self.size]
result = ShapeVariable(sizes, **options)
elif name == "requires_grad" and self.requires_grad is not None:
result = ConstantVariable(self.requires_grad, **options)
elif name == "is_quantized" and self.is_quantized is not None:
result = ConstantVariable(self.is_quantized, **options)
elif name == "is_sparse" and self.is_sparse is not None:
result = ConstantVariable(self.is_sparse, **options)
elif name == "shape" and self.size is None:
result = self.call_method(tx, "size", [], {})
elif name == "ndim" and self.ndim is None:
result = self.call_method(tx, "dim", [], {})
elif name == "data":
result = self.call_method(tx, "detach", [], {})
if name == "__class__":
return TorchVariable(self.python_type(), **options)
# Add a guard for type matching, these guards are checked before tensor guards
# In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
# <tensor> is later changed to another type
if result is not None and self.source is not None:
result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
# For attributes (not methods) that were not caught in the special handling above,
# (e.g. tensor.real), we handle these generically, assuming that the output type is
# a tensor.
if result is None:
def try_generic_attr_handling():
from .builder import wrap_fx_proxy
from .misc import GetAttrVariable
try:
static_attr = inspect.getattr_static(torch.Tensor, name)
except AttributeError:
return None
# Make sure this is an attribute, not a method.
# type(torch.Tensor.H) should be "getset_descriptor"
# This is a because of CPython implementation, see THPVariableType:
# these attributes are implemented under tp_getset, which appear
# as `getset_descriptor`s, (compared to, say, methods which appear
# as `method_descriptor`s)
if type(static_attr) != types.GetSetDescriptorType:
return None
return wrap_fx_proxy(
tx=tx,
proxy=GetAttrVariable.create_getattr_proxy(self.as_proxy(), name),
**options,
)
result = try_generic_attr_handling()
if result is None:
raise NotImplementedError()
return result
def has_unpack_var_sequence(self, tx):
return (self.size is not None and len(self.size) > 0) or (
self.size is None and config.dynamic_shapes
)
def unpack_var_sequence(self, tx, idxes=None):
from .builder import wrap_fx_proxy
options = VariableTracker.propagate(self)
if idxes is None:
if self.size:
length = self.size[0]
else:
dyn_length = self.call_method(tx, "size", [ConstantVariable(0)], {})
assert isinstance(dyn_length, SymNodeVariable)
length = dyn_length.evaluate_expr(tx.output)
idxes = range(length)
return [wrap_fx_proxy(tx, self.as_proxy()[i], **options) for i in idxes]
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from . import ConstantVariable, TorchVariable, TupleVariable
from .builder import wrap_fx_proxy
kwargs = dict(kwargs)
options = VariableTracker.propagate(self, args, kwargs.values())
if name == "stride" and self.stride is not None:
constant_result = ConstantVariable(self.stride, **options)
elif name == "size" and self.size is not None:
sizes = [variables.ConstantVariable(x) for x in self.size]
constant_result = SizeVariable(sizes, **options)
elif name == "size" and self.size is None and config.dynamic_shapes:
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
elif name in ("numel", "nelement") and self.size is not None:
constant_result = ConstantVariable(product(self.size), **options)
elif name in ("ndimension", "dim") and self.ndim is not None:
constant_result = ConstantVariable(self.ndim, **options)
elif name == "is_floating_point" and self.dtype is not None:
constant_result = ConstantVariable(self.dtype.is_floating_point, **options)
elif name == "is_contiguous" and self.is_contiguous is not None:
if "memory_format" in kwargs:
memory_format = kwargs.pop("memory_format").as_python_constant()
else:
memory_format = torch.contiguous_format
constant_result = ConstantVariable(
memory_format in self.is_contiguous, **options
)
elif (
name == "type"
and self.dtype is not None
and len(args) == 0
and isinstance(self.device, torch.device)
):
tensortype = [k for k, v in tensortype_to_dtype.items() if self.dtype in v][
0
]
if self.device.type == "cuda":
constant_result = ConstantVariable(
f"torch.cuda.{tensortype.__name__}", **options
)
else:
constant_result = ConstantVariable(
f"torch.{tensortype.__name__}", **options
)
elif (
name == "type"
and len(args) == 1
and fqn(type(args[0].as_python_constant())) == "torch.tensortype"
):
# torch.FloatTensor, etc. are all of type "torch.tensortype".
# torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type.
# So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args)
tensor_type = args[0].as_python_constant()
tensor_type_const = ConstantVariable(fqn(tensor_type), **options)
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self, tensor_type_const], kwargs),
),
**options,
)
elif name == "get_device" and isinstance(self.device, torch.device):
index = self.device.index if self.device.type != "cpu" else -1
constant_result = ConstantVariable(index, **options)
else:
constant_result = None
if constant_result:
assert not kwargs, f"Tensor.{name}() unhandled kwargs"
if len(args) == 1:
return constant_result.getitem_const(args[0])
elif args:
return TupleVariable(
[constant_result.getitem_const(a) for a in args], **options
)
return constant_result
elif (
name == "repeat"
and not all(
x.is_python_constant() for x in itertools.chain(args, kwargs.values())
)
and not config.dynamic_shapes
):
unimplemented("dynamic Tensor.repeat")
elif name in ("tolist", "numpy", "backward", "data_ptr"):
unimplemented(f"Tensor.{name}")
elif name == "nonzero" and not config.dynamic_shapes:
unimplemented(f"Tensor.{name}")
elif name == "item" and not config.capture_scalar_outputs:
unimplemented(f"Tensor.{name}")
Loading ...