Repository URL to install this package:
|
Version:
2.4.0 ▾
|
#include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
#include <torch/csrc/inductor/aoti_runtime/interface.h>
#include <torch/csrc/inductor/aoti_runtime/model_container.h>
#include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
#include <torch/csrc/inductor/aoti_runtime/thread_local.h>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <vector>
#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \
try { \
__VA_ARGS__ \
} catch (const std::exception& e) { \
std::cerr << "Error: " << e.what() << std::endl; \
return AOTI_RUNTIME_FAILURE; \
} catch (...) { \
std::cerr << "Unknown exception occurred." << std::endl; \
return AOTI_RUNTIME_FAILURE; \
} \
return AOTI_RUNTIME_SUCCESS;
#define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \
do { \
AOTI_RUNTIME_CHECK( \
actual_size == expected_size, \
"expected " + std::string(name) + " vector size to be " + \
std::to_string(expected_size) + ", but got " + \
std::to_string(actual_size)); \
} while (0)
// AOTInductor uses at::addmm_out, which doesn't supports
// arguments that requires gradient. For this reason, we
// enforce no_grad context for run APIs.
//
// A RAII, thread local (!) guard that enables or disables grad mode upon
// construction, and sets it back to the original value upon destruction.
struct AOTINoGradGuard {
AOTINoGradGuard() : prev_mode(aoti_torch_grad_mode_is_enabled()) {
aoti_torch_grad_mode_set_enabled(false);
}
~AOTINoGradGuard() {
aoti_torch_grad_mode_set_enabled(prev_mode);
}
bool prev_mode;
};
extern "C" {
AOTIRuntimeError AOTInductorModelContainerCreate(
AOTInductorModelContainerHandle* container_handle,
size_t num_models,
bool is_cpu,
const char* cubin_dir) {
return AOTInductorModelContainerCreateWithDevice(
container_handle,
num_models,
is_cpu ? "cpu" : "cuda",
cubin_dir);
}
AOTIRuntimeError AOTInductorModelContainerCreateWithDevice(
AOTInductorModelContainerHandle* container_handle,
size_t num_models,
const char* device_str,
const char* cubin_dir) {
if (num_models == 0) {
std::cerr << "Error: num_models must be positive, but got 0" << std::endl;
return AOTI_RUNTIME_FAILURE;
}
CONVERT_EXCEPTION_TO_ERROR_CODE({
std::optional<std::string> cubin_dir_opt;
if (cubin_dir != nullptr) {
cubin_dir_opt.emplace(cubin_dir);
}
auto* container = new torch::aot_inductor::AOTInductorModelContainer(
num_models, std::string(device_str), cubin_dir_opt);
*container_handle =
reinterpret_cast<AOTInductorModelContainerHandle>(container);
})
}
AOTIRuntimeError AOTInductorModelContainerDelete(
AOTInductorModelContainerHandle container_handle) {
CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
delete container;
});
}
AOTIRuntimeError AOTInductorModelContainerRun(
AOTInductorModelContainerHandle container_handle,
AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed
size_t num_inputs,
AtenTensorHandle*
output_handles, // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed
size_t num_outputs,
AOTInductorStreamHandle stream_handle,
AOTIProxyExecutorHandle proxy_executor_handle) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs");
AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs");
auto stream =
reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
AOTINoGradGuard guard;
container->run(
input_handles, output_handles, stream, proxy_executor_handle);
})
}
AOTIRuntimeError AOTInductorModelContainerGetNumConstants(
AOTInductorModelContainerHandle container_handle,
size_t* num_constants) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *num_constants = container->num_constants(); })
}
AOTIRuntimeError AOTInductorModelContainerGetConstantName(
AOTInductorModelContainerHandle container_handle,
size_t idx,
const char** name) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *name = container->constant_name(idx); })
}
AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN(
AOTInductorModelContainerHandle container_handle,
size_t idx,
const char** original_fqn) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *original_fqn = container->constant_original_fqn(idx); })
}
AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded(
AOTInductorModelContainerHandle container_handle,
size_t idx,
bool* from_folded) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); })
}
AOTIRuntimeError AOTInductorModelContainerGetConstantDtype(
AOTInductorModelContainerHandle container_handle,
size_t idx,
int32_t* dtype) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *dtype = container->constant_dtype(idx); })
}
AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer(
AOTInductorModelContainerHandle container_handle,
AOTInductorConstantMapHandle constant_map_handle,
bool use_inactive,
bool validate_full_update) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
container->update_constant_buffer(
*input_map, use_inactive, validate_full_update);
})
}
AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer(
AOTInductorModelContainerHandle container_handle,
AOTInductorConstantMapHandle constant_map_handle) {
return AOTInductorModelContainerUpdateConstantBuffer(container_handle,
constant_map_handle,
/*use_inactive*/ true,
/*validate_full_update*/ true);
}
AOTIRuntimeError AOTInductorModelContainerRunConstantFolding(
AOTInductorModelContainerHandle container_handle,
bool use_inactive,
AOTInductorStreamHandle stream_handle,
AOTIProxyExecutorHandle proxy_executor_handle) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
auto stream =
reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
AOTINoGradGuard guard;
container->run_const_fold(use_inactive, stream, proxy_executor_handle);
})
}
AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer(
AOTInductorModelContainerHandle container_handle) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
container->swap_constant_buffer();
})
}
AOTIRuntimeError AOTInductorModelContainerGetNumInputs(
AOTInductorModelContainerHandle container_handle,
size_t* ret_num_inputs) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *ret_num_inputs = container->num_inputs(); })
}
AOTIRuntimeError AOTInductorModelContainerGetInputName(
AOTInductorModelContainerHandle container_handle,
size_t input_idx,
const char** ret_input_names) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *ret_input_names = container->input_name(input_idx); })
}
AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(
AOTInductorModelContainerHandle container_handle,
size_t* ret_num_outputs) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *ret_num_outputs = container->num_outputs(); })
}
AOTIRuntimeError AOTInductorModelContainerGetOutputName(
AOTInductorModelContainerHandle container_handle,
size_t output_idx,
const char** ret_output_names) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *ret_output_names = container->output_name(output_idx); })
}
AOTIRuntimeError AOTInductorModelContainerGetCallSpec(
AOTInductorModelContainerHandle container_handle,
const char** in_spec,
const char** out_spec) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
*in_spec = container->get_in_spec();
*out_spec = container->get_out_spec();
})
}
AOTIRuntimeError AOTInductorModelCreate(
AOTInductorModelHandle* model_handle,
AOTInductorConstantMapHandle constant_map_handle){
CONVERT_EXCEPTION_TO_ERROR_CODE({
auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
auto constant_array = std::make_shared<std::vector<torch::aot_inductor::ConstantHandle>>();
auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
auto model = new torch::aot_inductor::AOTInductorModel(
constant_map,
constant_array,
"cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models
""
);
if (input_map) {
for (auto const& kv : *input_map) {
constant_map->emplace(kv.first, kv.second);
}
} else {
model->load_constants();
}
*model_handle = reinterpret_cast<AOTInductorModelHandle>(model);
})}
AOTIRuntimeError AOTInductorModelRun(
AOTInductorModelHandle model_handle,
AtenTensorHandle* input_handles,
AtenTensorHandle* output_handles) {
auto model =
reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
AOTINoGradGuard guard;
model->run_impl(
input_handles,
output_handles,
(torch::aot_inductor::DeviceStreamType) nullptr,
nullptr);
})
}
AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){
CONVERT_EXCEPTION_TO_ERROR_CODE({
auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(
model_handle);
delete model;
})}
AOTIRuntimeError AOTInductorModelGetNumOutputs(
AOTInductorModelHandle model_handle,
size_t* ret_num_outputs) {
CONVERT_EXCEPTION_TO_ERROR_CODE({
auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
*ret_num_outputs = model->num_outputs();
})
}
AOTIRuntimeError AOTInductorModelUpdateConstantsMap(
AOTInductorModelHandle model_handle,
AOTInductorConstantMapHandle constant_map_handle) {
auto model =
reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
auto input_map =
reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(
constant_map_handle);
for (auto const& kv : *input_map) {
constant_map->emplace(kv.first, kv.second);
}
model->update_constants_map(std::move(constant_map));
})
}
} // extern "C"