Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ _dynamo / variables / lists.py

import functools
import operator
from typing import Dict, List, Optional

import torch
import torch.fx

from .. import config, variables
from ..bytecode_transformation import create_instruction
from ..exc import unimplemented
from ..source import GetItemSource
from ..utils import namedtuple_fields, proxy_args_kwargs
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable


class BaseListVariable(VariableTracker):
    @staticmethod
    def cls_for(obj):
        return {
            iter: ListIteratorVariable,
            list: ListVariable,
            slice: SliceVariable,
            torch.Size: SizeVariable,
            tuple: TupleVariable,
        }[obj]

    def __init__(
        self,
        items: List[VariableTracker],
        recursively_contains=None,
        regen_guards=True,
        **kwargs,
    ):
        super().__init__(recursively_contains=recursively_contains, **kwargs)
        assert isinstance(items, list)
        assert all(isinstance(x, VariableTracker) for x in items)

        # Sometimes, we know that we have passed in the guards from the items in the list
        if regen_guards:
            self.guards.update(VariableTracker.propagate(items)["guards"])

        self.items: List[VariableTracker] = items

    def _as_proxy(self):
        return [x.as_proxy() for x in self.items]

    def as_python_constant(self):
        return self.python_type()([x.as_python_constant() for x in self.items])

    def as_proxy(self):
        assert self.python_type() is not SizeVariable
        return self.python_type()(self._as_proxy())

    def getitem_const(self, arg: VariableTracker):
        index = arg.as_python_constant()
        if isinstance(index, slice):
            if self.source is not None:
                return self.clone(
                    items=self.items[index],
                    source=GetItemSource(self.source, index),
                    mutable_local=MutableLocal() if self.mutable_local else None,
                ).add_options(arg, self)
            else:
                return self.clone(
                    items=self.items[index],
                    mutable_local=MutableLocal() if self.mutable_local else None,
                ).add_options(arg, self)
        else:
            assert isinstance(index, int)
            return self.items[index].add_options(arg, self)

    def unpack_var_sequence(self, tx):
        return [x.add_options(self) for x in self.items]

    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        options = VariableTracker.propagate(self, args, kwargs.values())
        if name == "__getitem__":
            assert not kwargs and len(args) == 1
            return self.getitem_const(args[0])
        elif (
            name == "__contains__"
            and len(args) == 1
            and args[0].is_python_constant()
            and all(x.is_python_constant() for x in self.items)
        ):
            assert not kwargs
            search = args[0].as_python_constant()
            result = any(x.as_python_constant() == search for x in self.items)
            return variables.ConstantVariable(result, **options)

        return super().call_method(tx, name, args, kwargs)

    @staticmethod
    def list_compare(tx, op, left, right):
        from .builtin import BuiltinVariable

        eq_result = BaseListVariable.list_eq(tx, left, right)
        if op is operator.eq:
            return eq_result
        elif op is operator.ne:
            return BuiltinVariable(operator.not_).call_function(tx, [eq_result], {})
        else:
            unimplemented(f"list_compare {left} {op} {right}")

    @staticmethod
    def list_eq(tx, left, right):
        from .builtin import BuiltinVariable

        options = VariableTracker.propagate(left, right)

        # Most list-like variables implement comparison ops the same way,
        # so they can re-use this helper.
        # There are quirks though, like how `tuple([2]) == torch.Size([2])`,
        # but `tuple([2]) != list([2])`
        if len(left.items) != len(right.items):
            return ConstantVariable(False, **options)
        if len(left.items) == 0:
            return ConstantVariable(True, **options)

        # Generic list comparison works by iterating over left aka self and right the compared-to list.
        # If we hit here, their lengths are the same and they cannot be expressed as python constants.
        # So, we iterate over the zipped list items.
        comps = []
        for l, r in zip(left.items, right.items):
            comp = BuiltinVariable(operator.eq).call_function(tx, [l, r], {})
            if comp.is_python_constant() and not comp.as_python_constant():
                # early exit in false case
                return comp.add_options(options)
            comps.append(comp)

        return functools.reduce(
            lambda a, b: BuiltinVariable(operator.and_).call_function(tx, [a, b], {}),
            comps,
        ).add_options(options)


class RangeVariable(BaseListVariable):
    def __init__(self, items, **kwargs):
        items_to_map = items
        start = variables.ConstantVariable(0)
        stop = None
        step = variables.ConstantVariable(1)

        if len(items_to_map) == 1:
            (stop,) = items_to_map
        elif len(items_to_map) == 2:
            start, stop = items_to_map
        elif len(items_to_map) == 3:
            start, stop, step = items_to_map
        else:
            raise AssertionError()

        assert stop is not None
        super().__init__([start, stop, step], **kwargs)

    def python_type(self):
        return range

    def as_python_constant(self):
        return range(*[x.as_python_constant() for x in self.items])

    def as_proxy(self):
        return self.python_type()(*self._as_proxy())

    def unpack_var_sequence(self, tx):
        return [
            variables.ConstantVariable(x).add_options(self)
            for x in self.as_python_constant()
        ]

    def reconstruct(self, codegen):
        assert "range" not in codegen.tx.f_globals
        codegen.append_output(codegen.create_load_python_module(range))
        codegen.foreach(self.items)
        return [create_instruction("CALL_FUNCTION", 3)]

    def var_getattr(self, tx, name):
        fields = ["start", "stop", "step"]
        if name not in fields:
            unimplemented(f"range.{name}")
        return self.items[fields.index(name)].add_options(self)


