import functools
import operator
from typing import Dict, List, Optional
import torch
import torch.fx
from .. import config, variables
from ..bytecode_transformation import create_instruction
from ..exc import unimplemented
from ..source import GetItemSource
from ..utils import namedtuple_fields, proxy_args_kwargs
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
class BaseListVariable(VariableTracker):
@staticmethod
def cls_for(obj):
return {
iter: ListIteratorVariable,
list: ListVariable,
slice: SliceVariable,
torch.Size: SizeVariable,
tuple: TupleVariable,
}[obj]
def __init__(
self,
items: List[VariableTracker],
recursively_contains=None,
regen_guards=True,
**kwargs,
):
super().__init__(recursively_contains=recursively_contains, **kwargs)
assert isinstance(items, list)
assert all(isinstance(x, VariableTracker) for x in items)
# Sometimes, we know that we have passed in the guards from the items in the list
if regen_guards:
self.guards.update(VariableTracker.propagate(items)["guards"])
self.items: List[VariableTracker] = items
def _as_proxy(self):
return [x.as_proxy() for x in self.items]
def as_python_constant(self):
return self.python_type()([x.as_python_constant() for x in self.items])
def as_proxy(self):
assert self.python_type() is not SizeVariable
return self.python_type()(self._as_proxy())
def getitem_const(self, arg: VariableTracker):
index = arg.as_python_constant()
if isinstance(index, slice):
if self.source is not None:
return self.clone(
items=self.items[index],
source=GetItemSource(self.source, index),
mutable_local=MutableLocal() if self.mutable_local else None,
).add_options(arg, self)
else:
return self.clone(
items=self.items[index],
mutable_local=MutableLocal() if self.mutable_local else None,
).add_options(arg, self)
else:
assert isinstance(index, int)
return self.items[index].add_options(arg, self)
def unpack_var_sequence(self, tx):
return [x.add_options(self) for x in self.items]
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
options = VariableTracker.propagate(self, args, kwargs.values())
if name == "__getitem__":
assert not kwargs and len(args) == 1
return self.getitem_const(args[0])
elif (
name == "__contains__"
and len(args) == 1
and args[0].is_python_constant()
and all(x.is_python_constant() for x in self.items)
):
assert not kwargs
search = args[0].as_python_constant()
result = any(x.as_python_constant() == search for x in self.items)
return variables.ConstantVariable(result, **options)
return super().call_method(tx, name, args, kwargs)
@staticmethod
def list_compare(tx, op, left, right):
from .builtin import BuiltinVariable
eq_result = BaseListVariable.list_eq(tx, left, right)
if op is operator.eq:
return eq_result
elif op is operator.ne:
return BuiltinVariable(operator.not_).call_function(tx, [eq_result], {})
else:
unimplemented(f"list_compare {left} {op} {right}")
@staticmethod
def list_eq(tx, left, right):
from .builtin import BuiltinVariable
options = VariableTracker.propagate(left, right)
# Most list-like variables implement comparison ops the same way,
# so they can re-use this helper.
# There are quirks though, like how `tuple([2]) == torch.Size([2])`,
# but `tuple([2]) != list([2])`
if len(left.items) != len(right.items):
return ConstantVariable(False, **options)
if len(left.items) == 0:
return ConstantVariable(True, **options)
# Generic list comparison works by iterating over left aka self and right the compared-to list.
# If we hit here, their lengths are the same and they cannot be expressed as python constants.
# So, we iterate over the zipped list items.
comps = []
for l, r in zip(left.items, right.items):
comp = BuiltinVariable(operator.eq).call_function(tx, [l, r], {})
if comp.is_python_constant() and not comp.as_python_constant():
# early exit in false case
return comp.add_options(options)
comps.append(comp)
return functools.reduce(
lambda a, b: BuiltinVariable(operator.and_).call_function(tx, [a, b], {}),
comps,
).add_options(options)
class RangeVariable(BaseListVariable):
def __init__(self, items, **kwargs):
items_to_map = items
start = variables.ConstantVariable(0)
stop = None
step = variables.ConstantVariable(1)
if len(items_to_map) == 1:
(stop,) = items_to_map
elif len(items_to_map) == 2:
start, stop = items_to_map
elif len(items_to_map) == 3:
start, stop, step = items_to_map
else:
raise AssertionError()
assert stop is not None
super().__init__([start, stop, step], **kwargs)
def python_type(self):
return range
def as_python_constant(self):
return range(*[x.as_python_constant() for x in self.items])
def as_proxy(self):
return self.python_type()(*self._as_proxy())
def unpack_var_sequence(self, tx):
return [
variables.ConstantVariable(x).add_options(self)
for x in self.as_python_constant()
]
def reconstruct(self, codegen):
assert "range" not in codegen.tx.f_globals
codegen.append_output(codegen.create_load_python_module(range))
codegen.foreach(self.items)
return [create_instruction("CALL_FUNCTION", 3)]
def var_getattr(self, tx, name):
fields = ["start", "stop", "step"]
if name not in fields:
unimplemented(f"range.{name}")
return self.items[fields.index(name)].add_options(self)
class ListVariable(BaseListVariable):
def python_type(self):
return list
def reconstruct(self, codegen):
codegen.foreach(self.items)
return [create_instruction("BUILD_LIST", len(self.items))]
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
options = VariableTracker.propagate(self, args, kwargs.values())
if name == "append" and self.mutable_local:
assert not kwargs
(arg,) = args
new_rec_contains = self.recursively_contains.union(arg.recursively_contains)
if arg.mutable_local is not None:
new_rec_contains.add(arg.mutable_local)
tx.replace_all(
self,
ListVariable(
self.items + [arg],
recursively_contains=new_rec_contains,
regen_guards=False,
**options,
),
)
return ConstantVariable(None)
elif (
name == "extend"
and self.mutable_local
and args
and args[0].has_unpack_var_sequence(tx)
):
assert not kwargs
(arg,) = args
return tx.replace_all(
self,
ListVariable(
list(self.items) + list(arg.unpack_var_sequence(tx)),
regen_guards=False,
**options,
),
)
elif name == "insert" and self.mutable_local:
assert not kwargs
idx, value = args
items = list(self.items)
items.insert(idx.as_python_constant(), value)
return tx.replace_all(
self,
ListVariable(items, regen_guards=False, **options),
)
elif name == "pop" and self.mutable_local:
assert not kwargs
items = list(self.items)
result = items.pop(*[a.as_python_constant() for a in args])
tx.replace_all(
self,
ListVariable(items, regen_guards=False, **options),
)
return result
elif name == "clear" and self.mutable_local:
assert not kwargs and not args
return tx.replace_all(
self,
ListVariable([], regen_guards=False, **options),
)
elif (
name == "__setitem__"
and self.mutable_local
and args
and args[0].is_python_constant()
):
assert not kwargs
key, value = args
items = list(self.items)
if isinstance(key, SliceVariable):
items[key.as_python_constant()] = list(value.items)
else:
items[key.as_python_constant()] = value
result = ListVariable(items, regen_guards=False, **options)
return tx.replace_all(self, result)
else:
return super().call_method(tx, name, args, kwargs)
class TupleVariable(BaseListVariable):
def python_type(self):
return tuple
def reconstruct(self, codegen):
codegen.foreach(self.items)
return [create_instruction("BUILD_TUPLE", len(self.items))]
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
return super().call_method(tx, name, args, kwargs)
class SizeVariable(TupleVariable):
"""torch.Size(...)"""
def __init__(
self,
items: List[VariableTracker],
proxy: Optional[torch.fx.Proxy] = None,
**kwargs,
):
self.proxy = proxy
super().__init__(items, **kwargs)
def python_type(self):
return torch.Size
def as_proxy(self):
if self.proxy is not None:
return self.proxy
# torch.Size needs special handling. Normally, we pun a list-like
# container to directly contain Proxy/Node objects from FX, and FX
# knows to look inside containers (via map_aggregate). But torch.Size
# is weird; although it subclasses from tuple, it doesn't allow
# members which aren't int-like (rejecting Proxy and Node). This
# means we can't use the normal representation trick
# torch.Size([proxy0, proxy1]). I looked into seeing if I could
# relax torch.Size in PyTorch proper, but if torch.Size constructor
# sees a type that it doesn't recognize, it will try to call
# __index__() on it, so there is no BC way to actually change this
# behavior (though it occurs to me that I could have just added a
# YOLO no checking alternate constructor.)
#
# To work around this problem, I represent a torch.Size proxy as
# a straight up proxy, that would have been constructed by taking
# the constituent proxies as arguments. This trick can be generally
# used for any construct that we need a proxy for but we can't
# directly represent as an aggregate; I don't see very many examples
# of this in torchdynamo though!
# Look for a proxy. If there are none, do the legacy behavior
tracer = None
proxies = self._as_proxy()
for proxy in proxies:
if isinstance(proxy, torch.fx.Proxy):
tracer = proxy.tracer
break
Loading ...