Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ model.py

import dataclasses
import itertools
import re

from dataclasses import dataclass
from enum import auto, Enum
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union

from torchgen.utils import assert_never, NamespaceHelper, OrderedSet

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
#                           DATA MODEL
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Some general principles for our data model.
#
# - Stop using C++ data types as the internal data representation
#   format.  Instead, the internal data structures are centered
#   around JIT schema representation.  This avoid a big problem
#   with the old codegen where we read in all the types from
#   native_functions.yaml and then immediately had to retranslate
#   them into C++ types.
#
# - More semantic data representation.  Instead of representing
#   everything as dicts and strings, we define dataclasses for
#   every interesting entity the code generation has to deal with.
#   These dataclasses have strong semantic invariants: for example,
#   we generally require them to roundtrip losslessly into the
#   form they were parsed from.  These structures are immutable
#   and you're expected to populate information once during
#   construction.

# Represent a source location; used for better error reporting
@dataclass(frozen=True)
class Location:
    file: str
    line: int

    def __str__(self) -> str:
        return "{}:{}".format(self.file, self.line)


# Valid values of the 'variants' field in native_functions.yaml
class Variant(Enum):
    function = auto()
    method = auto()


# Default kernel namespace
DEFAULT_KERNEL_NAMESPACE = "at::native"

# NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h
BACKEND_COMPONENTS = "CPU CUDA HIP XLA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split()
FUNCTIONALITY_KEYS = ["", "Quantized", "Sparse", "NestedTensor", "Autograd"]

# This list guards dispatches that can be used in derivatives.yaml
# For now we omit AutogradFunctionality and AutogradOther
AUTOGRAD_KEYS = ["AutogradNestedTensor"] + [
    "Autograd" + component for component in BACKEND_COMPONENTS
]

FRAGMENT_NAMESPACES = {"quantized", "quantized_decomposed"}

# This doesn't have to be in sync with the header, it only needs to contain
# entries that we actually use in the codegen or want pyi entries for
class DispatchKey(Enum):
    Undefined = 0
    CatchAll = Undefined

    FPGA = auto()
    ORT = auto()
    Vulkan = auto()
    Metal = auto()
    MKLDNN = auto()
    OpenGL = auto()
    OpenCL = auto()
    IDEEP = auto()
    CustomRNGKeyId = auto()
    MkldnnCPU = auto()
    Sparse = auto()
    SparseCsrCPU = auto()
    SparseCsrCUDA = auto()

    Python = auto()
    FuncTorchDynamicLayerBackMode = auto()
    ZeroTensor = auto()
    BackendSelect = auto()
    Named = auto()
    AutogradOther = auto()
    AutogradFunctionality = auto()
    AutogradNestedTensor = auto()
    Tracer = auto()
    Autocast = auto()
    Batched = auto()
    VmapMode = auto()
    FuncTorchDynamicLayerFrontMode = auto()
    Functionalize = auto()
    TESTING_ONLY_GenericWrapper = auto()
    TESTING_ONLY_GenericMode = auto()

    ADInplaceOrView = auto()
    Autograd = auto()
    CompositeImplicitAutograd = auto()
    CompositeImplicitAutogradNestedTensor = auto()
    CompositeExplicitAutograd = auto()
    CompositeExplicitAutogradNonFunctional = auto()

    # BEGIN autogenerated
    CPU = auto()
    CUDA = auto()
    HIP = auto()
    XLA = auto()
    MPS = auto()
    IPU = auto()
    XPU = auto()
    HPU = auto()
    VE = auto()
    Lazy = auto()
    Meta = auto()
    PrivateUse1 = auto()
    PrivateUse2 = auto()
    PrivateUse3 = auto()
    QuantizedCPU = auto()
    QuantizedCUDA = auto()
    QuantizedHIP = auto()
    QuantizedXLA = auto()
    QuantizedMPS = auto()
    QuantizedIPU = auto()
    QuantizedXPU = auto()
    QuantizedHPU = auto()
    QuantizedVE = auto()
    QuantizedLazy = auto()
    QuantizedMeta = auto()
    QuantizedPrivateUse1 = auto()
    QuantizedPrivateUse2 = auto()
    QuantizedPrivateUse3 = auto()
    SparseCPU = auto()
    SparseCUDA = auto()
    SparseHIP = auto()
    SparseXLA = auto()
    SparseMPS = auto()
    SparseIPU = auto()
    SparseXPU = auto()
    SparseHPU = auto()
    SparseVE = auto()
    SparseLazy = auto()
    SparseMeta = auto()
    SparsePrivateUse1 = auto()
    SparsePrivateUse2 = auto()
    SparsePrivateUse3 = auto()
    NestedTensorCPU = auto()
    NestedTensorCUDA = auto()
    NestedTensorHIP = auto()
    NestedTensorXLA = auto()
    NestedTensorMPS = auto()
    NestedTensorIPU = auto()
    NestedTensorXPU = auto()
    NestedTensorHPU = auto()
    NestedTensorVE = auto()
    NestedTensorLazy = auto()
    NestedTensorMeta = auto()
    NestedTensorPrivateUse1 = auto()
    NestedTensorPrivateUse2 = auto()
    NestedTensorPrivateUse3 = auto()
    AutogradCPU = auto()
    AutogradCUDA = auto()
    AutogradHIP = auto()
    AutogradXLA = auto()
    AutogradMPS = auto()
    AutogradIPU = auto()
    AutogradXPU = auto()
    AutogradHPU = auto()
    AutogradVE = auto()
    AutogradLazy = auto()
    AutogradMeta = auto()
    AutogradPrivateUse1 = auto()
    AutogradPrivateUse2 = auto()
    AutogradPrivateUse3 = auto()
    # END autogenerated

    def __str__(self) -> str:
        return self.name

    def lower(self) -> str:
        return str(self).lower()

    @staticmethod
    def parse(value: str) -> "DispatchKey":
        for k, v in DispatchKey.__members__.items():
            if k == value:
                return v
        raise AssertionError(f"unknown dispatch key {value}")