class ListVariable(BaseListVariable):
    def python_type(self):
        return list

    def reconstruct(self, codegen):
        codegen.foreach(self.items)
        return [create_instruction("BUILD_LIST", len(self.items))]

    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        options = VariableTracker.propagate(self, args, kwargs.values())
        if name == "append" and self.mutable_local:
            assert not kwargs
            (arg,) = args
            new_rec_contains = self.recursively_contains.union(arg.recursively_contains)
            if arg.mutable_local is not None:
                new_rec_contains.add(arg.mutable_local)
            tx.replace_all(
                self,
                ListVariable(
                    self.items + [arg],
                    recursively_contains=new_rec_contains,
                    regen_guards=False,
                    **options,
                ),
            )
            return ConstantVariable(None)
        elif (
            name == "extend"
            and self.mutable_local
            and args
            and args[0].has_unpack_var_sequence(tx)
        ):
            assert not kwargs
            (arg,) = args
            return tx.replace_all(
                self,
                ListVariable(
                    list(self.items) + list(arg.unpack_var_sequence(tx)),
                    regen_guards=False,
                    **options,
                ),
            )
        elif name == "insert" and self.mutable_local:
            assert not kwargs
            idx, value = args
            items = list(self.items)
            items.insert(idx.as_python_constant(), value)
            return tx.replace_all(
                self,
                ListVariable(items, regen_guards=False, **options),
            )
        elif name == "pop" and self.mutable_local:
            assert not kwargs
            items = list(self.items)
            result = items.pop(*[a.as_python_constant() for a in args])
            tx.replace_all(
                self,
                ListVariable(items, regen_guards=False, **options),
            )
            return result
        elif name == "clear" and self.mutable_local:
            assert not kwargs and not args
            return tx.replace_all(
                self,
                ListVariable([], regen_guards=False, **options),
            )
        elif (
            name == "__setitem__"
            and self.mutable_local
            and args
            and args[0].is_python_constant()
        ):
            assert not kwargs
            key, value = args
            items = list(self.items)
            if isinstance(key, SliceVariable):
                items[key.as_python_constant()] = list(value.items)
            else:
                items[key.as_python_constant()] = value
            result = ListVariable(items, regen_guards=False, **options)
            return tx.replace_all(self, result)
        else:
            return super().call_method(tx, name, args, kwargs)


class TupleVariable(BaseListVariable):
    def python_type(self):
        return tuple

    def reconstruct(self, codegen):
        codegen.foreach(self.items)
        return [create_instruction("BUILD_TUPLE", len(self.items))]

    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        return super().call_method(tx, name, args, kwargs)


class SizeVariable(TupleVariable):
    """torch.Size(...)"""

    def __init__(
        self,
        items: List[VariableTracker],
        proxy: Optional[torch.fx.Proxy] = None,
        **kwargs,
    ):
        self.proxy = proxy
        super().__init__(items, **kwargs)

    def python_type(self):
        return torch.Size

    def as_proxy(self):
        if self.proxy is not None:
            return self.proxy

        # torch.Size needs special handling.  Normally, we pun a list-like
        # container to directly contain Proxy/Node objects from FX, and FX
        # knows to look inside containers (via map_aggregate).  But torch.Size
        # is weird; although it subclasses from tuple, it doesn't allow
        # members which aren't int-like (rejecting Proxy and Node).  This
        # means we can't use the normal representation trick
        # torch.Size([proxy0, proxy1]).  I looked into seeing if I could
        # relax torch.Size in PyTorch proper, but if torch.Size constructor
        # sees a type that it doesn't recognize, it will try to call
        # __index__() on it, so there is no BC way to actually change this
        # behavior (though it occurs to me that I could have just added a
        # YOLO no checking alternate constructor.)
        #
        # To work around this problem, I represent a torch.Size proxy as
        # a straight up proxy, that would have been constructed by taking
        # the constituent proxies as arguments.  This trick can be generally
        # used for any construct that we need a proxy for but we can't
        # directly represent as an aggregate; I don't see very many examples
        # of this in torchdynamo though!

        # Look for a proxy.  If there are none, do the legacy behavior
        tracer = None
        proxies = self._as_proxy()
        for proxy in proxies:
            if isinstance(proxy, torch.fx.Proxy):
                tracer = proxy.tracer
                break
Loading ...