import functools
import inspect
import itertools
import logging
import math
import operator
import types
from typing import Dict, List
import torch
from torch import sym_float, sym_int
from .. import config, variables
from ..allowed_functions import is_allowed
from ..exc import unimplemented, Unsupported
from ..guards import GuardBuilder
from ..replay_record import DummyModule
from ..source import AttrSource, is_constant_source, SuperSource, TypeSource
from ..utils import (
check_constant_args,
check_unspec_python_args,
istype,
proxy_args_kwargs,
specialize_args_kwargs,
)
from .base import MutableLocal, typestr, VariableTracker
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import BaseListVariable, ListVariable, TupleIteratorVariable, TupleVariable
from .tensor import FakeItemVariable, SymNodeVariable, UnspecializedPythonVariable
from .user_defined import UserDefinedVariable
log = logging.getLogger(__name__)
class BuiltinVariable(VariableTracker):
@staticmethod
@functools.lru_cache(None)
def _constant_fold_functions():
fns = {
abs,
all,
any,
bool,
callable,
chr,
dict,
divmod,
float,
int,
len,
list,
max,
min,
ord,
pow,
repr,
round,
set,
str,
str.format,
sum,
tuple,
type,
operator.pos,
operator.neg,
operator.not_,
operator.invert,
operator.pow,
operator.mul,
operator.matmul,
operator.floordiv,
operator.truediv,
operator.mod,
operator.add,
operator.sub,
operator.getitem,
operator.lshift,
operator.rshift,
operator.and_,
operator.or_,
operator.xor,
operator.ipow,
operator.imul,
operator.imatmul,
operator.ifloordiv,
operator.itruediv,
operator.imod,
operator.iadd,
operator.isub,
operator.ilshift,
operator.irshift,
operator.iand,
operator.ixor,
operator.ior,
operator.index,
}
fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
return fns
def can_constant_fold_through(self):
return self.fn in self._constant_fold_functions()
@staticmethod
@functools.lru_cache(None)
def _fx_graph_functions():
fns = {
operator.pos,
operator.neg,
operator.not_,
operator.invert,
operator.pow,
operator.mul,
operator.matmul,
operator.floordiv,
operator.truediv,
operator.mod,
operator.add,
operator.sub,
operator.getitem,
operator.lshift,
operator.rshift,
operator.and_,
operator.or_,
operator.xor,
operator.ipow,
operator.imul,
operator.imatmul,
operator.ifloordiv,
operator.itruediv,
operator.imod,
operator.iadd,
operator.isub,
operator.ilshift,
operator.irshift,
operator.iand,
operator.ixor,
operator.ior,
}
return fns
@staticmethod
@functools.lru_cache(None)
def _reversible_binops():
# function -> (forward magic method name, reverse magic method name)
fns = {
operator.add: ("__add__", "__radd__"),
operator.sub: ("__sub__", "__rsub__"),
operator.mul: ("__mul__", "__rmul__"),
operator.truediv: ("__truediv__", "__rtruediv__"),
operator.floordiv: ("__floordiv__", "__rfloordiv__"),
operator.mod: ("__mod__", "__rmod__"),
pow: ("__pow__", "__rpow__"),
operator.pow: ("__pow__", "__rpow__"),
# Don't support these for now, since the corresponding reverse magic methods
# aren't defined on SymInt / SymFloat.
# operator.matmul: ("__matmul__", "__rmatmul__"),
# divmod: ("__divmod__", "__rdivmod__"),
# operator.lshift: ("__lshift__", "__rlshift__"),
# operator.rshift: ("__rshift__", "__rrshift__"),
# operator.and_: ("__and__", "__rand__"),
# operator.or_: ("__or__", "__ror__"),
# operator.xor: ("__xor__", "__rxor__"),
}
return fns
@staticmethod
@functools.lru_cache(None)
def _inplace_binops():
fns = {
operator.ipow: "__ipow__",
operator.imul: "__imul__",
operator.imatmul: "__imatmul__",
operator.ifloordiv: "__ifloordiv__",
operator.itruediv: "__itruediv__",
operator.imod: "__imod__",
operator.iadd: "__iadd__",
operator.iconcat: "__iconcat__",
operator.isub: "__isub__",
operator.ilshift: "__ilshift__",
operator.irshift: "__irshift__",
operator.iand: "__iand__",
operator.ixor: "__ixor__",
operator.ior: "__ior__",
}
return fns
@staticmethod
@functools.lru_cache(None)
def _binop_handlers():
# Multiple dispatch mechanism defining custom binop behavior for certain type
# combinations. Handlers are attempted in order, and will be used if the type checks
# match. They are expected to have the signature:
# fn(tx, arg0: VariableTracker, arg1: VariableTracker, options) -> VariableTracker
# Override table contains: op_fn -> [list of handlers]
op_handlers = {}
for op, magic_method_names in itertools.chain(
BuiltinVariable._inplace_binops().items(),
BuiltinVariable._reversible_binops().items(),
):
handlers = []
# User-defined args (highest precedence)
if isinstance(magic_method_names, tuple):
# Reversible binary ops have forward / backward magic methods
forward_name, reverse_name = magic_method_names
def user_defined_handler(
tx,
a,
b,
options,
forward_name=forward_name,
reverse_name=reverse_name,
):
# Manually handle reversing logic if needed (e.g. call __radd__)
# TODO: If we expand this to handle tensor args, we need to manually
# handle cases like this:
#
# class A(int):
# def __radd__(self, other):
# print("woof")
# torch.randn(3) + A(3)
#
# In this example, A.__radd__() is not called -> nothing is printed, because
# Tensor.__add__ only does a subtype test against int, ignoring the subclass.
# To be fully correct, we should not call A.__radd__() here, and there may be
# other cases to reason about and add exceptions for.
if isinstance(a, UserDefinedVariable):
return a.call_method(tx, forward_name, [b], {})
else:
return b.call_method(tx, reverse_name, [a], {})
else:
forward_name = magic_method_names
def user_defined_handler(tx, a, b, options, forward_name=forward_name):
return a.call_method(tx, forward_name, [b], {})
handlers.append(
((UserDefinedVariable, VariableTracker), user_defined_handler)
)
handlers.append(
((VariableTracker, UserDefinedVariable), user_defined_handler)
)
# Dynamic shape args
def dynamic_handler(tx, a, b, options, fn=op):
from .builder import wrap_fx_proxy
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_function", fn, *proxy_args_kwargs([a, b], {})
),
**options,
)
handlers.append(((SymNodeVariable, VariableTracker), dynamic_handler))
handlers.append(((VariableTracker, SymNodeVariable), dynamic_handler))
op_handlers[op] = handlers
# Special cases - lower precedence but still prefer these over constant folding
# List-like addition (e.g. [1, 2] + [3, 4])
def tuple_add_handler(tx, a, b, options):
return TupleVariable(a.items + list(b.unpack_var_sequence(tx)), **options)
list_like_addition_handlers = [
# NB: Prefer the tuple-specific logic over base logic because of
# some SizeVariable weirdness. Specifically, the tuple-specific logic
# drops the subclass type (e.g. SizeVariable) and returns TupleVariables.
(
(TupleVariable, TupleVariable),
tuple_add_handler,
),
(
(TupleVariable, ConstantVariable),
tuple_add_handler,
),
(
(ConstantVariable, TupleVariable),
lambda tx, a, b, options: TupleVariable(
list(a.unpack_var_sequence(tx)) + b.items, **options
),
),
(
(BaseListVariable, BaseListVariable),
lambda tx, a, b, options: type(a)(a.items + b.items, **options),
),
]
op_handlers[operator.add].extend(list_like_addition_handlers)
def list_iadd_handler(tx, a, b, options):
if not a.mutable_local or not b.has_unpack_var_sequence(tx):
# Handler doesn't apply
return None
return tx.replace_all(
a,
ListVariable(
list(a.items) + list(b.unpack_var_sequence(tx)),
regen_guards=False,
**options,
),
)
list_like_iadd_handlers = [
(
(ListVariable, VariableTracker),
list_iadd_handler,
),
(
(TupleVariable, TupleVariable),
tuple_add_handler,
),
(
(TupleVariable, ConstantVariable),
tuple_add_handler,
),
]
op_handlers[operator.iadd].extend(list_like_iadd_handlers)
# List-like expansion (e.g. [1, 2, 3] * 3)
def expand_list_like(tx, lst, const, options):
return lst.__class__(
items=lst.items * const.as_python_constant(),
mutable_local=MutableLocal(),
**options,
)
list_like_expansion_handlers = [
((ListVariable, ConstantVariable), expand_list_like),
((TupleVariable, ConstantVariable), expand_list_like),
(
(ConstantVariable, ListVariable),
lambda tx, a, b, options: expand_list_like(tx, b, a, options),
),
(
(ConstantVariable, TupleVariable),
lambda tx, a, b, options: expand_list_like(tx, b, a, options),
),
Loading ...