Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ api / ufunc.py

from dataclasses import dataclass
from typing import List, Optional

import torchgen.api.types as api_types

from torchgen.api import cpp, structured
from torchgen.api.types import (
    ArgName,
    BaseCppType,
    BaseCType,
    Binding,
    ConstRefCType,
    CType,
    NamedCType,
    scalarT,
)
from torchgen.model import (
    Argument,
    BaseTy,
    BaseType,
    DispatchKey,
    FunctionSchema,
    NativeFunctionsGroup,
    Type,
)


def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
    assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas"
    return f"ufunc_{func.name.name}_{dispatch_key}"


def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
    return schema_kernel_name(g.out.func, dispatch_key)


# Tensors are omitted (as they are stored in TensorIterator), everything else is
# passed along  (technically, we can pass tensors along too, it just wastes
# argument registers)
#
# NB: used for CPU only
def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]:
    # Dispatch stubs are always plain ints
    r = cpp.valuetype_type(t, binds=binds, symint=False)
    if r is not None:
        return r

    if t == BaseType(BaseTy.Scalar):
        return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
    elif t == BaseType(BaseTy.Tensor):
        return None
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")


def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
    if scalar_t == api_types.scalar_t:
        return api_types.opmath_t
    raise NotImplementedError


# NB: Tensors in constructor are stored in opmath_t, not scalar_t
# because Tensor in constructor = its a scalar tensor partially applied =
# it can be higher precision and we want to compute in that higher precision
#
# NB: CUDA only
def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
    r = cpp.valuetype_type(t, binds=binds, symint=False)
    if r is not None:
        return r

    if t == BaseType(BaseTy.Scalar):
        return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
    elif t == BaseType(BaseTy.Tensor):
        return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")


# Only Tensors ever get passed directly to operator()
#
# NB: CUDA only
# (Actually, this works for CPU too)
def ufunctor_apply_type(
    t: Type, *, binds: ArgName, scalar_t: BaseCppType
) -> NamedCType:
    if t == BaseType(BaseTy.Tensor):
        return NamedCType(binds, BaseCType(scalar_t))
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")


# The actual ufunc template function the user writes.  Everything here
# is done in the computation type.  compute_t is opmath_t in CUDA and scalar_t
# in CPU
def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
    r = cpp.valuetype_type(t, binds=binds, symint=False)
    if r is not None:
        return r

    if t == BaseType(BaseTy.Scalar):
        return NamedCType(binds, compute_t)
    elif t == BaseType(BaseTy.Tensor):
        return NamedCType(binds, compute_t)
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")


def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
    return Binding(
        nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t),
        name=a.name,
        default=None,
        argument=a,
    )


def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
    return Binding(
        nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t),
        name=a.name,
        default=None,
        argument=a,
    )


def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
    return Binding(
        nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t),
        name=a.name,
        default=None,
        argument=a,
    )


@dataclass(frozen=True)
class UfunctorBindings:
    ctor: List[Binding]
    apply: List[Binding]


# ufunctors are a CUDA-only concept representing functors that take some of
# their arguments on a host-side constructor, and the rest in the device-side
# apply.  E.g.,
#
# template <typename scalar_t>
# struct CUDAFunctorOnSelf_add {
#   using opmath_t = at::opmath_type<scalar_t>;
#   opmath_t other_;
#   opmath_t alpha_;
#   CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {}
#   __device__ scalar_t operator()(scalar_t self) {
#     return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
#   }
# };
#
# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
# to the operator() definition
def ufunctor_arguments(
    g: NativeFunctionsGroup, *, scalar_tensor_idx: Optional[int], scalar_t: BaseCppType
) -> UfunctorBindings:
    ctor = []
    apply = []
    for a in g.functional.func.arguments.flat_non_out:
        if a.type.is_tensor_like():
            if scalar_tensor_idx == 0:
                # put it in the ctor anyway
                ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
                scalar_tensor_idx = None
            else:
                if scalar_tensor_idx is not None:
                    scalar_tensor_idx -= 1
                apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t))
        else:
            ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
    assert scalar_tensor_idx is None
    return UfunctorBindings(ctor=ctor, apply=apply)


# ufuncs are the inner loop template functions that you wrote in ufunc/add.h
# which do the actual computation in question.  E.g.,
#
# template <typename T>
# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ {
#   return self + alpha * other;
# }
#
# In this file, we refer to T as compute_t which is bound by caller
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> List[Binding]:
    return [
        ufunc_argument(a, compute_t=compute_t)
        for a in g.functional.func.arguments.flat_non_out
    ]


# Stubs are the DispatchStub trampolines that CPU kernels use to get to their
# vectorized versions.  E.g.,
#
# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
def stub_arguments(g: NativeFunctionsGroup) -> List[Binding]:
    # stubs drop all tensor arguments (they are implicit in the TensorIterator
    # argument and keep everything else)
    return [
        r
        for a in g.out.func.arguments.flat_non_out
        if not a.type.is_tensor_like()
        for r in structured.argument(a)
    ]