import itertools
import textwrap
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing_extensions import Literal # Python 3.8+
import torchgen.api.cpp as cpp
import torchgen.api.meta as meta
import torchgen.api.structured as structured
from torchgen.api.translate import translate
from torchgen.api.types import (
BaseCType,
Binding,
ConstRefCType,
CppSignature,
CppSignatureGroup,
DispatcherSignature,
Expr,
kernel_signature,
MutRefCType,
NamedCType,
NativeSignature,
tensorT,
)
from torchgen.context import method_with_native_function, native_function_manager
from torchgen.model import (
Argument,
BackendIndex,
DeviceCheckType,
DispatchKey,
gets_generated_out_inplace_wrapper,
is_cuda_dispatch_key,
NativeFunction,
NativeFunctionsGroup,
SchemaKind,
TensorOptionsArguments,
)
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import assert_never, mapMaybe, Target
def gen_registration_headers(
backend_index: BackendIndex,
per_operator_headers: bool,
rocm: bool,
) -> List[str]:
if per_operator_headers:
headers = ["#include <ATen/ops/as_strided_native.h>"]
else:
headers = ["#include <ATen/NativeFunctions.h>"]
if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta):
headers.append("#include <ATen/EmptyTensor.h>")
elif backend_index.dispatch_key == DispatchKey.CUDA:
if rocm:
headers.append("#include <ATen/hip/EmptyTensor.h>")
else:
headers.append("#include <ATen/cuda/EmptyTensor.h>")
elif backend_index.dispatch_key == DispatchKey.MPS:
headers.append("#include <ATen/mps/EmptyTensor.h>")
elif per_operator_headers:
headers += [
"#include <ATen/ops/empty.h>",
"#include <ATen/ops/empty_strided.h>",
"#include <ATen/ops/_copy_from_and_resize.h>",
"#include <ATen/ops/_copy_from.h>",
]
else:
headers.append("#include <ATen/Functions.h>")
return headers
def gen_empty_impl_names(
backend_index: BackendIndex,
) -> Tuple[Optional[str], Optional[str]]:
empty_impl = None
empty_strided_impl = None
if backend_index.dispatch_key in (
DispatchKey.Meta,
DispatchKey.CPU,
DispatchKey.CUDA,
DispatchKey.MPS,
):
dispatch = str(backend_index.dispatch_key).lower()
empty_impl = f"at::detail::empty_{dispatch}"
empty_strided_impl = f"at::detail::empty_strided_{dispatch}"
elif backend_index.dispatch_key in (
DispatchKey.CompositeExplicitAutogradNonFunctional,
DispatchKey.QuantizedCPU,
DispatchKey.QuantizedCUDA,
):
empty_impl = "at::empty"
empty_strided_impl = "at::empty_strided"
return empty_impl, empty_strided_impl
def gen_create_out_helper(backend_index: BackendIndex) -> List[str]:
if backend_index.dispatch_key == DispatchKey.Meta:
empty_options = "options.device(at::kMeta)"
else:
empty_options = "options"
empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index)
if empty_impl is None:
return []
return [
f"""
Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
if (strides.empty()) {{
return {empty_impl}(sizes, {empty_options});
}} else {{
return {empty_strided_impl}(sizes, strides, {empty_options});
}}
}}
"""
]
def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> List[str]:
_, empty_strided_impl = gen_empty_impl_names(backend_index)
return (
[]
if empty_strided_impl is None
else [
f"""
c10::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
if (out.strides() != strides) {{
return {empty_strided_impl}(sizes, strides, options);
}}
return c10::nullopt;
}}
"""
]
)
def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]:
if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
# The function isn't used by this key (since only functional ops have a kernel for this key),
# so we need to not include it to avoid a defined-but-not-used error.
return []
return [
"""
void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
TORCH_CHECK(options.dtype() == out.dtype(),
"Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
TORCH_CHECK(options.device() == out.device(),
"Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
const bool resized = at::native::resize_output(out, sizes);
// Only restride if a resize occurred; otherwise we ignore the (advisory)
// strides from the meta function and directly use the output tensor's
// preexisting strides
if (resized) {
if (!strides.empty()) {
TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
// TODO: avoid the redispatch here
out.as_strided_(sizes, strides);
} else if (options.memory_format_opt().has_value()) {
out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
}
}
}
"""
]
def gen_check_inplace_helper(backend_index: BackendIndex) -> List[str]:
return [
"""
void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
// These checks are needed on those operators that:
// 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm')
// 2) have particular typing rules (e.g. 'cumsum' and 'cumprod')
// For other operators (e.g. 'add'), 'TensorIterator' already checks
// these things separately.
TORCH_CHECK(options.dtype() == self.dtype(),
"Bad in-place call: ",
"input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match");
TORCH_CHECK(options.device() == self.device(),
"Bad in-place call: ",
"input tensor device ", self.device(), " and output tensor device ", options.device(), " should match");
TORCH_CHECK(sizes == self.sizes(),
"Bad in-place call: ",
"input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match");
}
"""
]
def gen_registration_helpers(backend_index: BackendIndex) -> List[str]:
return [
*gen_create_out_helper(backend_index),
*gen_resize_out_helper(backend_index),
*gen_check_inplace_helper(backend_index),
*gen_maybe_create_proxy_helper(backend_index),
]
# Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp).
#
# - The primary function of this file is to register all of the
# implementations for the given dispatch key to the dispatcher,
# so they are available for use in PyTorch. If dispatch is
# None, we generate schema (def) registrations and catchall
# registrations.
# - The secondary function of this file is to generate a wrapper
# around functions. In CPUType these wrappers do nothing
# (and should be removed), but in other cases they handle
# DeviceGuard. A small extra benefit of wrappers is they
# are not overloaded, so they can be used in the registration
# API without having to disambiguate which overload you want
# (as would be the case if you directly registered native::
# functions).
# - The tertiary function of this file is to generate *static*
# cpp API bindings which can be used to bypass dispatcher
# directly to kernels, but with user-friendly cpp-style API
@dataclass(frozen=True)
class RegisterDispatchKey:
backend_index: BackendIndex
target: Union[
Literal[Target.ANONYMOUS_DEFINITION],
Literal[Target.NAMESPACED_DEFINITION],
Literal[Target.NAMESPACED_DECLARATION],
Literal[Target.REGISTRATION],
]
# Selector object to determine which operators to generate
# registration code for.
selector: SelectiveBuilder
# Whether or not we are actually code-genning for ROCm
rocm: bool
# Whether or not to generate symint registrations or not. External users
# of codegen who don't care about symints can set this to false to get
# non-SymInt codegen
symint: bool
# The class that all unstructured native functions live under. This is used to improve
# compiler error messages when a kernel writer adds a native function with the wrong signature.
# This is only used in unstructured kernels, since structured kernels already live in a class.
# Finally, this field is currently Optional because it is only used by external backends.
# It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
# all of the existing kernel signatures scattered across aten/src/ATen/native.
class_method_name: Optional[str]
# Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
# operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
skip_dispatcher_op_registration: bool
@staticmethod
def gen_device_check(
type: DeviceCheckType, args: List[Argument], method_name: str
) -> str:
if type == DeviceCheckType.NoCheck:
return " // No device check\n"
device_check = "c10::optional<Device> common_device = nullopt;\n"
device_check += "(void)common_device; // Suppress unused variable warning\n"
for arg in args:
# Only tensor like arguments are eligible
if arg.type.is_tensor_like():
device_check += f"""
c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");"""
return device_check
@method_with_native_function
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
if isinstance(f, NativeFunctionsGroup):
g: NativeFunctionsGroup = f
# Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
# gen_structured() has special logic to handle auto-generated kernels.
if g.structured:
return self.gen_structured(g)
else:
return list(
mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())
)
elif isinstance(f, NativeFunction):
r = self.gen_unstructured(f)
return [] if r is None else [r]
else:
assert_never(f)
def wrapper_kernel_sig(
self, f: NativeFunction
) -> Union[NativeSignature, DispatcherSignature]:
# The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
return DispatcherSignature.from_schema(
f.func,
prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_",
symint=self.symint,
)
def gen_out_inplace_wrapper(
self, f: NativeFunction, g: Optional[NativeFunctionsGroup]
) -> Optional[str]:
if g is None:
return None
k = f.func.kind()
if k is SchemaKind.inplace:
copy_op = "at::_copy_from"
elif k is SchemaKind.out:
copy_op = "at::_copy_from_and_resize"
else:
raise AssertionError("gen_out_inplace_wrapper called on a functional op")
sig = self.wrapper_kernel_sig(f)
name = sig.name()
func_res = f"{name}_tmp"
return_names = cpp.return_names(f)
if len(return_names) > 1:
updates = "\n ".join(
f"{copy_op}(std::get<{i}>({func_res}), {ret_name});"
for i, ret_name in enumerate(return_names)
)
returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
else:
ret_name = return_names[0]
updates = f"{copy_op}({func_res}, {ret_name});"
returns = ret_name
functional_sig = self.wrapper_kernel_sig(g.functional)
wrapper_name = sig.name()
return f"""\
{sig.defn(name=wrapper_name)} {{
auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))});
{updates}
return {returns};
}}
"""
def gen_structured(self, g: NativeFunctionsGroup) -> List[str]:
metadata = self.backend_index.get_kernel(g)
if self.backend_index.dispatch_key == DispatchKey.Meta:
assert not self.backend_index.has_kernel(g.out), (
Loading ...