Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ _dynamo / variables / builtin.py

import functools
import inspect
import itertools
import logging
import math
import operator
import types
from typing import Dict, List

import torch
from torch import sym_float, sym_int

from .. import config, variables
from ..allowed_functions import is_allowed
from ..exc import unimplemented, Unsupported
from ..guards import GuardBuilder
from ..replay_record import DummyModule
from ..source import AttrSource, is_constant_source, SuperSource, TypeSource
from ..utils import (
    check_constant_args,
    check_unspec_python_args,
    istype,
    proxy_args_kwargs,
    specialize_args_kwargs,
)
from .base import MutableLocal, typestr, VariableTracker
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import BaseListVariable, ListVariable, TupleIteratorVariable, TupleVariable
from .tensor import FakeItemVariable, SymNodeVariable, UnspecializedPythonVariable
from .user_defined import UserDefinedVariable

log = logging.getLogger(__name__)


class BuiltinVariable(VariableTracker):
    @staticmethod
    @functools.lru_cache(None)
    def _constant_fold_functions():
        fns = {
            abs,
            all,
            any,
            bool,
            callable,
            chr,
            dict,
            divmod,
            float,
            int,
            len,
            list,
            max,
            min,
            ord,
            pow,
            repr,
            round,
            set,
            str,
            str.format,
            sum,
            tuple,
            type,
            operator.pos,
            operator.neg,
            operator.not_,
            operator.invert,
            operator.pow,
            operator.mul,
            operator.matmul,
            operator.floordiv,
            operator.truediv,
            operator.mod,
            operator.add,
            operator.sub,
            operator.getitem,
            operator.lshift,
            operator.rshift,
            operator.and_,
            operator.or_,
            operator.xor,
            operator.ipow,
            operator.imul,
            operator.imatmul,
            operator.ifloordiv,
            operator.itruediv,
            operator.imod,
            operator.iadd,
            operator.isub,
            operator.ilshift,
            operator.irshift,
            operator.iand,
            operator.ixor,
            operator.ior,
            operator.index,
        }
        fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
        return fns

    def can_constant_fold_through(self):
        return self.fn in self._constant_fold_functions()

    @staticmethod
    @functools.lru_cache(None)
    def _fx_graph_functions():
        fns = {
            operator.pos,
            operator.neg,
            operator.not_,
            operator.invert,
            operator.pow,
            operator.mul,
            operator.matmul,
            operator.floordiv,
            operator.truediv,
            operator.mod,
            operator.add,
            operator.sub,
            operator.getitem,
            operator.lshift,
            operator.rshift,
            operator.and_,
            operator.or_,
            operator.xor,
            operator.ipow,
            operator.imul,
            operator.imatmul,
            operator.ifloordiv,
            operator.itruediv,
            operator.imod,
            operator.iadd,
            operator.isub,
            operator.ilshift,
            operator.irshift,
            operator.iand,
            operator.ixor,
            operator.ior,
        }
        return fns

    @staticmethod
    @functools.lru_cache(None)
    def _reversible_binops():
        # function -> (forward magic method name, reverse magic method name)
        fns = {
            operator.add: ("__add__", "__radd__"),
            operator.sub: ("__sub__", "__rsub__"),
            operator.mul: ("__mul__", "__rmul__"),
            operator.truediv: ("__truediv__", "__rtruediv__"),
            operator.floordiv: ("__floordiv__", "__rfloordiv__"),
            operator.mod: ("__mod__", "__rmod__"),
            pow: ("__pow__", "__rpow__"),
            operator.pow: ("__pow__", "__rpow__"),
            # Don't support these for now, since the corresponding reverse magic methods
            # aren't defined on SymInt / SymFloat.
            # operator.matmul: ("__matmul__", "__rmatmul__"),
            # divmod: ("__divmod__", "__rdivmod__"),
            # operator.lshift: ("__lshift__", "__rlshift__"),
            # operator.rshift: ("__rshift__", "__rrshift__"),
            # operator.and_: ("__and__", "__rand__"),
            # operator.or_: ("__or__", "__ror__"),
            # operator.xor: ("__xor__", "__rxor__"),
        }
        return fns

    @staticmethod
    @functools.lru_cache(None)
    def _inplace_binops():
        fns = {
            operator.ipow: "__ipow__",
            operator.imul: "__imul__",
            operator.imatmul: "__imatmul__",
            operator.ifloordiv: "__ifloordiv__",
            operator.itruediv: "__itruediv__",
            operator.imod: "__imod__",
            operator.iadd: "__iadd__",
            operator.iconcat: "__iconcat__",
            operator.isub: "__isub__",
            operator.ilshift: "__ilshift__",
            operator.irshift: "__irshift__",
            operator.iand: "__iand__",
            operator.ixor: "__ixor__",
            operator.ior: "__ior__",
        }
        return fns

    @staticmethod
    @functools.lru_cache(None)
    def _binop_handlers():
        # Multiple dispatch mechanism defining custom binop behavior for certain type
        # combinations. Handlers are attempted in order, and will be used if the type checks
        # match. They are expected to have the signature:
        # fn(tx, arg0: VariableTracker, arg1: VariableTracker, options) -> VariableTracker

        # Override table contains: op_fn -> [list of handlers]
        op_handlers = {}
        for op, magic_method_names in itertools.chain(
            BuiltinVariable._inplace_binops().items(),
            BuiltinVariable._reversible_binops().items(),
        ):
            handlers = []

            # User-defined args (highest precedence)
            if isinstance(magic_method_names, tuple):
                # Reversible binary ops have forward / backward magic methods
                forward_name, reverse_name = magic_method_names

                def user_defined_handler(
                    tx,
                    a,
                    b,
                    options,
                    forward_name=forward_name,
                    reverse_name=reverse_name,
                ):
                    # Manually handle reversing logic if needed (e.g. call __radd__)

                    # TODO: If we expand this to handle tensor args, we need to manually
                    # handle cases like this:
                    #
                    # class A(int):
                    #     def __radd__(self, other):
                    #         print("woof")
                    # torch.randn(3) + A(3)
                    #
                    # In this example, A.__radd__() is not called -> nothing is printed, because
                    # Tensor.__add__ only does a subtype test against int, ignoring the subclass.
                    # To be fully correct, we should not call A.__radd__() here, and there may be
                    # other cases to reason about and add exceptions for.
                    if isinstance(a, UserDefinedVariable):
                        return a.call_method(tx, forward_name, [b], {})
                    else:
                        return b.call_method(tx, reverse_name, [a], {})

            else:
                forward_name = magic_method_names

                def user_defined_handler(tx, a, b, options, forward_name=forward_name):
                    return a.call_method(tx, forward_name, [b], {})

            handlers.append(
                ((UserDefinedVariable, VariableTracker), user_defined_handler)
            )
            handlers.append(
                ((VariableTracker, UserDefinedVariable), user_defined_handler)
            )

            # Dynamic shape args
            def dynamic_handler(tx, a, b, options, fn=op):
                from .builder import wrap_fx_proxy

                return wrap_fx_proxy(
                    tx,
                    tx.output.create_proxy(
                        "call_function", fn, *proxy_args_kwargs([a, b], {})
                    ),
                    **options,
                )

            handlers.append(((SymNodeVariable, VariableTracker), dynamic_handler))
            handlers.append(((VariableTracker, SymNodeVariable), dynamic_handler))

            op_handlers[op] = handlers

        # Special cases - lower precedence but still prefer these over constant folding

        # List-like addition (e.g. [1, 2] + [3, 4])
        def tuple_add_handler(tx, a, b, options):
            return TupleVariable(a.items + list(b.unpack_var_sequence(tx)), **options)

        list_like_addition_handlers = [
            # NB: Prefer the tuple-specific logic over base logic because of
            # some SizeVariable weirdness. Specifically, the tuple-specific logic
            # drops the subclass type (e.g. SizeVariable) and returns TupleVariables.
            (
                (TupleVariable, TupleVariable),
                tuple_add_handler,
            ),
            (
                (TupleVariable, ConstantVariable),
                tuple_add_handler,
            ),
            (
                (ConstantVariable, TupleVariable),
                lambda tx, a, b, options: TupleVariable(
                    list(a.unpack_var_sequence(tx)) + b.items, **options
                ),
            ),
            (
                (BaseListVariable, BaseListVariable),
                lambda tx, a, b, options: type(a)(a.items + b.items, **options),
            ),
        ]
        op_handlers[operator.add].extend(list_like_addition_handlers)

        def list_iadd_handler(tx, a, b, options):
            if not a.mutable_local or not b.has_unpack_var_sequence(tx):
                # Handler doesn't apply
                return None

            return tx.replace_all(
                a,
                ListVariable(
                    list(a.items) + list(b.unpack_var_sequence(tx)),
                    regen_guards=False,
                    **options,
                ),
            )

        list_like_iadd_handlers = [
            (
                (ListVariable, VariableTracker),
                list_iadd_handler,
            ),
            (
                (TupleVariable, TupleVariable),
                tuple_add_handler,
            ),
            (
                (TupleVariable, ConstantVariable),
                tuple_add_handler,
            ),
        ]
        op_handlers[operator.iadd].extend(list_like_iadd_handlers)

        # List-like expansion (e.g. [1, 2, 3] * 3)
        def expand_list_like(tx, lst, const, options):
            return lst.__class__(
                items=lst.items * const.as_python_constant(),
                mutable_local=MutableLocal(),
                **options,
            )

        list_like_expansion_handlers = [
            ((ListVariable, ConstantVariable), expand_list_like),
            ((TupleVariable, ConstantVariable), expand_list_like),
            (
                (ConstantVariable, ListVariable),
                lambda tx, a, b, options: expand_list_like(tx, b, a, options),
            ),
            (
                (ConstantVariable, TupleVariable),
                lambda tx, a, b, options: expand_list_like(tx, b, a, options),
            ),
Loading ...