import itertools
from abc import ABC
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torchgen.api.dispatcher as dispatcher
from torchgen.api.lazy import (
getValueT,
isValueType,
LazyArgument,
LazyIrProperties,
LazyIrSchema,
tensorListValueT,
)
from torchgen.api.translate import translate
from torchgen.api.types import (
BaseCType,
Binding,
deviceT,
DispatcherSignature,
kernel_signature,
NativeSignature,
OptionalCType,
VectorCType,
)
from torchgen.context import method_with_native_function
from torchgen.dest.lazy_ts_lowering import ts_lowering_body
from torchgen.model import (
Argument,
BackendIndex,
BackendMetadata,
BaseTy,
BaseType,
FunctionSchema,
ListType,
NativeFunction,
NativeFunctionsGroup,
)
def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
"""
Given a LazyArgument,
generate a c++ string for materializing an rvalue of that arg for passing into
a lazy Node constructor.
"""
# TODO: Matching on CType seems wrong; should be matching on Type
if isValueType(arg.lazy_type):
if isinstance(arg.lazy_type, BaseCType):
if arg.is_wrapped_scalar:
return f"node_{arg.name}"
elif arg.lazy_type.type is tensorListValueT:
return f"lazy_{arg.name}_tensorlist"
elif arg.is_symint_or_list:
return f"GetSymIntValue({arg.name})"
return f"lazy_{arg.name}->GetIrValue()"
elif isinstance(arg.lazy_type, OptionalCType):
if arg.is_symint_or_list:
# TODO: I don't understand when you should put lazy_ in the name
# or not
return f"{arg.name} ? c10::make_optional(GetSymIntValue(*{arg.name})) : c10::nullopt"
elif arg.is_wrapped_scalar:
return f"node_{arg.name}"
return (
f"lazy_{arg.name} ? "
f"c10::make_optional(lazy_{arg.name}->GetIrValue()) : "
"c10::nullopt"
)
else:
raise AssertionError(
f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
)
else:
# NB: this is here because right now we aren't treating SymInt[] as a
# value type; when we do this needs to move above
# NB: we cannot test arg.lazy_type as we've already specified it is an
# int64_t and so we cannot distinguish between SymInt and int64_t
if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
BaseTy.SymInt
):
if arg.symint:
return f"GetSymIntArrayRefValue({arg.name})"
else:
return f"std::vector<int64_t>({arg.name}.begin(), {arg.name}.end())"
elif isinstance(arg.lazy_type, VectorCType) and isinstance(
arg.lazy_type.elem, BaseCType
):
return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
elif (
isinstance(arg.lazy_type, OptionalCType)
and isinstance(arg.lazy_type.elem, VectorCType)
and isinstance(arg.lazy_type.elem.elem, BaseCType)
):
return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})"
else:
return f"{arg.name}"
def node_ctor_inputs(schema: LazyIrSchema) -> str:
"""
Produce a formatted string with the arguments as passed into the constructor of a node class.
"""
node_ctor_values = [
node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args()
]
return ", ".join(node_ctor_values)
def gen_fallback_code(
schema: LazyIrSchema,
sig: Union[DispatcherSignature, NativeSignature],
overload_name: str,
) -> str:
"""
Generate code that falls back to eager conditioned on a predicate
"""
dispatcher_sig = DispatcherSignature.from_schema(schema.func)
exprs = translate(sig.arguments(), dispatcher_sig.arguments())
fallback_args = ",\n ".join([a.expr for a in exprs])
if len(overload_name):
aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
else:
aten_op_str = f"ATEN_OP({schema.aten_name})"
or_has_generator = ""
if schema.generator_arg:
# generators are always optional and there is never more than one, at least currently
or_has_generator = f" || ({schema.generator_arg.name}.has_value() && {schema.generator_arg.name}->defined())"
return f"""
if (force_eager_fallback({aten_symbol(schema)}){or_has_generator}) {{
return at::native::call_fallback_fn_symint<<c_eager_fallback, {aten_op_str}>::call(
{fallback_args}
);
}}
"""
def aten_symbol(schema: LazyIrSchema) -> str:
missing_interned_strings = {
"sigmoid_backward",
}
if schema.aten_name in missing_interned_strings:
return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
if not schema.aten_name.startswith("at::"):
return f"at::aten::{schema.aten_name}"
else:
return schema.aten_name
# converts all tensor-like arguments to meta tensors. Returns:
# (1) a string containing all of the logic that does the conversions.
# (2) a context, to be used by translate(), with all of the relevant bindings.
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
context: List[Binding] = []
unwrapped_tensor_args: List[str] = []
for arg in sig.arguments():
if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
unwrapped_name = f"{arg.name}_meta"
unwrapped_tensor_args.append(
f"auto {unwrapped_name} = to_meta({arg.name});"
)
context.append(arg.with_name(unwrapped_name))
else:
context.append(arg)
unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
return unwrap_tensor_args_str, context
@dataclass(frozen=True)
class GenLazyIR(ABC):
backend_index: BackendIndex
backend_name: str
node_base: str
use_lazy_shape: bool
@method_with_native_function
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
metadata = self.backend_index.get_kernel(
f.functional if isinstance(f, NativeFunctionsGroup) else f
)
schema = LazyIrSchema(
func, symint=metadata is not None and metadata.supports_symint()
)
return self.gen(schema)
# there is no lowering functionality generated unless this IR base class is subclassed and
# implemented as a backend-specific node
def lowering_function(self, schema: LazyIrSchema) -> str:
return ""
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
return ""
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
return f"""bool CanBeReused({node_ctor_args}) const {{
return false;
}}"""
def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
value_args = schema.filtered_args(values=True, scalars=False)
# backends can customize the way the node base class constructor is called,
# as long as all of its arguments can be generated from information available from the schema
base_ctor_value_args_list = []
for arg in value_args:
if isinstance(arg.lazy_type, BaseCType) or isinstance(
arg.lazy_type, VectorCType
):
base_ctor_value_args_list.append(f"{arg.name}")
elif isinstance(arg.lazy_type, OptionalCType):
base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
else:
raise AssertionError(
f"Unsupported type ({arg.lazy_type}) - add support if necessary"
)
base_ctor_value_args = ", ".join(base_ctor_value_args_list)
scalar_args = schema.filtered_args(values=False, scalars=True)
# Shape constuction.
# Conditionally build shape depending on specified shape property
if schema.properties.ShapePrecompute:
shape_ctor_arg = "std::move(shapes),"
elif schema.properties.ShapeCompute:
shape_args = [a.name for a in value_args]
shape_args.extend(a.name for a in scalar_args)
shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
elif schema.properties.ShapeCache:
shape_args = [f"operand({i})" for i in range(len(value_args))]
shape_args.extend(a.name for a in scalar_args)
shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
else:
shape_ctor_arg = ""
scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
return f"""{self.node_base}(
{schema.node_name}::ClassOpKind(),
OpList{{{base_ctor_value_args}}},
{shape_ctor_arg}
/* num_outputs */ {len(schema.returns)},
torch::lazy::MHash({scalar_hashes}))"""
def gen(self, schema: LazyIrSchema) -> List[str]:
opkind = schema.opkind or aten_symbol(schema)
# for now, we just want one IR class decl and soon after also the method defs
# and we use the functional version not out/inplace.
all_args = schema.filtered_args()
value_args = schema.filtered_args(values=True, scalars=False)
scalar_args = schema.filtered_args(values=False, scalars=True)
ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
reuse_ctor_args = ", ".join(ctor_args)
if self.use_lazy_shape and schema.properties.ShapePrecompute:
ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
node_ctor_args = ", ".join(ctor_args)
scalar_initializers = ",\n ".join(
[
# This code is just special casing the mapping from string_view -> strings
f"{a.name}({a.name}.has_value() ? c10::make_optional(std::string(*{a.name})) : c10::nullopt)"
if a.lazy_type.cpp_type() == "c10::optional<c10::string_view>"
else f"{a.name}({a.name})"
for a in scalar_args
]
)
if len(scalar_initializers):
scalar_initializers = f",\n {scalar_initializers}"
scalar_decls = "\n ".join(
[
f"std::string {a.name};"
if a.lazy_type.cpp_type() == "c10::string_view"
else f"c10::optional<std::string> {a.name};"
if a.lazy_type.cpp_type() == "c10::optional<c10::string_view>"
else f"{a.lazy_type.cpp_type()} {a.name};"
for a in scalar_args
]
)
optional_values = [
arg.name
for arg in schema.filtered_args(values=True, scalars=False)
if isinstance(arg.lazy_type, OptionalCType)
]
has_optional_decls = "\n ".join(
[f"bool has_{value}: 1;" for value in optional_values]
)
has_optional_defs = "\n ".join(
[f"has_{value} = !!{value};" for value in optional_values]
)
members_to_string = []
for arg in scalar_args:
if isinstance(arg.lazy_type, OptionalCType):
members_to_string.append(
f"""if ({arg.name}.has_value()) {{
ss << ", {arg.name}=" << {arg.name}.value();
}} else {{
ss << ", {arg.name}=null";
}}"""
)
else:
members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};')
members_to_string_str = "\n ".join(members_to_string)
return [
f"""\
class {schema.node_name} : public {self.node_base} {{
public:
static torch::lazy::OpKind ClassOpKind() {{
return torch::lazy::OpKind({opkind});
}}
{schema.node_name}({node_ctor_args})
: {self.node_base_ctor_call(schema)}{scalar_initializers}
{{
{has_optional_defs}
}}
std::string ToString() const override {{
std::stringstream ss;
ss << {self.node_base}::ToString();
{members_to_string_str}
return ss.str();
}}
{self.create_function(schema, reuse_ctor_args)}
{self.can_be_reused_function(schema, reuse_ctor_args)}
{self.lowering_function(schema)}
{scalar_decls}
{has_optional_decls}
}};
""",
]
@dataclass(frozen=True)
class GenTSLazyIR(GenLazyIR):
def lowering_function(self, schema: LazyIrSchema) -> str:
signature = """
Loading ...