Repository URL to install this package:
|
Version:
2.4.1 ▾
|
# mypy: allow-untyped-defs
import functools
import math
import os
import sys
from itertools import count
from typing import Dict, List, Optional, Tuple
import sympy
from sympy import Expr
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
import torch._ops
from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey
from .. import config, ir
from ..codecache import CudaKernelParamCache
from ..utils import _align, ALIGN_BYTES, cache_on_self, sympy_product
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import IndentedBuffer
from .cpp_utils import (
cexpr,
CppPrinter,
DEVICE_TO_ATEN,
DTYPE_TO_ATEN,
DTYPE_TO_CPP,
LAYOUT_TO_ATEN,
)
from .wrapper import EnterSubgraphLine, ExitSubgraphLine, WrapperCodeGen
class CppWrapperCpu(WrapperCodeGen):
"""
Generates cpp wrapper for running on CPU and calls cpp kernels
"""
def __init__(self):
if not hasattr(self, "device"):
self.device = "cpu"
super().__init__()
self.declare = "auto "
self.declare_maybe_reference = "decltype(auto) "
self.ending = ";"
self.open_bracket = "{"
self.closed_bracket = "}"
self.comment = "//"
self.namespace = "at::"
self.none_str = "nullptr" if config.abi_compatible else "at::Tensor()"
self.extern_call_ops = set()
self.size = "sizes()"
self.stride = "strides()"
self.cuda = False
self.supports_intermediate_hooks = False
self.outputs_need_copy = set()
self.kernel_callsite_id = count()
self.var_array_id = (
count()
) # for different types of local array variable declarations
self.declared_var_array_vars = set()
self.int_array_id = count() # for int array local variable declarations
self.declared_int_array_vars = set()
self.tmp_tensor_id = count() # for tmp tensor local variable declarations
self.arg_var_id = count()
self.used_cached_devices = set()
self.used_cached_dtypes = set()
self.used_cached_layouts = set()
self.cached_output_id = count()
self.scalar_to_tensor_id = count()
self.custom_op_wrapper_loaded = False
self.expr_printer = cexpr
# CppPrinter sometimes calls at::native functions which causes problems in
# the ABI-compatible mode. Currently we are hitting this problem when codegen
# Grid computation expressions, but we my need to fix other size computation
# as well.
class GridExprCppPrinter(CppPrinter):
def _print_FloorDiv(self, expr):
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
assert expr.is_integer, "Expect integers in GridExprPrinter"
return f"({x}/{div})"
self.grid_expr_printer = GridExprCppPrinter().doprint
def generate_kernel_call(
self,
name,
call_args,
grid=None,
device_index=None,
cuda=True,
triton=True,
arg_types=None,
grid_fn: str = "grid",
triton_meta=None,
):
"""
Generates kernel call code.
cuda: Defines whether the backend is GPU. Otherwise the backend is CPU.
triton: Defines whether the GPU backend uses Triton for codegen.
Otherwise it uses the CUDA language for codegen.
Only valid when cuda == True.
"""
if cuda:
return super().generate_kernel_call(
name,
call_args,
grid,
device_index,
cuda,
triton,
arg_types,
grid_fn,
)
else:
if config.abi_compatible:
assert arg_types is not None and len(call_args) == len(
arg_types
), "Mismatch call_args and arg_types in generate_kernel_call"
new_args = []
for idx, arg in enumerate(call_args):
if "*" in arg_types[idx]:
var_name = f"var_{next(self.arg_var_id)}"
self.writeline(
f"auto* {var_name} = get_data_ptr_wrapper({arg});"
)
new_args.append(f"({arg_types[idx]})({var_name})")
else:
# arg is a scalar
new_args.append(arg)
self.writeline(self.wrap_kernel_call(name, new_args))
else:
self.writeline(self.wrap_kernel_call(name, call_args))
def write_constant(self, name, hashed):
# include a hash so our code cache gives different constants different files
self.header.writeline(f"// {name} {hashed}")
def write_header(self):
if V.graph.is_const_graph:
# We do not write header for constant graph, it will be written by main module.
return
if V.graph.aot_mode:
for header_cpp_file in ("interface.cpp", "implementation.cpp"):
with open(
os.path.join(
os.path.dirname(__file__), "aoti_runtime", header_cpp_file
)
) as f:
self.header.splice(f.read())
else:
self.header.splice(
"""
import torch
from torch._inductor.codecache import CppWrapperCodeCache
cpp_wrapper_src = (
'''
"""
)
if config.abi_compatible:
if config.c_shim_version == "1":
self.header.splice("#include <torch/csrc/inductor/aoti_torch/c/shim.h>")
else:
self.header.splice(
f"#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{self.device}.h>"
)
self.header.splice(
"""
#include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
#include <torch/csrc/inductor/aoti_runtime/thread_local.h>
#include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
"""
)
if V.graph.aot_mode:
self.header.splice(
"""
#include <torch/csrc/inductor/aoti_runtime/model.h>
"""
)
else:
self.header.splice(
"""
#include <ATen/ATen.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/native/BinaryOps.h>
#include <torch/csrc/inductor/aoti_runtime/utils.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <torch/csrc/inductor/inductor_ops.h>
#include <torch/types.h>
#include <ATen/ops/bernoulli_native.h>
#define reinterpret_tensor torch::inductor::_reinterpret_tensor
#define alloc_from_pool torch::inductor::_alloc_from_pool
"""
)
self.header.splice("#include <c10/util/generic_math.h>")
if not V.graph.aot_mode:
self.header.splice(
"""
#include <pybind11/pybind11.h>
namespace py = pybind11;
using namespace torch::aot_inductor;
class RAIIPyObject {
public:
RAIIPyObject() : obj_(nullptr) {}
RAIIPyObject(PyObject* obj) : obj_(obj) {}
~RAIIPyObject() {
Py_XDECREF(obj_);
}
RAIIPyObject& operator=(const RAIIPyObject& other) {
if (this != &other) {
Py_XDECREF(obj_);
obj_ = other.obj_;
Py_XINCREF(obj_);
}
return *this;
}
operator PyObject*() {
return obj_;
}
PyObject* get() {
return obj_;
}
private:
PyObject* obj_;
};
"""
)
# Round up to the nearest multiple of ALIGN_BYTES
# ALIGN_BYTES must be a power of 2
self.header.splice(
f"""
[[maybe_unused]] static int64_t align(int64_t nbytes) {{
return (nbytes + {ALIGN_BYTES} - 1) & -{ALIGN_BYTES};
}}
"""
)
def mark_output_type(self):
# mark output type to unwrap tensor back to python scalar
from ..ir import ShapeAsConstantBuffer
output_is_tensor = dict()
for idx, x in enumerate(V.graph.graph_outputs):
if isinstance(x, ShapeAsConstantBuffer):
output_is_tensor[idx] = False
else:
output_is_tensor[idx] = True
self.output_is_tensor = output_is_tensor
def write_prefix(self):
if V.graph.is_const_graph:
# We do not write prefix for constant graph, it will be written by main module.
return
if V.graph.aot_mode:
self.prefix.writeline("namespace torch {")
self.prefix.writeline("namespace aot_inductor {")
def write_input_output_info(
self,
info_kind: str,
idx: int,
name: str,
):
self.prefix.writeline(f"""{info_kind}[{idx}].name = "{name}";""")
@staticmethod
def get_input_cpp_type(input):
assert config.use_minimal_arrayref_interface
if isinstance(input, sympy.Expr):
from ..graph import may_get_constant_buffer_dtype
dtype = may_get_constant_buffer_dtype(input)
assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}"
return DTYPE_TO_CPP[dtype]
return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>"
def generate_input_output_runtime_checks(self):
# In debug_compile mode, we generate checks to ensure the dtype/shape/stride of each
# real input/output tensor match ones provided at compile time via sample
# input/output.
def gen_check(handle_kind, idx, name, tensor):
self.prefix.writeline(f"auto {name} = {handle_kind}[{idx}];")
self.codegen_tensor_dtype_var_decl(self.prefix, name)
expected_dtype_name = DTYPE_TO_ATEN[tensor.dtype]
dtype_str = str(tensor.dtype).split(".")[-1]
self.prefix.splice(
f"""
int32_t {name}_expected_dtype = aoti_torch_dtype_{dtype_str}();
if ({name}_expected_dtype != {name}_dtype) {{
std::stringstream ss;
ss << "{handle_kind}[{idx}]: unmatched dtype, "
<< "expected: " << {name}_expected_dtype << "({expected_dtype_name}), "
<< "but got: " << {name}_dtype << "\\n";
throw std::runtime_error(ss.str());
}}
"""
)
self.codegen_input_size_var_decl(self.prefix, name)
for dim_idx, d in enumerate(tensor.get_size()):
if isinstance(d, (int, sympy.Integer)):
self.prefix.splice(
f"""
if ({d} != {name}_size[{dim_idx}]) {{
std::stringstream ss;
ss << "{handle_kind}[{idx}]: unmatched dim value at {dim_idx}, "
<< "expected: {d}, " << "but got: " << {name}_size[{dim_idx}]
<< "\\n";
throw std::runtime_error(ss.str());
}}
"""
)
else:
assert isinstance(
d, sympy.Symbol
), f"dimention at {dim_idx=} for tensor {name=} must be a sympy.Symbol"
sym_range = V.graph.sizevars.shape_env.var_to_range.get(d, None)
if sym_range is None:
continue
if not math.isinf(sym_range.lower):
self.prefix.splice(
f"""
if ({name}_size[{dim_idx}] < {sym_range.lower}) {{
std::stringstream ss;
ss << "{handle_kind}[{idx}]: dim value is too small at {dim_idx}, "
<< "expected it to be >= {sym_range.lower}, " << "but got: "
<< {name}_size[{dim_idx}] << "\\n";
throw std::runtime_error(ss.str());
}}
"""
)
if not math.isinf(sym_range.upper):
self.prefix.splice(
f"""
if ({name}_size[{dim_idx}] > {sym_range.upper}) {{
std::stringstream ss;
ss << "{handle_kind}[{idx}]: dim value is too large at {dim_idx}, "
<< "expected to be <= {sym_range.upper}, " << "but got: "
<< {name}_size[{dim_idx}] << "\\n";
throw std::runtime_error(ss.str());
}}
"""
)
self.codegen_input_stride_var_decl(self.prefix, name)
for stride_idx, s in enumerate(tensor.get_stride()):
if not isinstance(s, (int, sympy.Integer)):
continue
self.prefix.splice(
f"""
if ({s} != {name}_stride[{stride_idx}]) {{
std::stringstream ss;
ss << "{handle_kind}[{idx}]: unmatched stride value at {stride_idx}, "
<< "expected: {s}, " << "but got: " << {name}_stride[{stride_idx}]
<< "\\n";
throw std::runtime_error(ss.str());
}}
"""
)
# force noinline to avoid any potential compilation slowdown due to aggressive
# inline done by the host compiler
self.prefix.splice(
"""
AOTI_NOINLINE static void __check_inputs_outputs(
AtenTensorHandle* input_handles,
AtenTensorHandle* output_handles) {
"""
)
with self.prefix.indent():
for idx, (name, tensor) in enumerate(V.graph.graph_inputs.items()):
gen_check("input_handles", idx, name, tensor)
self.prefix.writeline("}")
def write_wrapper_decl(self):
inputs_len = len(V.graph.graph_inputs.keys())
if V.graph.aot_mode:
if config.use_minimal_arrayref_interface and not V.graph.is_const_graph:
input_cpp_types = ", ".join(
f"{CppWrapperCpu.get_input_cpp_type(x)}"
for x in V.graph.graph_inputs.values()
)
output_arrayref_types = ", ".join(
f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>"
for x in V.graph.graph_outputs
)
self.prefix.splice(
f"""
using AOTInductorModelInputs = std::tuple<{input_cpp_types}>;
using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>;
"""
)
if V.graph.const_module:
self.header.splice(V.graph.const_module.wrapper_code.header)
self.prefix.splice(V.graph.const_code)
if V.graph.is_const_graph:
self.prefix.splice(
"""
void AOTInductorModel::_const_run_impl(
std::vector<AtenTensorHandle>& output_handles,
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor
) {
"""
)
else:
if not config.aot_inductor.use_runtime_constant_folding:
# If we do not split the constant graph, we'll just create
# an empty implementation when wrapping the main module.
self.prefix.splice(
"""
void AOTInductorModel::_const_run_impl(
std::vector<AtenTensorHandle>& output_handles,
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor
) {}
"""
)
run_impl_proto = """
void AOTInductorModel::run_impl(
AtenTensorHandle*
input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed
AtenTensorHandle*
output_handles, // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor
) {
"""
# Since we are removing non-abi-compatible mode, let's generate
# runtime checks only for abi_compatible mode to avoid extra branches.
if config.aot_inductor.debug_compile and config.abi_compatible:
self.generate_input_output_runtime_checks()
run_impl_proto += """
__check_inputs_outputs(input_handles, output_handles);
"""
if config.use_minimal_arrayref_interface:
self.prefix.splice(
"""
template <>
AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface<
AOTInductorModelInputs, AOTInductorModelOutputs>(
const AOTInductorModelInputs& inputs,
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor
) {
"""
)
self.suffix.splice(run_impl_proto)
self.suffix.splice(
"""
AOTInductorModelInputs inputs;
convert_handles_to_inputs(input_handles, inputs);
auto outputs = run_impl_minimal_arrayref_interface<AOTInductorModelInputs, AOTInductorModelOutputs>(
inputs, stream, proxy_executor);
// NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this
// interface to perform well for a DSO using the minimal arrayref interface, all we need
// to do is provide ThreadLocalCachedTensor for each one!
convert_outputs_to_handles(outputs, output_handles);
}
"""
)
self.suffix.splice(
"""
extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface(
AOTInductorModelHandle model_handle,
const AOTInductorModelInputs& inputs,
AOTInductorModelOutputs& outputs) {
auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
outputs = model->run_impl_minimal_arrayref_interface<AOTInductorModelInputs, AOTInductorModelOutputs>(
inputs,
(torch::aot_inductor::DeviceStreamType)nullptr,
nullptr);
})
}
"""
)
else:
self.prefix.splice(run_impl_proto)
else:
# cpp entry function for JIT with cpp wrapper
self.prefix.splice(
"""
void inductor_entry_impl(
AtenTensorHandle*
input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed
AtenTensorHandle*
output_handles // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed)
) {
"""
)
with self.prefix.indent():
# assign inputs and outputs in both cases so the later codegen can be simplified
if not config.use_minimal_arrayref_interface:
if not V.graph.is_const_graph:
if V.graph.aot_mode:
num_args = len(V.graph.graph_inputs)
else:
# Weights are promoted in the JIT mode
num_args = len(V.graph.graph_inputs) + len(V.graph.constants)
# release GIL to support multiple instances inference (in different threads of the same process)
self.prefix.splice("py::gil_scoped_release release;")
if config.abi_compatible:
self.prefix.splice(
f"""
auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args});
"""
)
else:
# This looks dumb, but can avoid creating two versions of code in the AOTInductor runtime.
self.prefix.splice(
f"""
auto inputs = alloc_tensors_by_stealing_from_handles(input_handles, {num_args});
"""
)
if inputs_len != 0:
for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
if config.use_minimal_arrayref_interface:
self.prefix.writeline(
f"auto {input_key} = std::get<{idx}>(inputs);"
)
continue
# unwrap input tensor back to scalar
if isinstance(V.graph.graph_inputs[input_key], sympy.Expr):
from ..graph import may_get_constant_buffer_dtype
dtype = may_get_constant_buffer_dtype(
V.graph.graph_inputs[input_key]
)
assert (
dtype is not None
), "Fails to get the dtype of the sympy.Expr"
cpp_dtype = DTYPE_TO_CPP[dtype]
if config.abi_compatible:
self.codegen_tensor_item(
dtype, f"inputs[{idx}]", input_key, self.prefix
)
else:
self.prefix.writeline(
f"{cpp_dtype} {input_key} = inputs[{idx}].item<{cpp_dtype}>();"
)
else:
self.prefix.writeline(
f"auto {input_key} = std::move(inputs[{idx}]);"
)
assert all(
isinstance(v, torch.Tensor) for v in list(V.graph.constants.values())
), "Expect all constants to be Tensor"
for idx, constants_key in enumerate(V.graph.constants.keys()):
if V.graph.aot_mode:
# Weights are stored in constants_ and owned by RAIIAtenTensorHandle there.
# Don't call std::move here because it will cause constants_ to lose the ownership.
if config.abi_compatible:
self.prefix.writeline(
f"""auto {constants_key} = constants_->at({idx});"""
)
else:
self.prefix.writeline(
f"auto {constants_key} = *tensor_handle_to_tensor_pointer("
+ f"""constants_->at({idx}));"""
)
else:
# Append constants as inputs to the graph
constants_idx = inputs_len + idx
if config.abi_compatible:
self.prefix.writeline(
f"auto {constants_key} = std::move(inputs[{constants_idx}]);"
)
else:
self.prefix.writeline(
f"auto {constants_key} = inputs[{constants_idx}];"
)
self.codegen_inputs(self.prefix, V.graph.graph_inputs)
if V.graph.aot_mode:
if not V.graph.is_const_graph:
if config.use_minimal_arrayref_interface:
# TODO: input shape checking for regular tensor interface as well?
self.codegen_input_numel_asserts()
else:
self.prefix.writeline("inputs.clear();")
self.prefix.writeline(
"auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());"
)
def codegen_input_numel_asserts(self):
for name, buf in V.graph.graph_inputs.items():
if isinstance(buf, sympy.Expr):
continue
# comparing strides for 0 size tensor is tricky. Ignore them for now.
if sympy_product(buf.get_size()) == 0:
continue
numel = buf.get_numel()
self.prefix.writeline(f"assert_numel({name}, {numel});")
def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name):
if config.abi_compatible:
code.writeline(f"int32_t {name}_dtype;")
code.writeline(
"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype"
f"({name}, &{name}_dtype));"
)
else:
# Note that we don't have a corresponding class method from
# the WrapperCodeGen since this method is used for asserting AOTI
# cpp wrapper code.
code.writeline(f"auto {name}_dtype = {name}.dtype();")
def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
if config.abi_compatible:
code.writeline(f"int64_t* {name}_size;")
code.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes({name}, &{name}_size));"
)
else:
super().codegen_input_size_var_decl(code, name)
def codegen_input_stride_var_decl(self, code: IndentedBuffer, name):
if config.abi_compatible:
code.writeline(f"int64_t* {name}_stride;")
code.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides({name}, &{name}_stride));"
)
else:
super().codegen_input_stride_var_decl(code, name)
def codegen_model_kernels(self):
self.prefix.writeline("namespace {")
self.prefix.writeline(
"class AOTInductorModelKernels : public AOTInductorModelKernelsBase {"
)
self.prefix.writeline(" public:")
declare_kernel = set(self.src_to_kernel.values())
declare_kernel.update(
entry[0] for entry in self.user_defined_kernel_cache.values()
)
if V.graph.const_module:
declare_kernel.update(
V.graph.const_module.wrapper_code.src_to_kernel.values()
)
for kernel in sorted(declare_kernel):
self.prefix.writeline(
maybe_hipify_code_wrapper(f" CUfunction {kernel}{{nullptr}};")
)
self.prefix.writeline("};")
self.prefix.writeline("} // namespace")
def codegen_model_constructor(self):
"""
// Generated code example
AOTInductorModel::AOTInductorModel()
: AOTInductorModelBase(4, 1) {
inputs_info_[0].name = "input0";
inputs_info_[0].dtype = "torch.float16";
...
constants_info_[0].name = "L__self___weight";
constants_info_[0].dtype = at::kFloat;
constants_info_[0].offset = 0;
constants_info_[0].data_size = 8192;
constants_info_[0].shape = {64, 32};
constants_info_[0].stride = {32, 1};
...
outputs_info_[0].name = "output0";
outputs_info_[0].dtype = "torch.float16";
}
"""
num_inputs = len(V.graph.graph_inputs)
num_outputs = len(V.graph.graph_outputs)
num_constants = len(V.graph.constants)
self.prefix.splice(
f"""
AOTInductorModel::AOTInductorModel(std::shared_ptr<ConstantMap> constants_map,
std::shared_ptr<std::vector<ConstantHandle>> constants_array,
const std::string& device_str,
std::optional<std::string> cubin_dir)
: AOTInductorModelBase({num_inputs}, {num_outputs}, {num_constants}, device_str, cubin_dir) {{
"""
)
with self.prefix.indent():
for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()):
assert not isinstance(
inp, sympy.Expr
), f"input {name=} cannot be symbolic"
self.write_input_output_info("inputs_info_", idx, name)
all_cuda = all(
V.graph.get_original_value_of_constant(name).is_cuda
for name in V.graph.constants.keys()
if name not in V.graph.folded_constants
)
for idx, name in enumerate(V.graph.constants.keys()):
tensor = V.graph.get_original_value_of_constant(name)
assert isinstance(tensor, torch.Tensor)
self.prefix.writeline(f"""constants_info_[{idx}].name = "{name}";""")
self.prefix.writeline(
f"constants_info_[{idx}].dtype = static_cast<int32_t>({self.codegen_dtype(tensor.dtype)});"
)
self.prefix.writeline(
f"constants_info_[{idx}].offset = {tensor.storage_offset()};"
)
# If constants to serialize contain cpu tensors, we always align data_size it to 64.
# When loading the constants, the valid data will depends on the size
# not the data_size so there won't be correctness issue.
data_size = (
torch.ops.mkldnn._nbytes(tensor)
if tensor.is_mkldnn
else tensor.untyped_storage().nbytes()
)
self.prefix.writeline(
f"constants_info_[{idx}].data_size = {data_size if all_cuda else _align(data_size)};"
)
from_folded = "true" if name in V.graph.folded_constants else "false"
self.prefix.writeline(
f"constants_info_[{idx}].from_folded = {from_folded};"
)
size_str = ", ".join([str(s) for s in tensor.size()])
self.prefix.writeline(f"constants_info_[{idx}].shape = {{{size_str}}};")
stride_str = ", ".join([str(s) for s in tensor.stride()])
self.prefix.writeline(
f"constants_info_[{idx}].stride = {{{stride_str}}};"
)
self.prefix.writeline(
f"constants_info_[{idx}].layout = static_cast<int32_t>({self.codegen_layout(tensor.layout)});"
)
if tensor.is_mkldnn:
opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md(
tensor
)
assert (
opaque_metadata_tensor.dim() == 1
), "Expect opaque_metadata_tensor to be 1-D"
opaque_metadata_list = opaque_metadata_tensor.tolist()
opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list)
self.prefix.writeline(
f"constants_info_[{idx}].opaque_metadata = {opaque_metadata_str};"
)
if name in V.graph.dynamo_flat_name_to_original_fqn:
original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get(
name, name
)
elif name in V.graph.allocated_constant_name:
original_fqn = V.graph.allocated_constant_name[name]
else:
raise AssertionError("original_fqn must be set for constant")
self.prefix.writeline(
f"""constants_info_[{idx}].original_fqn = "{original_fqn}";"""
)
self.prefix.writeline("update_constants_map(std::move(constants_map));")
self.prefix.writeline("update_constants_array(std::move(constants_array));")
def escape_string(x):
return (
x.replace("\\", "\\\\")
.replace('"', '\\"')
.replace("\n", "\\n")
.replace("\t", "\\t")
)
self.prefix.writeline(
f'in_spec_ = "{escape_string(config.aot_inductor.serialized_in_spec)}";'
)
self.prefix.writeline(
f'out_spec_ = "{escape_string(config.aot_inductor.serialized_out_spec)}";'
)
for idx, output in enumerate(V.graph.graph_outputs):
assert not isinstance(
output, sympy.Expr
), f"output {name=} cannot be symbolic"
name = f"output{idx}"
self.write_input_output_info("outputs_info_", idx, name)
self.prefix.writeline(
"this->kernels_ = std::make_unique<AOTInductorModelKernels>();"
)
self.prefix.writeline("}")
def codegen_const_run_driver(self):
"""
// Generated code example
std::unordered_map<std::string, AtenTensorHandle> AOTInductorModel::const_run_impl(
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor,
bool initialization
) {
std::unordered_map<std::string, AtenTensorHandle> folded_constants_map;
std::vector<AtenTensorHandle> output_handles;
// build up output_handles over here.
_const_run_impl(output_handles, stream, proxy_executor);
// build up folded_constants_map
return folded_constants_map;
}
"""
self.prefix.splice(
"""
std::unordered_map<std::string, AtenTensorHandle> AOTInductorModel::const_run_impl(
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor,
bool initialization
) {
"""
)
if not config.aot_inductor.use_runtime_constant_folding:
self.prefix.splice(
"""
if (!initialization) {
std::cerr << "[WARNING] Calling constant_folding in model, but compiled with config: "
<< "aot_inductor.use_runtime_constant_folding=False\\n";
}
return {};
}
"""
)
return
with self.prefix.indent():
# This is a mapping to the index of constant folding graph's output
const_index_mapping: List[Optional[Tuple[int, str]]] = [None] * len(
V.graph.const_output_index
)
for idx, (name, _) in enumerate(V.graph.constants.items()):
if name in V.graph.const_output_index:
const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload]
assert (
None not in const_index_mapping
), "Not all constant gets mapped for constant folding graph."
self.prefix.writeline(
f"""
std::unordered_map<std::string, AtenTensorHandle> folded_constants_map;
folded_constants_map.reserve({len(const_index_mapping)});
std::vector<AtenTensorHandle> output_handles({len(const_index_mapping)});
"""
)
self.prefix.splice(
"""
// The below assignment of output_handles to constants is not used directly.
// It's only used to memo the correspondence of handle and constants.
"""
)
for output_idx, (const_idx, _) in enumerate(const_index_mapping): # type: ignore[misc]
self.prefix.writeline(
f"output_handles[{output_idx}] = constants_->at({const_idx});"
)
self.prefix.writeline(
"_const_run_impl(output_handles, stream, proxy_executor);"
)
for output_idx, (_, const_name) in enumerate(const_index_mapping): # type: ignore[misc]
self.prefix.writeline(
f'folded_constants_map["{const_name}"] = output_handles[{output_idx}];'
)
self.prefix.writeline("return folded_constants_map;")
self.prefix.writeline("}")
def generate(self, is_inference):
if V.graph.aot_mode and not V.graph.is_const_graph:
self.codegen_model_kernels()
self.codegen_model_constructor()
self.codegen_const_run_driver()
self.write_wrapper_decl()
return super().generate(is_inference)
def finalize_prefix(self):
cached_dtypes_buffer = IndentedBuffer()
if config.abi_compatible:
for dtype in self.used_cached_dtypes:
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});")
for device in self.used_cached_devices:
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});")
for layout in self.used_cached_layouts:
cached_dtypes_buffer.writeline(f"CACHE_TORCH_LAYOUT({layout});")
cached_dtypes_buffer.splice(self.prefix)
self.prefix = cached_dtypes_buffer
def define_kernel(
self, name: str, kernel: str, metadata: Optional[str] = None, cuda=False
):
self.header.splice(f"\n{kernel}\n")
def codegen_scalar_to_tensor(self, output: str):
name = f"scalar_to_tensor_{next(self.scalar_to_tensor_id)}"
self.wrapper_call.writeline(
f"RAIIAtenTensorHandle {name} = scalar_to_tensor_handle({output});"
)
return name
def codegen_tensor_item(
self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None
):
assert (
config.abi_compatible
), "codegen_tensor_item is only used for the ABI-compatible mode"
dtype_str = str(dtype).split(".")[-1]
writer = indented_buffer or self
writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};")
writer.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));"
)
@cache_on_self
def get_output_refs(self):
return [
f"torch::tensor({x.codegen_reference(self.wrapper_call)})"
if isinstance(x, ir.ShapeAsConstantBuffer) and not config.abi_compatible
else x.codegen_reference(self.wrapper_call)
for x in V.graph.graph_outputs
]
def generate_return(self, output_refs: List[str]):
cst_names = V.graph.constants.keys()
arr_iface = (
not V.graph.is_const_graph and config.use_minimal_arrayref_interface
) # For brevity.
def use_thread_local_cached_output_tensor(idx, output):
cached_output_name = f"cached_output_{next(self.cached_output_id)}"
cache_type = "Array" if arr_iface else "Tensor"
self.wrapper_call.writeline(
f"thread_local ThreadLocalCachedOutput{cache_type}<std::decay_t<decltype({output})>> "
f"{cached_output_name}({output});"
)
if arr_iface:
self.wrapper_call.writeline(
f"{cached_output_name}.copy_data_from({output});"
)
output_entry = f"std::get<{idx}>(output_arrayref_tensors)"
element_type = f"std::decay_t<decltype({output_entry}.data()[0])>"
self.wrapper_call.writeline(
f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();"
)
else:
self.wrapper_call.writeline(
f"{cached_output_name}.copy_data_from({output});"
)
self.wrapper_call.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));"
)
self.wrapper_call.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), "
f"output_handles[{idx}]));"
)
if arr_iface:
self.wrapper_call.writeline(
"AOTInductorModelOutputs output_arrayref_tensors;"
)
output2idx: Dict[str, int] = {}
for idx, output in enumerate(output_refs):
if output == self.none_str:
continue
is_constant_buffer = output in cst_names
output_buffer = V.graph.graph_outputs[idx]
if isinstance(output_buffer, ir.BaseView):
output_storage = output_buffer.unwrap_view()
if isinstance(output_storage.data, ir.ConstantBuffer):
is_constant_buffer = True
if config.abi_compatible:
if isinstance(output_buffer, ir.ShapeAsConstantBuffer):
# Need to wrap scalar into tensor as the main function returns a vector of tensors
output_tensor = self.codegen_scalar_to_tensor(output)
self.wrapper_call.writeline(
f"output_handles[{idx}] = {output_tensor}.release();"
)
continue
output_is_tensor_handle_expr = (
f"std::is_same_v<std::decay_t<decltype({output})>,"
"RAIIAtenTensorHandle> || "
f"std::is_same_v<std::decay_t<decltype({output})>,"
"AtenTensorHandle> || "
f"std::is_same_v<std::decay_t<decltype({output})>,"
"ConstantHandle>"
)
self.wrapper_call.writeline(
f"if constexpr ({output_is_tensor_handle_expr}) {{"
)
with self.wrapper_call.indent():
if arr_iface:
cached_output_name = (
f"cached_output_{next(self.cached_output_id)}"
)
output_value_type = f"std::decay_t<decltype(std::get<{idx}>(output_arrayref_tensors).data()[0])>"
self.wrapper_call.writeline(
f"thread_local RAIIAtenTensorHandle {cached_output_name};"
)
if is_constant_buffer:
# NOTE(return_constant): In some rare cases where we return
# a constant, we have to return a copy of this constant,
# because (1) constants are not owned by the Model instance
# (2) constants remain the same cross inference runs,
# assuming they are not updated at runtime Basically, we
# cannot release or transfer the ownership of any original
# constant to the user.
self.wrapper_call.writeline(
f"AtenTensorHandle {cached_output_name}_tmp;"
)
self.wrapper_call.writeline(
f"aoti_torch_clone({output}, &{cached_output_name}_tmp);"
)
self.wrapper_call.writeline(
f"{cached_output_name} = {cached_output_name}_tmp;"
)
else:
self.wrapper_call.writeline(
f"{cached_output_name} = {output}.release();"
)
self.wrapper_call.writeline(
f"convert_handle_to_arrayref_tensor({cached_output_name}, "
f"std::get<{idx}>(output_arrayref_tensors));"
)
else:
if is_constant_buffer:
# See NOTE(return_constant) above.
self.wrapper_call.writeline(
f"aoti_torch_clone({output}, &output_handles[{idx}]);"
)
else:
if output in output2idx:
src_idx = output2idx[output]
self.wrapper_call.writeline(
f"output_handles[{idx}] = output_handles[{src_idx}];"
)
else:
self.wrapper_call.writeline(
f"output_handles[{idx}] = {output}.release();"
)
self.wrapper_call.writeline("} else {")
with self.wrapper_call.indent():
use_thread_local_cached_output_tensor(idx, output)
self.wrapper_call.writeline("}")
else:
assert (
not arr_iface
), "minimal ArrayRef interface is only supported in ABI-compatible mode"
if is_constant_buffer:
output_expr = f"{output}.clone()"
# See NOTE(return_constant) above.
else:
output_expr = output
self.wrapper_call.writeline(
f"output_handles[{idx}] = reinterpret_cast<AtenTensorHandle>("
+ f"new at::Tensor({output_expr}));"
)
if output not in output2idx:
output2idx[output] = idx
if arr_iface:
self.wrapper_call.writeline("return output_arrayref_tensors;")
def generate_before_suffix(self, result):
if not V.graph.is_const_graph:
if V.graph.aot_mode:
result.writeline("} // AOTInductorModel::run_impl")
else:
result.writeline("} // inductor_entry_impl")
def generate_end(self, result):
if V.graph.aot_mode:
if V.graph.is_const_graph:
result.writeline("} // AOTInductorModel::_const_run_impl")
else:
result.writeline("} // namespace aot_inductor")
result.writeline("} // namespace torch")
return
# cpp entry function for JIT with cpp wrapper
result.writeline("'''\n)")
result.splice(
f"""
inductor_entry = CppWrapperCodeCache.load_pybinding(
["std::vector<AtenTensorHandle>"], cpp_wrapper_src, {self.cuda}, {len(V.graph.graph_outputs)})
"""
)
wrapper_body = "input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]"
if V.graph.constants:
# Append constants to the input args for cpp wrapper.
# Python wrapper directly gets the value inside the wrapper call
# as a global variable passed when calling exec(code, mod.__dict__, mod.__dict__).
# For cpp wrapper, we need to pass this python value to the inductor_entry_impl function explicitly.
assert all(
isinstance(v, torch.Tensor) for v in list(V.graph.constants.values())
), "Expect all constants to be Tensor"
constants_str = f"[{', '.join(V.graph.constants.keys())}]"
wrapper_body += f"""
constants_tensor = {constants_str}
input_tensors.extend(constants_tensor)
"""
# Convert vector of at::Tensor to vector of AtenTensorHandle.
# If we pass at::Tensor, the compilation will be too slow.
wrapper_body += """
input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors)
"""
# unwrap output tensor back to python scalar
if all(x for x in self.output_is_tensor.values()):
# If no ShapeAsConstantBuffer in the output, directly return the output as tensors
outputs_str = "output_tensors"
else:
outputs = [
f"output_tensors[{i}]"
if self.output_is_tensor[i]
else f"output_tensors[{i}].item()"
for i in range(len(V.graph.graph_outputs))
]
outputs_str = f"[{', '.join(outputs)}]"
wrapper_body += f"""
output_handles = f(input_handles)
output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles)
return {outputs_str}
"""
# Wrap the func to support setting result._boxed_call = True
result.splice(
f"""
def _wrap_func(f):
def g(args):
{wrapper_body}
return g
call = _wrap_func(inductor_entry)
"""
)
def get_c_shim_func_name(self, kernel):
if not config.abi_compatible or kernel.startswith("aoti_torch_"):
return kernel
assert "::" in kernel, "Cpp kernel name: " + kernel + " does not contain '::'"
kernel_tokens = kernel.split("::")
kernel_suffix = kernel_tokens[-1]
if kernel_suffix == "call":
kernel_suffix = kernel_tokens[-2]
if config.c_shim_version == "1":
# For sdpa, we need the v2 version since v1 didn't consider optional arg
# FIXME: no need to do this after we switch to the torchgen-ed C shim
if kernel_suffix == "_scaled_dot_product_flash_attention":
shim_fn = "aoti_torch__scaled_dot_product_flash_attention_v2"
elif kernel_suffix.startswith("wrapped_fbgemm"):
assert self.device == "cpu", "Using wrapped_fbgemm out of CPU!"
shim_fn = f"aoti_torch_cpu_{kernel_suffix}"
else:
shim_fn = f"aoti_torch_{kernel_suffix}"
else:
shim_fn = f"aoti_torch_{self.device}_{kernel_suffix}"
return shim_fn
def generate_c_shim_extern_kernel_call(self, kernel, args):
# In the abi_compatible mode, we call fallback aten ops through a C shim layer
# Setting self.allow_stack_allocation to False because the exchange between
# ArrayRefTensor and at::Tensor is still fragile.
self.allow_stack_allocation = False
wrapped_args = []
for x in args:
pieces = x.split(", ")
for piece in pieces:
# We only really *need* convert_arrayref_tensor_to_tensor for
# ArrayRefTensors. The code flowing into here uses `0` for nullptr,
# which convert_arrayref_tensor_to_tensor would blindly coerce to int,
# so just avoid wrapping integers.
# Name matching is to find tensor is hacky, but fixing all the
# ArrayRefTensor issues is not a priority for now.
if isinstance(piece, str) and piece.startswith(
("buf", "arg", "wrap_with_raii_handle_if_needed")
):
piece = f"convert_arrayref_tensor_to_tensor({piece})"
wrapped_args.append(piece)
shim_fn = self.get_c_shim_func_name(kernel)
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(wrapped_args)}));"
)
def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args):
# registered output buffer name
name = extern_kernel.name
output_handle_name = f"{name}_handle"
self.writeline(f"AtenTensorHandle {output_handle_name};")
output_arg = f"&{output_handle_name}"
self.generate_c_shim_extern_kernel_call(
extern_kernel.get_kernel_name(), args + [output_arg]
)
self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")
def generate_extern_kernel_alloc(self, extern_kernel, args):
if config.abi_compatible:
self.generate_c_shim_extern_kernel_alloc(extern_kernel, args)
else:
super().generate_extern_kernel_alloc(extern_kernel, args)
def generate_c_shim_fallback_kernel(self, fallback_kernel, args):
output_args = []
output_raii_handles = []
output_name_base = fallback_kernel.get_name()
for idx, output in enumerate(fallback_kernel.outputs):
if isinstance(output, ir.MultiOutput):
# TODO: handle integer output (e.g., as in attention)
name = f"{output.get_name()}"
output_handle_name = f"{name}_handle"
if output.indices:
assert (
output.indices[0][1] == idx
), f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}"
self.writeline(f"AtenTensorHandle {output_handle_name};")
output_args.append(f"&{output_handle_name}")
output_raii_handles.append(
f"RAIIAtenTensorHandle {name}({output_handle_name});"
)
elif isinstance(output, int):
output_name = f"{output_name_base}_{idx}"
self.writeline(f"int64_t {output_name} = {output};")
output_args.append(f"&{output_name}")
elif isinstance(output, sympy.Symbol):
output_name = f"{output_name_base}_{idx}"
self.writeline(f"auto {output_name} = {output};")
output_args.append(f"&{output_name}")
elif output is None:
output_args.append("nullptr")
else:
raise NotImplementedError(f"unsupported type of {output=}")
args = args + output_args
self.generate_c_shim_extern_kernel_call(fallback_kernel.cpp_kernel_name, args)
for raii_handle in output_raii_handles:
self.writeline(raii_handle)
def generate_fallback_kernel(self, fallback_kernel, args):
if config.abi_compatible:
self.generate_c_shim_fallback_kernel(fallback_kernel, args)
else:
super().generate_fallback_kernel(fallback_kernel, args)
def generate_extern_kernel_out(
self, kernel: str, out: str, out_view: Optional[str], args: List[str]
):
if out_view:
out_name = f"{out}_as_strided"
self.writeline(f"auto {out_name} = {out_view};")
args.insert(0, out_name)
else:
args.insert(0, out)
if config.abi_compatible:
self.generate_c_shim_extern_kernel_call(kernel, args)
else:
self.writeline(self.wrap_kernel_call(kernel, args))
def generate_user_defined_triton_kernel(
self, kernel_name, grid, configs, args, triton_meta, arg_types=None
):
assert len(grid) != 0
if len(grid) == 1:
grid_decision = grid[0]
else:
meta = CudaKernelParamCache.get(kernel_name)
assert meta is not None
grid_decision = None
for i, c in enumerate(configs):
if all(arg == meta["meta"][key] for key, arg in c.kwargs.items()):
grid_decision = grid[i]
break
assert grid_decision is not None
current_device = V.graph.scheduler.get_current_device_or_throw()
self.generate_kernel_call(
kernel_name,
args,
arg_types=arg_types,
grid=grid_decision,
device_index=current_device.index,
cuda=True,
triton=True,
triton_meta=triton_meta,
)
def generate_scatter_fallback(
self,
output,
inputs,
cpp_kernel_name,
python_kernel_name,
src_is_tensor,
reduce,
kwargs,
):
# No stack allocation when there is a fallback op
self.allow_stack_allocation = False
# TODO: needs updates to use C shim v2
if config.abi_compatible:
# call the ABI shim function instead of the ATen one
if config.c_shim_version == "1":
cpp_kernel_name = (
"aoti_torch_scatter_reduce_out"
if python_kernel_name.startswith("aten.scatter_reduce")
else "aoti_torch_scatter_out"
)
else:
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name)
# C shim only contains out-variant instead of inplace-variant
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
inputs_wrapped = [
f"convert_arrayref_tensor_to_tensor({x})"
if isinstance(x, str)
else str(x)
for x in inputs
]
line = f"{cpp_kernel_name}(convert_arrayref_tensor_to_tensor({output}), {','.join(inputs_wrapped)}"
else:
line = f"{cpp_kernel_name}({','.join(map(str, inputs))}"
if python_kernel_name.startswith("aten.scatter_reduce"):
line += f", {','.join(kwargs)}"
else:
if src_is_tensor:
if reduce:
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
else:
assert (
reduce is None
), "Expect reduce to be None for aten.scatter_ with scalar src"
line += ");"
self.writeline(line)
def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
# No stack allocation when there is a fallback op
self.allow_stack_allocation = False
# TODO: needs updates to use C shim v2
if config.abi_compatible:
# See the comment in codegen_reinterpret_view about why having something like
# RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding
# tensor prematurely deallocated, thus this std::vector().data() trick here.
indices_str = (
"std::vector<AtenTensorHandle>{"
+ (
", ".join(
[f"convert_arrayref_tensor_to_tensor({ind})" for ind in indices]
)
)
+ "}.data()"
)
args = [
f"convert_arrayref_tensor_to_tensor({x})",
indices_str,
str(len(indices)),
f"convert_arrayref_tensor_to_tensor({values})",
accumulate,
]
args.insert(
0, f"convert_arrayref_tensor_to_tensor({x})"
) # set x as the output tensor, this fallback mutates x.
else:
indices_str = (
f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}"
)
args = [x, indices_str, values, accumulate]
args.insert(0, x) # set x as the output tensor, this fallback mutates
self.writeline(self.wrap_kernel_call(kernel, args))
def add_benchmark_harness(self, output):
if V.graph.aot_mode:
return
super().add_benchmark_harness(output)
def codegen_sizevar(self, x: Expr) -> str:
return self.expr_printer(V.graph.sizevars.simplify(x))
def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
if config.abi_compatible:
# in the abi_compatible mode, outputs are returned via arguments
return name
else:
return f"std::get<{index}>({basename})"
def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
parts = list(map(self.codegen_sizevar, shape))
if len(parts) == 0:
return "{}"
if len(parts) == 1:
return f"{{{parts[0]}, }}"
return f"{{{', '.join(parts)}}}"
def codegen_dynamic_scalar(self, node):
(data,) = (t.codegen_reference() for t in node.inputs)
if config.abi_compatible:
self.codegen_tensor_item(
node.inputs[0].get_dtype(), data, f"{node.sym}_raw"
)
else:
convert_type = DTYPE_TO_ATEN[node.inputs[0].get_dtype()].replace(
"at::k", "to"
)
self.writeline(f"auto {node.sym}_raw = {data}.item().{convert_type}();")
if len(node.keypath) == 0:
self.writeline(f"auto {node.sym} = {node.sym}_raw;")
elif len(node.keypath == 1) and isinstance(node.keypath[0], ConvertIntKey):
self.writeline(f"int64_t {node.sym} = {node.sym}_raw ? 1 : 0;")
elif len(node.keypath == 1) and isinstance(node.keypath[0], DivideByKey):
# TODO: assert divisibility here
self.writeline(
f"int64_t {node.sym} = {node.sym}_raw / {node.keypath[0].divisor};"
)
else:
raise AssertionError(f"unrecognized keypath {node.keypath}")
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
self.unbacked_symbol_decls.add(str(node.sym))
def can_stack_allocate_buffer(self, buffer):
return (
self.allow_stack_allocation
and buffer.get_device().type == "cpu"
and self.can_prove_buffer_has_static_shape(buffer)
and ir.is_contiguous_strides_for_shape(
buffer.get_stride(), buffer.get_size()
)
)
def make_buffer_free(self, buffer):
return (
""
if isinstance(buffer.get_layout(), ir.MultiOutputLayout)
or (V.graph.aot_mode and buffer.get_name() in self.stack_allocated_buffers)
or (
config.use_minimal_arrayref_interface
and V.graph.aot_mode
and buffer.get_name() in V.graph.graph_inputs
)
else f"{buffer.get_name()}.reset();"
)
def make_free_by_names(self, names_to_del: List[str]):
return " ".join(f"{name}.reset();" for name in names_to_del)
def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str):
if config.abi_compatible:
return f"auto {new_name} = std::move({old_name}); // reuse"
else:
return super().codegen_exact_buffer_reuse(old_name, new_name, del_line)
def generate_profiler_mark_wrapper_call(self, stack):
self.wrapper_call.writeline(
'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef<c10::IValue>());'
)
def write_triton_header_once(self):
pass
def generate_start_graph(self):
pass
def generate_end_graph(self):
pass
def generate_inf_and_nan_checker(self, nodes):
for buf in nodes.get_names():
# TODO: Add buf name directly into check_inf_and_nan.
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan({buf}));"
)
def codegen_device(self, device):
if config.abi_compatible:
self.used_cached_devices.add(device.type)
return f"cached_torch_device_type_{device.type}, {device.index if device.index else 0}"
else:
return (
f"c10::Device({DEVICE_TO_ATEN[device.type]}, {device.index})"
if device.index is not None
else f"{DEVICE_TO_ATEN[device.type]}"
)
def codegen_dtype(self, dtype):
if config.abi_compatible:
dtype_str = str(dtype).split(".")[-1]
self.used_cached_dtypes.add(dtype_str)
return f"cached_torch_dtype_{dtype_str}"
else:
return DTYPE_TO_ATEN[dtype]
def codegen_layout(self, layout):
if config.abi_compatible:
layout_str = str(layout).split(".")[-1]
self.used_cached_layouts.add(layout_str)
return f"cached_torch_layout_{layout_str}"
else:
return LAYOUT_TO_ATEN[layout]
@functools.lru_cache(None) # noqa: B019
def codegen_int_array_var(
self,
int_array: str,
writer=None,
known_statically=False,
graph=None, # for per-graph caching
):
# This is used for size/stride declaration
# Because the memory planning is done in two passes (see the implementation
# of self.generate), the writeline behavior is different in the two passes.
# As a result, the emitted int array declarations may appear in a later
# position of the generated code, so the second pass codegen should not
# reuse int array declarations generated in the first pass
if writer is None:
# The first pass codegen uses `self` as the writer
writer = self
var = f"int_array_{next(self.int_array_id)}"
ctype = "int64_t"
if var not in self.declared_int_array_vars:
self.declared_int_array_vars.add(var)
if known_statically:
writer.writeline(f"static constexpr {ctype} {var}[] = {int_array};")
else:
writer.writeline(f"const {ctype} {var}[] = {int_array};")
return var
def make_buffer_allocation(self, buffer):
return self.make_allocation(
buffer.get_name(),
buffer.get_device(),
buffer.get_dtype(),
buffer.get_size(),
buffer.get_stride(),
buffer if self.can_stack_allocate_buffer(buffer) else None,
)
def make_allocation(
self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None
):
orig_stride = stride
device_str = self.codegen_device(device)
dtype_code = self.codegen_dtype(dtype)
size = self.codegen_shape_tuple(shape)
stride = self.codegen_shape_tuple(orig_stride)
if config.abi_compatible:
size_array_var = self.codegen_int_array_var(
size,
self.wrapper_call,
known_statically=self.is_statically_known_list_of_ints(shape),
graph=self.get_codegened_graph(),
)
stride_array_var = self.codegen_int_array_var(
stride,
self.wrapper_call,
known_statically=self.is_statically_known_list_of_ints(orig_stride),
graph=self.get_codegened_graph(),
)
device_type, device_id = device_str.split(",")
device_idx = "this->device_idx_" if V.graph.aot_mode else device_id
if buffer_if_can_stack_allocate is not None:
self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate
cpp_type = DTYPE_TO_CPP[dtype]
numel = buffer_if_can_stack_allocate.get_numel()
# Note: we don't zero storage because empty_strided doesn't zero either.
self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];")
args = [
f"{name}_storage",
size_array_var,
stride_array_var,
device_type,
device_idx,
]
return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});"
args = [
str(len(shape)),
size_array_var,
stride_array_var,
dtype_code,
device_type,
device_idx,
f"&{name}_handle",
]
self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;")
self.wrapper_call.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));"
)
return f"RAIIAtenTensorHandle {name}({name}_handle);"
if V.graph.aot_mode and device_str.startswith("c10::Device("):
tensor_device = f"{device_str.split(',')[0]}, this->device_idx_)"
else:
tensor_device = device_str
if device.type == "cpu":
return f"at::Tensor {name} = at::detail::empty_strided_cpu({size}, {stride}, {dtype_code});"
if device.type == "cuda":
return (
f"at::Tensor {name} = at::detail::empty_strided_cuda("
f"{size}, {stride}, {dtype_code}, c10::DeviceType::CUDA);"
)
return (
f"{self.declare}{name} = {self.namespace}empty_strided("
f"{size}, {stride}, at::TensorOptions({tensor_device}).dtype({dtype_code})){self.ending}"
)
def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
if config.abi_compatible:
size = self.codegen_shape_tuple(shape)
stride = self.codegen_shape_tuple(stride)
tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}"
args = [
name,
self.expr_printer(offset), # bytes not numel
self.codegen_dtype(dtype),
str(len(shape)),
self.codegen_int_array_var(
size, self.wrapper_call, graph=self.get_codegened_graph()
),
self.codegen_int_array_var(
stride, self.wrapper_call, graph=self.get_codegened_graph()
),
f"&{tmp_name}",
]
self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};")
self.wrapper_call.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));"
)
return f"RAIIAtenTensorHandle({tmp_name})"
return "alloc_from_pool({})".format(
", ".join(
[
name,
self.expr_printer(offset), # bytes not numel
self.codegen_dtype(dtype),
self.codegen_shape_tuple(shape),
self.codegen_shape_tuple(stride),
]
)
)
def codegen_reinterpret_view(
self, data, size_list, stride_list, offset, writer
) -> str:
dim = str(len(size_list))
size = self.codegen_shape_tuple(size_list)
stride = self.codegen_shape_tuple(stride_list)
offset = self.codegen_sizevar(offset)
if config.abi_compatible:
tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}"
# Because the memory planning is done in two passes (see the implementation
# of self.generate), the writeline behavior is different in the two passes.
if writer is None:
writer = self
args = [
f"{data.get_name()}",
dim,
self.codegen_int_array_var(
size,
writer,
known_statically=self.is_statically_known_list_of_ints(size_list),
graph=self.get_codegened_graph(),
),
self.codegen_int_array_var(
stride,
writer,
known_statically=self.is_statically_known_list_of_ints(stride_list),
graph=self.get_codegened_graph(),
),
offset,
]
def gen_reinterpret_call(writer, args):
writer.writeline(
f"auto {tmp_name} = reinterpret_tensor_wrapper({', '.join(args)});"
)
if (
self.can_stack_allocate_buffer(data)
and self.is_statically_known_list_of_ints(size_list)
and self.is_statically_known_list_of_ints(stride_list)
and ir.is_contiguous_strides_for_shape(stride_list, size_list)
):
gen_reinterpret_call(writer, args)
return tmp_name
gen_reinterpret_call(writer, args)
# NB, the return handle here represents a temporary tensor, which will be automatically
# released.
# Here's a sample usage in the cpp wrapper code:
# ```
# aoti_torch_addmm_out(
# buf1,
# arg1_1,
# RAIIAtenTensorHandle(tmp_tensor_handle_0),
# buf0,
# 1L,
# 1L));
# ```
# RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out.
# This could be problematic when it's used in a different pattern, for example:
# ````
# AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6};
# aoti_torch_proxy_executor_call_function(..., tensor_args);
# ````
# RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter
# kernel call.
#
# This is solved by updating the proxy_executor invocation to
# ```
# aoti_torch_proxy_executor_call_function(...,
# std::vector<AtenTensorHandle>{
# RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6
# }.data()
# );
# ```
return f"wrap_with_raii_handle_if_needed({tmp_name})"
else:
args = [data.get_name(), size, stride, offset]
return f"reinterpret_tensor({', '.join(args)})"
def codegen_device_copy(self, src, dst):
if config.abi_compatible:
# aoti_torch_tensor_copy_ takes AtenTensorHandle as input,
# while stack-allocation results in ArrayRefTensor
# so disable stack allocation here
self.allow_stack_allocation = False
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_tensor_copy_(expensive_copy_to_tensor_if_needed({src}), {dst}));"
)
else:
self.writeline(f"{dst}.copy_({src});")
def codegen_multi_output(self, name, value):
# in the abi_compatible mode, outputs are retrieved by passing
# output pointers, so we skip its codegen here.
if not config.abi_compatible:
super().codegen_multi_output(name, value)
def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs):
if config.abi_compatible:
# in ABI-compatible mode, we copy the underlying at::Tensor of the conditional
# input (outer_input) into another at::Tensor to be used as a subgraph input
# (inner_input) in the nested scope. we can't std::move here, as the codegened
# outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we
# can't necessarily std::move it back to the origin (x).
self.writeline(f"AtenTensorHandle {inner_input}_handle;")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));"
)
self.writeline(
f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);"
)
else:
self.writeline(
f"{self.declare}{inner_input} = {outer_input}{self.ending}"
)
def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
for inner_output, outer_output in zip(
subgraph.graph.graph_outputs, outer_outputs
):
src = inner_output.codegen_reference()
if config.abi_compatible:
# in ABI-compatible mode, we need to std::move subgraph output (inner_output)
# to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
# constructor is deleted.
src = f"std::move({src})"
# in case the outer_output carried a value
# before (e.g., in the while_loop codegen)
self.writeline(f"{outer_output}.reset();")
self.writeline(f"{outer_output} = {src}{self.ending}")
def codegen_conditional(self, conditional):
name = conditional.get_name()
outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands]
if config.abi_compatible:
outer_outputs = []
for out in conditional.outputs:
# in ABI-compatible mode, ir.MultiOutput is not codegened,
# hence pre-declare output variables directly and separately
self.writeline(f"RAIIAtenTensorHandle {out.get_name()};")
outer_outputs.append(out.get_name())
if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer):
# in ABI-compatible mode, we need to use the ABI shim function
# to extract a C++ bool from the unrelying scalar bool Tensor
predicate = f"{conditional.predicate.get_name()}_scalar"
self.codegen_tensor_item(
torch.bool,
conditional.predicate.codegen_reference(),
predicate,
)
else:
# the predicate is not a Tensor: SymBool or Python bool
predicate = conditional.predicate.codegen_reference()
else:
# in non-ABI-compatible mode, we can codegen the conditional outputs
# as array of at::Tensor instances, as the ir.MultiOutput is codegened
outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
self.writeline(f"at::Tensor {name}[{len(conditional.outputs)}];")
predicate = f"{conditional.predicate.codegen_reference()}"
if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer):
# move the Tensor predicate to host
predicate = f"{predicate}.item<bool>()"
self.writeline(f"if ({predicate}) {{")
self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs)
self.writeline(ExitSubgraphLine(self))
self.writeline("} else {")
self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs)
self.writeline(ExitSubgraphLine(self))
self.writeline("}")
def codegen_while_loop(self, while_loop):
name = while_loop.get_name()
outer_carried_inputs = [
buf.codegen_reference() for buf in while_loop.carried_inputs
]
outer_additional_inputs = [
buf.codegen_reference() for buf in while_loop.additional_inputs
]
cond_result_name = f"{name}_cond_result"
if config.abi_compatible:
self.writeline(f"RAIIAtenTensorHandle {cond_result_name};")
cond_outer_inputs = []
for inp, out in zip(outer_carried_inputs, while_loop.outputs):
# in ABI-compatible mode, the carried inputs are codegened
# as buffers outside the while loop and set to the initial
# values. at the end of each while_loop iteration, they
# will be assined the carried values.
out_name = out.get_name()
self.writeline(f"AtenTensorHandle {out_name}_handle;")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({inp}, &{out_name}_handle));"
)
self.writeline(f"RAIIAtenTensorHandle {out_name}({out_name}_handle);")
cond_outer_inputs.append(out_name)
# additional inputs will be assinged within the while_loop
# iteration directly from the corresponding outer graph buffers
cond_outer_inputs.extend(outer_additional_inputs)
else:
self.writeline(f"at::Tensor {cond_result_name};")
self.writeline(f"at::Tensor {name}[{len(outer_carried_inputs)}];")
for i, inp in enumerate(outer_carried_inputs):
# set the initial state before the loop
self.writeline(f"{name}[{i}] = {inp};")
cond_outer_inputs = [
*[f"{name}[{i}]" for i in range(len(outer_carried_inputs))],
*outer_additional_inputs,
]
cond_outer_outputs = [cond_result_name]
body_outer_inputs = list(cond_outer_inputs)
body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)]
self.writeline("while (1) {")
self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph))
self.codegen_subgraph(
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
)
if config.abi_compatible:
cond_result = f"{cond_result_name}_scalar"
self.codegen_tensor_item(torch.bool, cond_result_name, cond_result)
else:
cond_result = f"{cond_result_name}.item<bool>()"
self.writeline(f"if (!{cond_result}) break;")
self.writeline(ExitSubgraphLine(self))
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
self.codegen_subgraph(
while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
)
self.writeline(ExitSubgraphLine(self))
self.writeline("}")
def generate_extern_kernel_args_decl_if_needed(
self, op_overload, raw_args, output_args
):
arg_types = [x.real_type for x in op_overload._schema.arguments]
return_types = [x.type for x in op_overload._schema.returns]
new_tensor_args = []
new_int_args = []
def fill_args(arg, arg_type):
static_arg_types = (
torch.FloatType,
torch.BoolType,
torch.StringType,
torch.Type,
torch.DeviceObjType,
)
inductor_tensor_buffers = (
ir.Buffer,
ir.ReinterpretView,
)
if isinstance(arg_type, torch.TensorType):
assert isinstance(arg, inductor_tensor_buffers), f"got {type(arg)}"
new_tensor_args.append(f"{arg.codegen_reference()}")
elif isinstance(arg_type, torch.IntType):
# int
new_int_args.append(str(arg))
elif isinstance(arg_type, torch.SymIntType):
# SymInt
expr = arg.node.expr if isinstance(arg, torch.SymInt) else arg
new_int_args.append(self.expr_printer(expr))
elif isinstance(arg_type, torch.NumberType):
# Scalar of type int
assert isinstance(arg, (int, float, bool))
# Only treat int Scalar as dynamic
if isinstance(arg, int):
new_int_args.append(str(arg))
elif isinstance(arg_type, torch.ListType):
assert isinstance(arg, (list, tuple))
# List[Tensor]
if isinstance(arg_type.getElementType(), torch.TensorType):
new_tensor_args.extend([f"{a.codegen_reference()}" for a in arg])
# List[Optional[Tensor]]
elif isinstance(
arg_type.getElementType(), torch.OptionalType
) and isinstance(
arg_type.getElementType().getElementType(), torch.TensorType
):
new_tensor_args.extend(
[f"{a.codegen_reference()}" for a in arg if a is not None]
)
# List[int]
elif isinstance(arg_type.getElementType(), torch.IntType):
new_int_args.extend([str(a) for a in arg])
# List[SymInt]
elif isinstance(arg_type.getElementType(), torch.SymIntType):
expressions = [
a.node.expr if isinstance(a, torch.SymInt) else a for a in arg
]
new_int_args.extend(
[self.expr_printer(expr) for expr in expressions]
)
# List[Scalar]
elif isinstance(arg_type.getElementType(), torch.NumberType):
# Only treat int Scalar as dynamic
is_int_type = [isinstance(a, int) for a in arg]
if any(is_int_type):
assert all(
is_int_type
), "AOTInductor only supports int scalars of the same type"
new_int_args.extend([str(a) for a in arg])
else:
assert isinstance(
arg_type.getElementType(), static_arg_types # type: ignore[arg-type]
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
else:
assert isinstance(
arg_type, static_arg_types # type: ignore[arg-type]
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
for arg, arg_type in zip(raw_args, arg_types):
if arg is not None:
if isinstance(arg_type, torch.OptionalType):
fill_args(arg, arg_type.getElementType())
else:
fill_args(arg, arg_type)
def fill_output_arg(arg, return_type):
if isinstance(return_type, torch.TensorType):
self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));"
)
self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);")
new_tensor_args.append(f"{arg}")
elif isinstance(return_type, torch.SymIntType):
raise NotImplementedError("NYI support for return type: SymInt")
elif isinstance(return_type, torch.ListType) and isinstance(
return_type.getElementType(), torch.SymIntType
):
raise NotImplementedError("NYI support for return type: List[SymInt]")
else:
raise AssertionError(f"Unsupported return type found: {return_type}")
# TODO: Only support tensor(s) returns for now, SymInt is not implemented yet
for return_type in return_types:
if isinstance(return_type, (torch.TensorType)):
pass
elif isinstance(return_type, torch.OptionalType):
assert isinstance(return_type.getElementType(), torch.TensorType)
elif isinstance(return_type, torch.ListType):
assert isinstance(return_type.getElementType(), torch.TensorType)
else:
raise NotImplementedError(
f"return type {return_type} is not yet supported."
)
for output_arg in output_args:
assert output_arg is not None, "Optional return types are not yet supported"
if isinstance(output_arg, (list, tuple)):
for out in output_arg:
fill_output_arg(out, torch.TensorType.get())
else:
fill_output_arg(output_arg, torch.TensorType.get())
return new_tensor_args, new_int_args
def generate_extern_kernel_alloc_and_find_schema_if_needed(
self,
buf_name: str,
python_kernel_name: str,
cpp_kernel_name: str,
codegen_args: List[str],
cpp_op_schema: str,
cpp_kernel_key: str,
cpp_kernel_overload_name: str = "",
op_overload: Optional[torch._ops.OpOverload] = None,
raw_args=None,
outputs=None,
):
# No stack allocation when there is a fallback op
self.allow_stack_allocation = False
def extract_output_name(out):
assert out is not None, "None, i.e. optional output is not supported"
if isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)):
return out.get_name()
elif isinstance(out, (list, tuple)):
return type(out)(extract_output_name(o) for o in out)
else:
raise AssertionError(f"Unexpected output: {type(out)}")
# output_args has the same pytree structure as outputs
output_args = None
if config.abi_compatible:
output_args = extract_output_name(outputs)
if isinstance(output_args, str):
output_args = [output_args]
if config.is_fbcode():
assert op_overload is not None
assert raw_args is not None
assert outputs is not None
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
cpp_kernel_key,
op_overload,
raw_args,
output_args,
)
else:
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
buf_name,
python_kernel_name,
cpp_kernel_name,
codegen_args,
cpp_op_schema,
cpp_kernel_key,
cpp_kernel_overload_name,
op_overload,
raw_args,
output_args,
)
def generate_scoped_gil_acquire(self, declarations_before_scope, lines_in_scope):
scoped_lines = IndentedBuffer()
for declaration in declarations_before_scope:
scoped_lines.writeline(declaration)
scoped_lines.writeline("{")
with scoped_lines.indent():
scoped_lines.writeline("py::gil_scoped_acquire acquire;")
scoped_lines.writelines(lines_in_scope.split("\n"))
scoped_lines.writelines("}")
return scoped_lines._lines
def load_custom_op_wrapper(self):
# TODO: need to support control flow
if self.custom_op_wrapper_loaded:
return
lines = """
RAIIPyObject codecache_module(PyImport_ImportModule("torch._inductor.codecache"));
if (codecache_module.get() == NULL) {
throw std::runtime_error("Failed to load torch._inductor.codecache");
}
custom_op_wrapper = PyObject_GetAttrString(codecache_module, "custom_op_wrapper");
if (custom_op_wrapper.get() == NULL) {
throw std::runtime_error("Failed to load torch._inductor.codecache.custom_op_wrapper");
}"""
declarations_before_scope = ["RAIIPyObject custom_op_wrapper;"]
scope_gil_acquire = self.generate_scoped_gil_acquire(
declarations_before_scope, lines
)
self.writelines(scope_gil_acquire)
self.custom_op_wrapper_loaded = True
def generate_py_arg(self, py_args_var, idx, raw_arg, arg_type):
def generate_py_arg_inner(raw_arg, arg_type):
if isinstance(arg_type, torch.TensorType):
# Store AtenTensorHandle as void*
return f"PyCapsule_New(reinterpret_cast<void*>({raw_arg.codegen_reference()}.get()), NULL, NULL)"
elif isinstance(arg_type, torch.IntType):
# int
return f"PyLong_FromLongLong({raw_arg})"
elif isinstance(arg_type, torch.SymIntType):
# SymInt
expr = (
raw_arg.node.expr if isinstance(raw_arg, torch.SymInt) else raw_arg
)
return f"PyLong_FromLongLong({self.expr_printer(expr)})"
elif isinstance(arg_type, torch.FloatType):
return f"PyFloat_FromDouble({raw_arg})"
elif isinstance(arg_type, torch.BoolType):
return f"PyBool_FromLong({1 if raw_arg else 0})"
elif isinstance(arg_type, torch.StringType):
return f'PyUnicode_FromString("{raw_arg}")'
else:
raise NotImplementedError(
f"arg type {arg_type} is not yet supported by custom_op_wrapper"
)
lines = ""
if isinstance(arg_type, torch.ListType):
assert isinstance(raw_arg, (list, tuple)), str(raw_arg) + " is not a list"
lines += f"PyObject* {py_args_var}_{idx} = PyList_New({len(raw_arg)});\n"
for i, elem in enumerate(raw_arg):
lines += f"PyList_SetItem({py_args_var}_{idx}, {i}, {generate_py_arg_inner(elem, arg_type.getElementType())});\n"
lines += f"PyTuple_SetItem({py_args_var}, {idx}, {py_args_var}_{idx});\n"
else:
lines += f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(raw_arg, arg_type)});\n"
return lines
def generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
self,
buf_name: str,
python_kernel_name: str,
cpp_kernel_name: str,
codegen_args: List[str],
cpp_op_schema: str,
cpp_kernel_key: str,
cpp_kernel_overload_name: str = "",
op_overload: Optional[torch._ops.OpOverload] = None,
raw_args=None,
output_args: Optional[List[str]] = None,
):
if V.graph.aot_mode or not config.abi_compatible:
# Will update this to use an OSS version ProxyExecutor
if cpp_kernel_key not in self.extern_call_ops:
self.writeline(
f"static auto op_{cpp_kernel_key} = c10::Dispatcher::singleton()"
)
self.writeline(
f'\t.findSchemaOrThrow("{cpp_kernel_name}", "{cpp_kernel_overload_name}")'
)
self.writeline(f"\t.typed<{cpp_op_schema}>();")
self.extern_call_ops.add(cpp_kernel_key)
self.writeline(
f"auto {buf_name} = op_{cpp_kernel_key}.call({', '.join(codegen_args)});"
)
else:
# In the JIT mode, because of the ABI-compatible requirement, we can't directly call
# c10::Dispatcher to find the custom op and call it. Instead, we go back to Python
# to invoke this custom op.
self.load_custom_op_wrapper()
assert output_args is not None, "output_args should not be None"
num_args = len(raw_args)
py_args_var = f"py_args_{next(self.arg_var_id)}"
# First arg is always the python op name
lines = f"""
RAIIPyObject {py_args_var}(PyTuple_New({num_args+1}));
if ({py_args_var}.get() == NULL) {{
throw std::runtime_error("PyTuple_New {py_args_var} failed");
}}
PyTuple_SetItem({py_args_var}, 0, PyUnicode_FromString("{python_kernel_name}"));
"""
assert op_overload is not None, "op_overload should not be None"
for idx, (raw_arg, schema_arg) in enumerate(
zip(raw_args, op_overload._schema.arguments)
):
lines += self.generate_py_arg(
py_args_var, idx + 1, raw_arg, schema_arg.real_type
)
lines += f"""
// Call the custom op in Python
RAIIPyObject py_{buf_name}(PyObject_CallObject(custom_op_wrapper, {py_args_var}));
if (py_{buf_name}.get() == NULL) {{
throw std::runtime_error("PyObject_CallObject {python_kernel_name} failed");
}}"""
if len(output_args) == 1:
# result is a single tensor
lines += f"""
{output_args[0]} = reinterpret_cast<AtenTensorHandle>(PyCapsule_GetPointer(py_{buf_name}.get(), NULL));"""
else:
# result is a tuple of tensors
for idx, output_arg in enumerate(output_args):
lines += f"""
{output_arg} =
reinterpret_cast<AtenTensorHandle>(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));"""
declarations_before_scope = [
f"RAIIAtenTensorHandle {output_arg};"
for idx, output_arg in enumerate(output_args)
]
scope_gil_acquire = self.generate_scoped_gil_acquire(
declarations_before_scope, lines
)
self.writelines(scope_gil_acquire)
def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
self,
cpp_kernel_key,
op_overload,
raw_args, # contains both args and flatten kwargs
output_args: Optional[List[str]] = None,
):
(
tensor_call_args,
int_call_args,
) = self.generate_extern_kernel_args_decl_if_needed(
op_overload, raw_args, output_args
)
tensor_call_args_str = ", ".join(tensor_call_args)
int_call_args_str = ", ".join(int_call_args)
extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1
self.writeline(
f"aoti_torch_proxy_executor_call_function(proxy_executor, "
f"{extern_kernel_node_index}, "
f"{len(int_call_args)}, "
f"std::vector<int64_t>{{{int_call_args_str}}}.data(), "
f"{len(tensor_call_args)}, "
f"std::vector<AtenTensorHandle>{{{tensor_call_args_str}}}.data());"
)
self.extern_call_ops.add(cpp_kernel_key)
def generate_reset_kernel_saved_flags(self):
pass
def generate_save_uncompiled_kernels(self):
pass
def c_type_for_prim_type(self, type_) -> str:
assert (
config.abi_compatible
), "c_type_for_prim_type is only used in ABI compatible mode"
if isinstance(type_, torch.OptionalType):
return f"{self.c_type_for_prim_type(type_.getElementType())}*"
elif isinstance(type_, torch.TensorType):
return "AtenTensorHandle"
elif isinstance(type_, (torch.IntType, torch.SymIntType)):
return "int64_t"
elif isinstance(
type_, (torch.BoolType, torch.SymBoolType, torch.EnumType)
) or repr(type_) in ("ScalarType", "Layout"):
return "int32_t"
elif isinstance(type_, torch.FloatType):
return "double"
else:
raise AssertionError(f"Unexpected type in c_type_for_prim_type: {type_=}")
def val_to_arg_str_for_prim_type(self, val, type_) -> str:
# TODO: not using type_ as the first step of refactoring. Will update this later.
if isinstance(val, bool):
if config.abi_compatible:
return "1" if val else "0"
else:
return "true" if val else "false"
elif isinstance(val, int):
# uint64_t is long on Linux, but long long on MacOS
return f"{val}LL" if sys.platform == "darwin" else f"{val}L"
elif isinstance(val, str):
return f'"{val}"'
elif isinstance(
val, (ir.Buffer, ir.ReinterpretView, ir.StorageBox, ir.TensorBox)
):
return val.codegen_reference()
elif isinstance(val, torch.device):
return self.codegen_device(val)
elif isinstance(val, torch.dtype):
return self.codegen_dtype(val)
elif isinstance(val, float) and val in [float("inf"), float("-inf")]:
if val == float("inf"):
return "std::numeric_limits<float>::infinity()"
else:
return "-std::numeric_limits<float>::infinity()"
elif isinstance(val, (list, tuple)):
# FIXME: This happens because type_ is not always properly set to torch.ListType
return f"{{{', '.join(self.val_to_arg_str(x, None) for x in val)}}}"
else:
return repr(val)
def val_to_arg_str(self, val, type_=None) -> str:
if val is None:
# None needs special care. It either represent nullopt or an empty tensor
if config.abi_compatible:
if type_ is None or isinstance(type_, torch.OptionalType):
if type_ is not None and isinstance(
type_.getElementType(),
(
torch.ListType,
torch.TupleType,
torch.DeviceObjType,
),
):
return "0, 0"
else:
return "0" # nullptr is not available in C
elif isinstance(type_, torch.TensorType):
# create an empty tensor, the equivalent of at::Tensor()
var_name = f"var_{next(self.arg_var_id)}"
self.writeline(f"AtenTensorHandle {var_name}_handle;")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));"
)
self.writeline(
f"RAIIAtenTensorHandle {var_name}({var_name}_handle);"
)
return var_name
else:
raise AssertionError("Can not map None to a known data type")
else:
if isinstance(type_, torch.TensorType):
var_name = f"var_{next(self.arg_var_id)}"
self.writeline(f"at::Tensor {var_name} = at::Tensor();")
return var_name
else:
return "std::nullopt"
if isinstance(type_, torch.OptionalType):
element_type = type_.getElementType()
if config.abi_compatible:
if not isinstance(element_type, torch.TensorType):
var_name = f"var_{next(self.arg_var_id)}"
if isinstance(
element_type,
(torch.ListType, torch.TupleType, torch.DeviceObjType),
):
# type_ is something like Optional[List] or Optional[Device]
arg_str = self.val_to_arg_str(val, element_type)
# For datatypes with auxiliary info, we need to hoist out the extra arguments.
# NOTE: This only works if there is one additional argument, though it can easily be generalized.
main_value, aux = arg_str.rsplit(", ")
self.writeline(f"auto {var_name} = {main_value};")
return f"&{var_name}, {aux}"
else:
self.writeline(
f"{self.c_type_for_prim_type(element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};"
)
return f"&{var_name}"
elif config.c_shim_version == "2":
# type_ is Optional[Tensor]
# Similar to other data type, use pointer to denote optional tensor arg in v2 C shim
base_handle = self.val_to_arg_str(val, element_type)
if "wrap_with_raii_handle_if_needed" in base_handle:
# wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to
# explicitly store it. Otherwise, it will be destroyed before the fallback kernel call.
tmp_var_name = f"var_{next(self.arg_var_id)}"
self.writeline(
f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};"
)
base_handle = tmp_var_name
var_name = f"var_{next(self.arg_var_id)}"
self.writeline(
f"AtenTensorHandle {var_name} = {base_handle}.get();"
)
return f"&{var_name}"
else:
return self.val_to_arg_str(val, element_type)
elif isinstance(type_, torch.ListType):
assert isinstance(
val, (list, tuple)
), f"{val} does not match with arg type {type_}"
element_type = type_.getElementType()
if config.abi_compatible:
assert len(val) > 0, "Empty array is not supported in C"
var_name = f"var_array_{next(self.var_array_id)}"
result = f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}"
self.writeline(
f"const {self.c_type_for_prim_type(element_type)} {var_name}[] = {result};"
)
# Need to pass the array length because we can't use std::vector
return f"{var_name}, {len(val)}"
else:
return f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}"
return self.val_to_arg_str_for_prim_type(val, type_)