from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
from torchgen import dest
# disable import sorting to avoid circular dependency.
from torchgen.api.types import DispatcherSignature # isort:skip
from torchgen.context import method_with_native_function
from torchgen.model import BackendIndex, DispatchKey, NativeFunction, Variant
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import concatMap, Target
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
# model authoring side.
@dataclass(frozen=True)
class ComputeNativeFunctionStub:
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if Variant.function not in f.variants:
return None
sig = DispatcherSignature.from_schema(
f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
)
assert sig is not None
if len(f.func.returns) == 0:
ret_name = ""
elif len(f.func.returns) == 1:
if f.func.arguments.out:
ret_name = f.func.arguments.out[0].name
else:
ret_name = next(
(
a.name
for a in f.func.arguments.flat_non_out
if a.type == f.func.returns[0].type
),
"",
)
if not ret_name:
raise Exception(f"Can't handle this return type {f.func}")
else:
assert len(f.func.arguments.out) == len(f.func.returns), (
"Out variant number of returns need to match the number of out arguments."
f" Got outs {str(f.func.arguments.out)} but returns {str(f.func.returns)}"
)
# returns a tuple of out arguments
tensor_type = "at::Tensor &"
comma = ", "
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
{comma.join([r.name for r in f.func.arguments.out])}
)"""
ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
return f"""
{sig.defn()} {{
{ret_str}
}}
"""
def gen_custom_ops_registration(
*,
native_functions: Sequence[NativeFunction],
selector: SelectiveBuilder,
backend_index: BackendIndex,
rocm: bool,
) -> Tuple[str, str]:
"""
Generate custom ops registration code for dest.RegisterDispatchKey.
:param native_functions: a sequence of `NativeFunction`
:param selector: for selective build.
:param backend_index: kernels for all the ops.
:param rocm: bool for dest.RegisterDispatchKey.
:return: generated C++ code to register custom operators into PyTorch
"""
dispatch_key = DispatchKey.CPU
static_init_dispatch_registrations = ""
ns_grouped_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
for native_function in native_functions:
ns_grouped_native_functions[native_function.namespace].append(native_function)
for namespace, functions in ns_grouped_native_functions.items():
if len(functions) == 0:
continue
dispatch_registrations_body = "\n".join(
list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.REGISTRATION,
selector,
rocm=rocm,
symint=False,
class_method_name=None,
skip_dispatcher_op_registration=False,
),
functions,
)
)
)
static_init_dispatch_registrations += f"""
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
{dispatch_registrations_body}
}};"""
anonymous_definition = "\n".join(
list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=rocm,
symint=False,
class_method_name=None,
skip_dispatcher_op_registration=False,
),
native_functions,
)
)
)
return anonymous_definition, static_init_dispatch_registrations