def codegen_per_backend_entries() -> str:
    r = []
    for fk in FUNCTIONALITY_KEYS:
        for bc in BACKEND_COMPONENTS:
            r.append(f"    {fk}{bc} = auto()")
    return "\n".join(r)


for fk in FUNCTIONALITY_KEYS:
    for bc in BACKEND_COMPONENTS:
        if not hasattr(DispatchKey, fk + bc):
            r = codegen_per_backend_entries()
            print(r)
            raise RuntimeError(
                f"Missing {fk}{bc} from DispatchKey enum.  Here is the autogenerated list we expect to have:\n\n{r}"
            )


STRUCTURED_DISPATCH_KEYS = {DispatchKey.MPS, DispatchKey.CUDA, DispatchKey.CPU}
UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU}

# Set of supported dispatch keys
dispatch_keys = [
    DispatchKey.CPU,
    DispatchKey.SparseCPU,
    DispatchKey.SparseCsrCPU,
    DispatchKey.MkldnnCPU,
    DispatchKey.CUDA,
    DispatchKey.MPS,
    DispatchKey.SparseCUDA,
    DispatchKey.SparseCsrCUDA,
    DispatchKey.QuantizedCPU,
    DispatchKey.QuantizedCUDA,
    DispatchKey.CompositeImplicitAutograd,
    DispatchKey.CompositeImplicitAutogradNestedTensor,
    DispatchKey.CompositeExplicitAutograd,
    DispatchKey.CompositeExplicitAutogradNonFunctional,
    DispatchKey.NestedTensorCPU,
    DispatchKey.NestedTensorCUDA,
    # Meta is a magic key: it is automatically generated for structured
    # kernels
    DispatchKey.Meta,
    DispatchKey.SparseMeta,
    DispatchKey.QuantizedMeta,
    DispatchKey.NestedTensorMeta,
    DispatchKey.ZeroTensor,
]

# Dispatch keys that "support all backends".  These codegen slightly differently
# then backend specific keys.
def is_generic_dispatch_key(dk: DispatchKey) -> bool:
    return dk in {
        DispatchKey.CompositeExplicitAutograd,
        DispatchKey.CompositeExplicitAutogradNonFunctional,
        DispatchKey.CompositeImplicitAutograd,
        DispatchKey.CompositeImplicitAutogradNestedTensor,
    }


# CUDA specific dispatch keys
def is_cuda_dispatch_key(dk: DispatchKey) -> bool:
    return dk in {
        DispatchKey.CUDA,
        DispatchKey.QuantizedCUDA,
        DispatchKey.SparseCUDA,
        DispatchKey.SparseCsrCUDA,
        DispatchKey.NestedTensorCUDA,
        DispatchKey.AutogradCUDA,
    }


# Structured kernel generation is only supported for certain key types;
# otherwise use old-style
def is_structured_dispatch_key(dk: DispatchKey) -> bool:
    return dk in STRUCTURED_DISPATCH_KEYS


def is_ufunc_dispatch_key(dk: DispatchKey) -> bool:
    # For now, ufunc dispatch keys coincide with structured keys
    return dk in UFUNC_DISPATCH_KEYS


# This is oddly named ScalarType and not DType for symmetry with C++
class ScalarType(Enum):
    Byte = auto()
    Char = auto()
    Short = auto()
    Int = auto()
    Long = auto()
    Half = auto()
    Float = auto()
    Double = auto()
    ComplexHalf = auto()
    ComplexFloat = auto()
    ComplexDouble = auto()
    Bool = auto()
    BFloat16 = auto()

    def __str__(self) -> str:
        return self.name

    @staticmethod
    def maybe_parse(value: str) -> Optional["ScalarType"]:
        for k, v in ScalarType.__members__.items():
            if k == value:
                return v
        return None

    @staticmethod
    def parse(value: str) -> "ScalarType":
        mb_r = ScalarType.maybe_parse(value)
        assert mb_r is not None, f"unknown dtype {value}"
        return mb_r

    @staticmethod
    def parse_set(values: str) -> OrderedSet["ScalarType"]:
        dtypes: OrderedSet[ScalarType] = OrderedSet()
        for value in values.split(", "):
            if value in DTYPE_CLASSES:
                dtypes.update(DTYPE_CLASSES[value])
            else:
                dtypes.add(ScalarType.parse(value))
        return dtypes


DTYPE_CLASSES: Dict[str, OrderedSet[ScalarType]] = {}
# NB: Integral doesn't include boolean
DTYPE_CLASSES["Integral"] = OrderedSet(
    [
        ScalarType.Byte,
        ScalarType.Char,
        ScalarType.Int,
        ScalarType.Long,
        ScalarType.Short,
    ]
)
# NB: Floating doesn't include low precision types
DTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double])
DTYPE_CLASSES["Complex"] = OrderedSet(
    [ScalarType.ComplexFloat, ScalarType.ComplexDouble]
)
DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"]
DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"]
DTYPE_CLASSES["FloatingAndComplex"] = (
    DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"]
)


# Represents the valid entries for ufunc_inner_loop in native_functions.yaml.
Loading ...