import logging
import math
import re
import types
from typing import Dict, List
import torch._C
import torch.fx
import torch.nn
import torch.onnx.operators
from torch._dynamo.utils import get_fake_value
from torch._dynamo.variables import SymNodeVariable
from torch._guards import GuardsCheckpointState
from .. import config, variables
from ..allowed_functions import torch_get_name
from ..exc import unimplemented
from ..source import GetItemSource, NNModuleSource
from ..utils import (
check_constant_args,
check_unspec_python_args,
HAS_NUMPY,
istype,
np,
product,
proxy_args_kwargs,
specialize_args_kwargs,
tensortype_to_dtype,
)
from .base import VariableTracker
from .lists import ListVariable, TupleVariable
from .misc import AutocastModeVariable, NullContextVariable
from .tensor import TensorWithTFOverrideVariable
log = logging.getLogger(__name__)
# TODO(voz): Maybe rename these later
tensor_dunder_fns = [
torch.Tensor.__rmatmul__,
torch.Tensor.__rmod__,
torch.Tensor.__rpow__,
torch.Tensor.__rsub__,
torch._C._TensorBase.__radd__,
torch._C._TensorBase.__rmul__,
torch._C._TensorBase.__ror__,
torch._C._TensorBase.__rxor__,
torch._C._TensorBase.__rand__,
]
torch_special_class_types = (torch._C.Generator,)
REWRITE_OPS_TO_TENSOR_SIZE_METHOD = [
torch.onnx.operators.shape_as_tensor,
torch._shape_as_tensor,
]
constant_fold_functions = [
torch._assert,
torch._utils._get_device_index,
torch.cuda.is_available,
torch.device,
torch.distributed.is_available,
torch.finfo,
torch.get_default_dtype,
torch.iinfo,
torch.is_floating_point,
torch.nn.functional._Reduction.get_enum,
]
if torch.distributed.is_available():
constant_fold_functions.append(torch.distributed.is_initialized)
# TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
def remap_as_fn___radd__(*args):
return torch._C._TensorBase.__radd__(*args)
def remap_as_fn___rmul__(*args):
return torch._C._TensorBase.__rmul__(*args)
def remap_as_fn___ror__(*args):
return torch._C._TensorBase.__ror__(*args)
def remap_as_fn___rxor__(*args):
return torch._C._TensorBase.__rxor__(*args)
def remap_as_fn___rand__(*args):
return torch._C._TensorBase.__rand__(*args)
tensor_dunder_fns_remap = {
torch._C._TensorBase.__radd__: remap_as_fn___radd__,
torch._C._TensorBase.__rmul__: remap_as_fn___rmul__,
torch._C._TensorBase.__ror__: remap_as_fn___ror__,
torch._C._TensorBase.__rxor__: remap_as_fn___rxor__,
torch._C._TensorBase.__rand__: remap_as_fn___rand__,
}
try:
# Wed need to monkeypatch transformers here, sadly.
# TODO(voz): Upstream to transformers lib
import transformers
def _dynamo_overriden_transformers_eq(self, other):
if not hasattr(other, "__dict__"):
return False
return self.__dict__ == other.__dict__
transformers.configuration_utils.PretrainedConfig.__eq__ = (
_dynamo_overriden_transformers_eq
)
except ImportError:
pass
class TorchVariable(VariableTracker):
"""Points to a module or method in torch.*"""
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
if value in tensor_dunder_fns_remap:
value = tensor_dunder_fns_remap[value]
self.value = value
# the remainder of this is just optional debug checks
try:
self_should_be_none = getattr(self.value, "__self__", None)
except RuntimeError as e:
assert "No such operator" in str(e), str(e)
self_should_be_none = None
# assert "_ntuple.<locals>.parse" not in str(value)
if self_should_be_none is None:
pass
elif isinstance(self_should_be_none, types.ModuleType):
# weird ones like torch.nn.functional.avg_pool2d have __self__
name = self_should_be_none.__name__
assert re.match(r"^(torch|math)([.]|$)", name), f"__self__ set to {name}"
elif isinstance(
self_should_be_none, type(torch._C._get_tracing_state.__self__)
):
# some _C functions have __self__ as a null capsule
pass
elif isinstance(self_should_be_none, torch_special_class_types):
pass
else:
raise AssertionError(f"{value} found with __self__ set")
def __repr__(self):
return f"TorchVariable({self.value})"
def unique_var_name(self):
name = torch_get_name(self.value, f"allowed_fn_{id(self.value)}")
return "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name)
def reconstruct(self, codegen):
return codegen.setup_globally_cached(self.unique_var_name(), self.value)
def as_proxy(self):
return self.value
def python_type(self):
if isinstance(self.value, (torch.Tensor, torch.nn.Module)):
return type(self.value)
return super().python_type()
def as_python_constant(self):
return self.value
def can_constant_fold_through(self):
if self.value in constant_fold_functions:
return True
return getattr(self.value, "__module__", None) == "math"
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import (
ConstantVariable,
CUDAStreamContextVariable,
CUDAStreamVariable,
GradModeVariable,
SymNodeVariable,
TensorVariable,
UserDefinedObjectVariable,
)
from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
constant_args = check_constant_args(args, kwargs)
unspec_python_args = check_unspec_python_args(args, kwargs)
options = VariableTracker.propagate(self, args, kwargs.values())
if self.value in config.constant_functions:
assert not args and not kwargs
return ConstantVariable(config.constant_functions[self.value], **options)
elif self.can_constant_fold_through() and (constant_args or unspec_python_args):
args, kwargs = specialize_args_kwargs(tx, args, kwargs)
# constant fold
return ConstantVariable(
self.as_python_constant()(
*[x.as_python_constant() for x in args],
**{k: v.as_python_constant() for k, v in kwargs.items()},
),
**options,
)
elif istype(self.value, type) and issubclass(self.value, torch.nn.Module):
if self.value is torch.nn.Softmax:
return self._call_softmax(tx, args, kwargs, options)
if self.value is torch.nn.CrossEntropyLoss:
return self._call_cross_entropy_loss(tx, args, kwargs, options)
else:
unimplemented(f"construct nn.Module: {self.value.__name__}")
elif self.value in (torch.is_tensor, torch.overrides.is_tensor_like):
assert len(args) == 1
if isinstance(args[0], TensorVariable) or (
self.value is torch.overrides.is_tensor_like
and isinstance(args[0], UserDefinedObjectVariable)
and hasattr(args[0].value, "__torch_function__")
):
return ConstantVariable(True, **options)
else:
return ConstantVariable(False, **options)
elif (
self.value
in (
torch.is_floating_point,
torch.is_complex,
)
and isinstance(args[0], TensorVariable)
and args[0].dtype is not None
):
if self.value is torch.is_floating_point:
return ConstantVariable(args[0].dtype.is_floating_point, **options)
elif self.value is torch.is_complex:
return ConstantVariable(args[0].dtype.is_complex, **options)
else:
raise AssertionError()
elif (
self.value is torch.numel
and isinstance(args[0], TensorVariable)
and args[0].size is not None
):
return ConstantVariable(product(args[0].size), **options)
elif self.value in REWRITE_OPS_TO_TENSOR_SIZE_METHOD:
assert len(args) == 1
assert isinstance(args[0], TensorVariable)
return args[0].call_method(tx, "size", [], {})
elif self.value in (
torch.nn.modules.utils._single,
torch.nn.modules.utils._pair,
torch.nn.modules.utils._triple,
torch.nn.modules.utils._quadruple,
torch.nn.modules.utils._ntuple,
):
return self._call_ntuple(tx, args, kwargs, options)
elif self.value is torch.no_grad:
return GradModeVariable.create(tx, False, **options)
elif self.value is torch.enable_grad:
return GradModeVariable.create(tx, True, **options)
elif self.value is torch.set_grad_enabled and len(args) == 1:
return GradModeVariable.create(tx, args[0].as_python_constant(), **options)
elif self.value is torch.is_grad_enabled:
assert not (args or kwargs)
return ConstantVariable(torch.is_grad_enabled(), **options).add_guards(
GradModeVariable._guards_singleton
)
elif self.value is torch.cuda.stream:
log.warning(
"torch.cuda.stream() not fully supported, streams may be ignored"
)
assert len(args) == 1
return CUDAStreamContextVariable.create(tx, args[0], **options)
elif self.value is torch.cuda.streams.Stream:
return wrap_fx_proxy_cls(
CUDAStreamVariable,
tx,
tx.output.create_proxy(
"call_function",
torch.cuda.streams.Stream,
(),
{},
),
**options,
)
elif not config.dynamic_shapes and self.is_dynamic_shapes(args, kwargs):
unimplemented(f"dynamic shapes: {self.value.__name__}")
elif len(args) > 0 and isinstance(args[0], TensorWithTFOverrideVariable):
# This code block implements inlining the __torch_function__
# override of a tensor.
tensor_with_tf_override = args[0]
# TODO(future PR): make this implement the full __torch_function__ API
# instead of assuming the relevant override is in the first argument.
args[0] = args[0].tensor_variable
unwrapped = TensorWithTFOverrideVariable.inline_torch_function_unwrapped(
tx,
self,
tensor_with_tf_override.orig_tensor_variable_source,
tensor_with_tf_override.subclass_torch_function__func,
tensor_with_tf_override.subclass_type,
options,
args,
kwargs,
)
# The wrapping here follows the logic in
# `torch.Tensor.__torch_function__`.
if self.value in torch.overrides.get_default_nowrap_functions():
return unwrapped
return TensorWithTFOverrideVariable(
unwrapped,
tensor_with_tf_override.orig_tensor_variable_source,
tensor_with_tf_override.subclass_torch_function__func,
tensor_with_tf_override.subclass_type,
)
elif self.value is torch.amp.autocast_mode.autocast:
return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
elif self.value in (
torch.profiler.profile,
torch.profiler.record_function,
torch.autograd.profiler.profile,
torch.autograd.profiler.record_function,
):
log.warning("Profiler will be ignored")
return NullContextVariable(**options)
elif self.value is torch.autograd._profiler_enabled:
unimplemented("torch.autograd._profiler_enabled not supported yet")
elif self.value is torch.jit.annotate:
assert len(args) == 2
return args[1]
elif self.value is torch.backends.cudnn.is_acceptable:
# is_acceptable(tensor) returns true if
# (a) tensor dtype/device are supported by cudnn
# (b) cudnn is available
# (c) some initialization has completed
# technically, it depends on some global state from (c) (torch.backends.cudnn.__cudnn_version)
Loading ...