import argparse
import os
import pathlib
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Sequence, TextIO, Tuple, Union
import yaml
# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
from torchgen import dest
from torchgen.api import cpp as aten_cpp
from torchgen.api.types import CppSignature, CppSignatureGroup, CType, NamedCType
from torchgen.context import method_with_native_function, with_native_function_and_index
from torchgen.executorch.api import et_cpp
from torchgen.executorch.api.custom_ops import (
ComputeNativeFunctionStub,
gen_custom_ops_registration,
)
from torchgen.executorch.api.types import ExecutorchCppSignature
from torchgen.executorch.api.unboxing import Unboxing
from torchgen.gen import (
get_custom_build_selector,
get_native_function_declarations,
get_native_function_schema_registrations,
LineLoader,
parse_native_yaml,
ParsedYaml,
)
from torchgen.model import (
BackendIndex,
BackendMetadata,
DispatchKey,
is_cuda_dispatch_key,
Location,
NativeFunction,
NativeFunctionsGroup,
OperatorName,
Variant,
)
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import (
context,
FileManager,
make_file_manager,
mapMaybe,
NamespaceHelper,
)
def static_dispatch(
sig: Union[CppSignature, ExecutorchCppSignature],
f: NativeFunction,
backend_indices: List[BackendIndex],
) -> str:
"""
For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one
native function exists, error out. A simplified version of register_dispatch_key.py
Arguments:
sig: A CppSignature for this native function we want to use.
f: NativeFunction to generate static dispatch.
backend_indices: All available backends.
Return:
C++ code to call backend-specific functions, e.g., "return at::native::add(self, other, scale);"
"""
if len(backend_indices) == 0 or f.manual_kernel_registration:
return ""
backends = [b for b in backend_indices if b.has_kernel(f)]
static_block = None
if len(backends) == 1:
backend_metadata = backends[0].get_kernel(f)
if backend_metadata:
args = ", ".join(a.name for a in sig.arguments())
# Here we are assuming there's no difference between CppSignature and NativeSignature for Executorch.
static_block = f"return ::{backend_metadata.cpp_namespace}::{backend_metadata.kernel}({args});"
else:
static_block = f"""
ET_ASSERT_UNREACHABLE_MSG("The number of native function(s) binding to {f.func.name} is {len(backends)}.");
"""
return f"""
// {f.namespace}::{f.func}
TORCH_API inline {sig.decl()} {{
{static_block}
}}
"""
# Generates Functions.h, which provides the functional public C++ API,
# and the scaffolding to call into the dispatcher from these functions.
@dataclass(frozen=True)
class ComputeFunction:
static_dispatch_backend_indices: List[BackendIndex]
selector: SelectiveBuilder
use_aten_lib: bool
is_custom_op: Callable[[NativeFunction], bool]
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
return None
if Variant.function not in f.variants:
return None
sig: Union[CppSignature, ExecutorchCppSignature] = (
CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=f.manual_cpp_binding
).most_faithful_signature()
if self.use_aten_lib
else ExecutorchCppSignature.from_native_function(f)
)
if self.use_aten_lib and not self.is_custom_op(f):
comma = ", "
return f"""
// {f.namespace}::{f.func}
TORCH_API inline {sig.decl()} {{
return at::{sig.name()}({comma.join(e.name for e in sig.arguments())});
}}
"""
else:
return static_dispatch(
sig,
f,
backend_indices=self.static_dispatch_backend_indices,
)
# Generates RegisterCodegenUnboxedKernels.cpp.
@dataclass(frozen=True)
class ComputeCodegenUnboxedKernels:
selector: SelectiveBuilder
use_aten_lib: bool
@method_with_native_function
def __call__(self, f: NativeFunction) -> str:
if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
return ""
sig: Union[CppSignature, ExecutorchCppSignature]
argument_type_gen: Callable[..., NamedCType]
return_type_gen: Callable[..., CType]
if self.use_aten_lib:
sig = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=f.manual_cpp_binding
).most_faithful_signature()
argument_type_gen = aten_cpp.argumenttype_type
return_type_gen = aten_cpp.returns_type
else:
sig = ExecutorchCppSignature.from_native_function(f)
argument_type_gen = et_cpp.argumenttype_type
return_type_gen = et_cpp.returns_type
# parse arguments into C++ code
binding_list, code_list = Unboxing(
argument_type_gen=argument_type_gen
).convert_arguments(sig.arguments())
# for each C++ argument, generate the conversion code
code_connector = "\n\t"
arg_connector = ", "
args_str = f"{arg_connector.join(e.name for e in binding_list)}"
if len(f.func.returns) == 0:
if len(f.func.arguments.out) == 0:
raise Exception(
f"Can't handle native function {f.func} with no returns and no out yet."
)
out = f.func.arguments.out[0]
return_assignment = f"""stack[{len(binding_list)}] = &{out.name};"""
ret_prefix = ""
else:
if len(f.func.arguments.out) == 0:
return_assignment = (
f"""*stack[{len(binding_list)}] = EValue(result_);"""
)
ret_prefix = return_type_gen(f.func.returns).cpp_type() + " result_ = "
else:
return_assignment = ""
ret_prefix = ""
return f"""
Operator(
"{f.namespace}::{f.func.name}",
[](EValue** stack) {{
{code_connector.join(code_list)}
EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}");
{ret_prefix}torch::executor::{f.namespace}::{sig.name()}({args_str});
{return_assignment}
}}
),
"""
def gen_unboxing(
*,
native_functions: Sequence[NativeFunction],
cpu_fm: FileManager,
selector: SelectiveBuilder,
use_aten_lib: bool,
) -> None:
def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
return fn.root_name
cpu_fm.write_sharded(
"RegisterCodegenUnboxedKernels.cpp",
native_functions,
key_fn=key_func,
env_callable=lambda fn: {
"unboxed_ops": [ComputeCodegenUnboxedKernels(selector, use_aten_lib)(fn)],
},
num_shards=1,
sharded_keys={"unboxed_ops"},
)
@with_native_function_and_index
def compute_native_function_declaration(
g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex
) -> List[str]:
assert isinstance(g, NativeFunction)
sig = ExecutorchCppSignature.from_native_function(f=g)
metadata = backend_index.get_kernel(g)
if metadata is None:
return []
prefix = "static" if backend_index.external else "TORCH_API"
return [f"{prefix} {sig.decl(name=metadata.kernel)};"]
def gen_functions_declarations(
*,
native_functions: Sequence[NativeFunction],
static_dispatch_idx: List[BackendIndex],
selector: SelectiveBuilder,
use_aten_lib: bool,
custom_ops_native_functions: Optional[Sequence[NativeFunction]] = None,
) -> str:
"""
Generates namespace separated C++ function API inline declaration/definitions.
Native functions are grouped by namespaces and the generated code is wrapped inside
namespace blocks.
E.g., for `custom_1::foo.out` in yaml file we will generate a C++ API as a symbol
in `torch::executor::custom_1::foo_out`. This way we avoid symbol conflict when
the other `custom_2::foo.out` is available.
"""
ns_grouped_functions = defaultdict(list)
for native_function in native_functions:
ns_grouped_functions[native_function.namespace].append(native_function)
functions_declarations = ""
newline = "\n"
for namespace in ns_grouped_functions:
ns_helper = NamespaceHelper(
namespace_str=namespace,
entity_name="",
max_level=3,
)
declarations = list(
mapMaybe(
ComputeFunction(
static_dispatch_backend_indices=static_dispatch_idx,
selector=selector,
use_aten_lib=use_aten_lib,
is_custom_op=lambda f: custom_ops_native_functions is not None
and f in custom_ops_native_functions,
),
ns_grouped_functions[namespace],
)
)
functions_declarations += f"""
{ns_helper.prologue}
{newline.join(declarations)}
{ns_helper.epilogue}
"""
return functions_declarations
def gen_headers(
*,
native_functions: Sequence[NativeFunction],
custom_ops_native_functions: Sequence[NativeFunction],
static_dispatch_idx: List[BackendIndex],
selector: SelectiveBuilder,
backend_indices: Dict[DispatchKey, BackendIndex],
cpu_fm: FileManager,
use_aten_lib: bool,
) -> None:
aten_headers = ["#include <ATen/Functions.h>"]
if custom_ops_native_functions:
cpu_fm.write_with_template(
"CustomOpsNativeFunctions.h",
"NativeFunctions.h",
lambda: {
"nativeFunctions_declarations": get_native_function_declarations(
grouped_native_functions=custom_ops_native_functions,
backend_indices=backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
),
},
)
aten_headers.append('#include "CustomOpsNativeFunctions.h"')
cpu_fm.write(
"Functions.h",
lambda: {
"static_dispatch_extra_headers": aten_headers
if use_aten_lib
else ['#include "NativeFunctions.h"'],
"Functions_declarations": gen_functions_declarations(
native_functions=native_functions,
static_dispatch_idx=static_dispatch_idx,
selector=selector,
use_aten_lib=use_aten_lib,
custom_ops_native_functions=custom_ops_native_functions,
),
},
)
cpu_fm.write(
"NativeFunctions.h",
lambda: {
"nativeFunctions_declarations": get_native_function_declarations(
grouped_native_functions=native_functions,
backend_indices=backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration
if use_aten_lib
else compute_native_function_declaration,
),
},
)
def gen_custom_ops(
*,
native_functions: Sequence[NativeFunction],
selector: SelectiveBuilder,
backend_indices: Dict[DispatchKey, BackendIndex],
cpu_fm: FileManager,
rocm: bool,
) -> None:
dispatch_key = DispatchKey.CPU
Loading ...