Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ _dynamo / source.py

import collections
import dataclasses
import enum
from typing import Any, Optional, Union

from torch._guards import GuardSource, Source

from . import utils
from .bytecode_transformation import create_instruction
from .utils import enum_repr, rename_implicit

_GUARD_SOURCE_NN_MODULE = {
    GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE,
    GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE,
    GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE,
    GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE,
}

_GUARD_SOURCE_NOT_NN_MODULE = {
    GuardSource.LOCAL: GuardSource.LOCAL,
    GuardSource.GLOBAL: GuardSource.GLOBAL,
    GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL,
    GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL,
}


def is_constant_source(source):
    if isinstance(source, ConstantSource):
        return True
    try:
        if source.guard_source() == GuardSource.CONSTANT:
            return True
    except NotImplementedError:
        pass

    return False


def is_input_source(source):
    return source.guard_source() in [
        GuardSource.LOCAL,
        GuardSource.GLOBAL,
        GuardSource.LOCAL_NN_MODULE,
        GuardSource.GLOBAL_NN_MODULE,
    ]


@dataclasses.dataclass
class LocalSource(Source):
    local_name: str

    def reconstruct(self, codegen):
        return [codegen.create_load(self.local_name)]

    def guard_source(self):
        return GuardSource.LOCAL

    def name(self):
        return rename_implicit(self.local_name)


@dataclasses.dataclass
class LocalInputSource(LocalSource):
    pos: int


@dataclasses.dataclass
class RandomValueSource(Source):
    random_call_index: int

    def guard_source(self):
        return GuardSource.RANDOM_VALUE

    def reconstruct(self, codegen):
        return [
            codegen.create_load(codegen.tx.output.random_values_var),
            codegen.create_load_const(self.random_call_index),
            create_instruction("BINARY_SUBSCR"),
        ]

    def name(self):
        return rename_implicit(f"random_value_{self.random_call_index}")


@dataclasses.dataclass
class GlobalSource(Source):
    global_name: str

    def reconstruct(self, codegen):
        return [codegen.create_load_global(self.global_name, add=True)]

    def guard_source(self):
        return GuardSource.GLOBAL

    def name(self):
        return self.global_name


@dataclasses.dataclass
class GlobalWeakRefSource(Source):
    global_name: str

    def reconstruct(self, codegen):
        return [
            codegen.create_load_global(self.global_name, add=True),
            create_instruction("CALL_FUNCTION", 0),
        ]

    def guard_source(self):
        return GuardSource.GLOBAL

    def name(self):
        return f"{self.global_name}()"


@dataclasses.dataclass
class AttrSource(Source):
    base: Source
    member: str

    def __init__(self, base, member):
        super().__init__()
        assert base, "Can't construct an AttrSource without a valid base source"
        if "." in member:
            member_parts = member.split(".")
            self.base = AttrSource(base, ".".join(member_parts[:-1]))
            self.member = member_parts[-1]
        else:
            self.base = base
            self.member = member
        assert self.base is not None

    def reconstruct(self, codegen):
        return self.base.reconstruct(codegen) + codegen.create_load_attrs(self.member)

    def guard_source(self):
        return self.base.guard_source()

    def name(self):
        if self.member.isnumeric():
            return f"getattr({self.base.name()}, {self.member!r})"
        return f"{self.base.name()}.{self.member}"


class TensorProperty(enum.Enum):
    SIZE = 0
    STRIDE = 1
    STORAGE_OFFSET = 2


@dataclasses.dataclass
class TensorPropertySource(Source):
    base: Source
    prop: TensorProperty
    idx: Optional[int] = None  # None for STORAGE_OFFSET

    def __post_init__(self):
        assert self.base is not None
        if self.prop is TensorProperty.STORAGE_OFFSET:
            assert self.idx is None
        else:
            assert self.idx is not None

    def reconstruct(self, codegen):
        raise NotImplementedError()

    def guard_source(self):
        return self.base.guard_source()

    def name(self):
        if self.prop is TensorProperty.SIZE:
            return f"{self.base.name()}.size()[{self.idx}]"
        elif self.prop is TensorProperty.STRIDE:
            return f"{self.base.name()}.stride()[{self.idx}]"
        elif self.prop is TensorProperty.STORAGE_OFFSET:
            assert self.idx is None
            return f"{self.base.name()}.storage_offset()"
        else:
            raise AssertionError(f"unhandled {self.prop}")


