Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ _dynamo / variables / constant.py

import operator
from typing import Dict, List

import torch

from .. import variables
from ..exc import unimplemented
from ..utils import HAS_NUMPY, istype, np
from .base import typestr, VariableTracker


class ConstantVariable(VariableTracker):
    def __init__(self, value, **kwargs):
        super().__init__(**kwargs)
        assert not isinstance(value, torch.Tensor)
        assert not isinstance(value, torch.SymInt)
        assert not isinstance(value, torch.SymFloat)

        if HAS_NUMPY and isinstance(value, np.number):
            self.value = value.item()
        else:
            self.value = value

    def as_proxy(self):
        return self.value

    def __str__(self):
        # return f"ConstantVariable({self.value})"
        return f"ConstantVariable({type(self.value).__name__})"

    def python_type(self):
        return type(self.value)

    def as_python_constant(self):
        return self.value

    @property
    def items(self):
        """
        Need this when adding a BaseListVariable and a ConstantVariable together.
        Happens in detectron2.
        """
        return self.unpack_var_sequence(tx=None)

    def getitem_const(self, arg: VariableTracker):
        return ConstantVariable(
            self.value[arg.as_python_constant()],
            **VariableTracker.propagate([self, arg]),
        )

    @staticmethod
    def is_literal(obj):
        if type(obj) in (int, float, bool, type(None), str):
            return True
        if type(obj) in (list, tuple, set, frozenset):
            return all(ConstantVariable.is_literal(x) for x in obj)
        return False

    def unpack_var_sequence(self, tx):
        try:
            options = VariableTracker.propagate([self])
            return [ConstantVariable(x, **options) for x in self.as_python_constant()]
        except TypeError as e:
            raise NotImplementedError from e

    def const_getattr(self, tx, name):
        member = getattr(self.value, name)
        if callable(member):
            raise NotImplementedError()
        return member

    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        from .tensor import SymNodeVariable

        options = VariableTracker.propagate(self, args, kwargs.values())

        if istype(self.value, tuple):
            # empty tuple constant etc
            return variables.TupleVariable(
                items=self.unpack_var_sequence(tx), source=self.source, **options
            ).call_method(tx, name, args, kwargs)

        if any([isinstance(x, SymNodeVariable) for x in args]):
            # Promote to SymNodeVariable for operations involving dynamic shapes.
            return variables.SymNodeVariable(self.as_proxy(), self.value).call_method(
                tx, name, args, kwargs
            )

        try:
            const_args = [a.as_python_constant() for a in args]
            const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
        except NotImplementedError:
            return super().call_method(tx, name, args, kwargs)

        def has_arith_binop(num_ty):
            return (
                isinstance(self.value, num_ty)
                and hasattr(operator, name)
                and len(args) == 1
                and args[0].is_python_constant()
            )

        if isinstance(self.value, str) and name in str.__dict__.keys():
            assert not kwargs
            method = getattr(self.value, name)
            return ConstantVariable(method(*const_args, **const_kwargs), **options)
        elif has_arith_binop(int) or has_arith_binop(float):
            op = getattr(operator, name)
            add_target = const_args[0]
            if isinstance(add_target, (torch.SymInt, torch.SymFloat)):
                from .tensor import SymNodeVariable

                # Addition between a non sym and sym makes a sym
                # sym_num = tx.output.register_attr_or_module(
                #     add_target, f"sym_shape_{add_target}", source=None
                # )
                proxy = tx.output.create_proxy(
                    "call_function", op, (self.value, add_target), {}
                )
                return SymNodeVariable.create(tx, proxy, add_target, **options)
            return ConstantVariable(op(self.value, add_target), **options)
        elif name == "__len__" and not (args or kwargs):
            return ConstantVariable(len(self.value), **options)
        elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
            assert not kwargs
            search = args[0].as_python_constant()
            result = search in self.value
            return ConstantVariable(result, **options)

        unimplemented(f"const method call {typestr(self.value)}.{name}")


class EnumVariable(VariableTracker):
    def __init__(self, value, **kwargs):
        super().__init__(**kwargs)
        self.value = value

    def as_proxy(self):
        return self.value

    def __str__(self):
        return f"EnumVariable({type(self.value)})"

    def python_type(self):
        return type(self.value)

    def as_python_constant(self):
        return self.value

    def const_getattr(self, tx, name):
        member = getattr(self.value, name)
        if callable(member):
            raise NotImplementedError()
        return member