from dataclasses import dataclass
from typing import Callable, List, Sequence, Tuple
from torchgen.api.types import Binding, CType, NamedCType
from torchgen.model import (
Argument,
BaseTy,
BaseType,
ListType,
NativeFunction,
OptionalType,
Type,
)
connector = "\n\t"
# Return unboxing function name for a NativeFunction
def name(f: NativeFunction) -> str:
return f.func.name.unambiguous_name()
@dataclass(frozen=True)
class Unboxing:
"""
Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
A sample generated code:
// aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
void mul_out(EValue** stack) {
EValue& self = *stack[0];
EValue& other = *stack[1];
EValue& out = *stack[2];
const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
EXECUTORCH_SCOPE_PROF("native_call_mul.out");
torch::executor::mul_outf(self_base, other_base, out_base);
}
"""
# this is a callable that converts a JIT argument, into its C++ type.
# Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
argument_type_gen: Callable[
...,
NamedCType,
]
# Convert all the arguments in a NativeFunction to C++ code
def convert_arguments(
self, args: Sequence[Binding]
) -> Tuple[List[Binding], List[str]]:
code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
binding_list = []
for arg in args:
# expecting only Argument
if not isinstance(arg.argument, Argument):
raise Exception(
f"Unexpected argument type, expecting `Argument` but got {arg}"
)
argument: Argument = arg.argument
unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
argument.type, argument.name, mutable=argument.is_write
)
code_list.extend(decl)
code_list.extend(code)
binding_list.append(arg.with_name(unboxed_name))
return binding_list, code_list
def argumenttype_evalue_convert(
self, t: Type, arg_name: str, *, mutable: bool = False
) -> Tuple[str, CType, List[str], List[str]]:
"""
Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
(1) the C++ code necessary to unbox the argument
(2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
:param t: a `Type` of an argument
:param arg_name: argument name
:param mutable: boolean for whether this argument type is mutable
:return: unboxed result
"""
ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
if isinstance(t, BaseType):
out_name = f"{arg_name}_base"
code, decl = self._gen_code_base_type(
arg_name=arg_name, out_name=out_name, ctype=ctype
)
elif isinstance(t, OptionalType):
out_name = f"{arg_name}_opt_out"
code, decl = self._gen_code_optional_type(
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
)
elif isinstance(t, ListType):
out_name = f"{arg_name}_list_out"
code, decl = self._gen_code_list_type(
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
)
else:
raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}")
return out_name, ctype, code, decl
def _gen_code_base_type(
self, arg_name: str, out_name: str, ctype: CType
) -> Tuple[List[str], List[str]]:
return [
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
], []
def _gen_code_optional_type(
self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
) -> Tuple[List[str], List[str]]:
in_name = f"{arg_name}_opt_in"
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
t.elem, in_name
)
return (
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
""".split(
"\n"
),
decl,
)
def _gen_code_list_type(
self, arg_name: str, out_name: str, t: ListType, ctype: CType
) -> Tuple[List[str], List[str]]:
in_name = f"{arg_name}_list_in"
elem_name = f"{arg_name}_elem"
code = []
res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
t.elem, elem_name
)
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toTensorList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and (
t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
):
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toIntList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toDoubleList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
# handle list type with size, e.g., bool[4]
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toBoolList();
""".split(
"\n"
)
)
# pytorch codegen:
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<c10::optional<at::Tensor>>
elif (
isinstance(t.elem, OptionalType)
and isinstance(t.elem.elem, BaseType)
and t.elem.elem.name == BaseTy.Tensor
):
code.extend(
f"""
#ifdef USE_ATEN_LIB
at::ArrayRef<c10::optional<at::Tensor>> {in_name} = {arg_name}.toListOptionalTensor();
c10::List<c10::optional<at::Tensor>> {out_name};
for (auto {elem_name}: {in_name}) {{
{out_name}.push_back({elem_name});
}}
#else
torch::executor::ArrayRef<torch::executor::optional<torch::executor::Tensor>> {out_name} = {arg_name}.toListOptionalTensor();
#endif
""".split(
"\n"
)
)
else:
# use ArrayRef as default.
vec_name = arg_name + "_vec"
# need to bring vector instantiation out of scope so that ArrayRef has valid data
decl.append(
f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
)
code.extend(
f"""
for (EValue {elem_name}: {in_name}) {{
{connector.join(res_code)}
{vec_name}.push_back({res_name});
}}
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
""".split(
"\n"
)
)
return code, decl