@dataclasses.dataclass
class NegateSource(Source):
    base: Source

    def __post_init__(self):
        assert self.base is not None

    def reconstruct(self, codegen):
        raise NotImplementedError()

    def guard_source(self):
        return self.base.guard_source()

    def name(self):
        # NB: use method call so that function stripping regexes work
        return f"{self.base.name()}.__neg__()"


@dataclasses.dataclass
class DefaultsSource(Source):
    base: Source
    idx_key: Union[int, str]
    is_kw: bool
    field: str

    def __init__(self, base, idx_key, is_kw=False):
        super().__init__()
        assert (
            base
        ), "Base must be a valid source in order to properly track and guard this Defaults to its origin."
        self.base = base
        self.idx_key = idx_key
        self.is_kw = is_kw
        if self.is_kw:
            assert isinstance(idx_key, str)
            self.field = "__kwdefaults__"
            self._name = f"{self.base.name()}.{self.field}['{self.idx_key}']"
        else:
            assert isinstance(idx_key, int)
            self.field = "__defaults__"
            self._name = f"{self.base.name()}.{self.field}[{self.idx_key}]"

    def reconstruct(self, codegen):
        instrs = self.base.reconstruct(codegen)
        instrs.extend(codegen.create_load_attrs(self.field))
        instrs.extend(
            [
                codegen.create_load_const(self.idx_key),
                create_instruction("BINARY_SUBSCR"),
            ]
        )
        return instrs

    def guard_source(self):
        return self.base.guard_source()

    def name(self):
        return self._name


@dataclasses.dataclass
class GetItemSource(Source):
    base: Source
    index: Any

    def __post_init__(self):
        assert self.base is not None

    def reconstruct(self, codegen):
        instrs = self.base.reconstruct(codegen)

        if isinstance(self.index, Source):
            instrs.extend(self.index.reconstruct(codegen))
        else:
            instrs.append(codegen.create_load_const(self.index))
        instrs.append(create_instruction("BINARY_SUBSCR"))

        return instrs

    def guard_source(self):
        return self.base.guard_source()

    def name(self):
        if isinstance(self.index, Source):
            return f"{self.base.name()}[{self.index.name()}]"
        else:
            if isinstance(self.index, enum.Enum):
                return f"{self.base.name()}[{enum_repr(self.index)}]"
            else:
                return f"{self.base.name()}[{self.index!r}]"


@dataclasses.dataclass
class TupleIteratorGetItemSource(GetItemSource):
    def reconstruct(self, codegen):
        codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
        return self.base.reconstruct(codegen) + [
            codegen.create_load_const(self.index),
            create_instruction("CALL_FUNCTION", 2),
        ]

    def name(self):
        return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"


@dataclasses.dataclass
class TypeSource(Source):
    base: Source

    def __post_init__(self):
        assert self.base is not None

    def reconstruct(self, codegen):
        codegen.load_import_from("builtins", "type")
        return self.base.reconstruct(codegen) + [create_instruction("CALL_FUNCTION", 1)]

    def guard_source(self):
        return self.base.guard_source()

    def name(self):
        return f"type({self.base.name()})"


@dataclasses.dataclass
class SuperSource(Source):
    type: Source
    obj: Source

    def __post_init__(self):
        assert self.type is not None
        assert self.obj is not None

    def reconstruct(self, codegen):
        codegen.load_import_from("builtins", "super")
        return (
            self.type.reconstruct(codegen)
            + self.obj.reconstruct(codegen)
            + [create_instruction("CALL_FUNCTION", 2)]
        )

    def guard_source(self):
        return self.obj.guard_source()

    def name(self):
        return f"super({self.type.name()}, {self.obj.name()})"


@dataclasses.dataclass
class ODictGetItemSource(Source):
    base: Source
    index: Any

    def __post_init__(self):
        assert self.base is not None

    def reconstruct(self, codegen):
        return (
            [codegen._create_load_const(collections.OrderedDict.__getitem__)]
            + self.base.reconstruct(codegen)
            + [
                codegen.create_load_const(self.index),
                create_instruction("CALL_FUNCTION", 2),
            ]
        )
Loading ...