#pragma once
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/onnx/onnx.h>
#include <ostream>
namespace ONNX_NAMESPACE {
class ModelProto;
}
namespace torch {
namespace jit {
// This map is used to keep track of parameters that should be exported
// externally. When `defer_weight_export` is true, the returned map contains
// kv pairs that map {external reference name} -> {at::Tensor to be exported}.
// It is the responsibility of the caller to export these appropriately.
//
// For example, when exporting to a zip archive, the caller may write out files
// for each entry in the export map, with the filename being the key and the
// file contents being the raw tensor data.
using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;
using SymbolDimMap = std::map<c10::ShapeSymbol, std::string>;
TORCH_API std::tuple<
std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
RawDataExportMap,
SymbolDimMap>
export_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
const std::unordered_map<
std::string,
std::unordered_map<int64_t, std::string>>& dynamic_axes,
bool defer_weight_export = false,
::torch::onnx::OperatorExportTypes operator_export_type =
::torch::onnx::OperatorExportTypes::ONNX,
bool strip_doc_string = true,
bool keep_initializers_as_inputs = true,
const std::map<std::string, int>& custom_opsets = {},
bool add_node_names = true,
bool use_external_data_format = false,
const std::string& onnx_file_path = std::string());
TORCH_API std::string serialize_model_proto_to_string(
const std::shared_ptr<::ONNX_NAMESPACE::ModelProto>& model_proto);
TORCH_API void check_onnx_proto(const std::string& proto_string);
// For testing purposes
TORCH_API std::string pretty_print_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type =
::torch::onnx::OperatorExportTypes::ONNX,
bool google_printer = false,
bool keep_initializers_as_inputs = true,
const std::map<std::string, int>& custom_opsets = {},
bool add_node_names = true);
TORCH_API void ExportModule(
const Module& module,
std::ostream& out,
const ExtraFilesMap& metadata = ExtraFilesMap(),
bool bytecode_format = false,
bool save_mobile_debug_info = false);
TORCH_API void ExportModule(
const Module& module,
const std::string& filename,
const ExtraFilesMap& metadata = ExtraFilesMap(),
bool bytecode_format = false,
bool save_mobile_debug_info = false);
TORCH_API void ExportModule(
const Module& module,
const std::function<size_t(const void*, size_t)>& writer_func,
const ExtraFilesMap& metadata = ExtraFilesMap(),
bool bytecode_format = false,
bool save_mobile_debug_info = false);
// Write the bytes of a pickle archive and the tensors referenced inside that
// archive
TORCH_API void writeArchiveAndTensors(
const std::string& archive_name,
const char* pickle_bytes,
size_t size,
const std::vector<at::Tensor>& tensors,
caffe2::serialize::PyTorchStreamWriter& out);
// Surrounding system can install an additional hook to produce extra files
// with metadata based on environment every time a module is serialized.
using ExportModuleExtraFilesHook = std::function<ExtraFilesMap(const Module&)>;
TORCH_API void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook);
using ExportModuleMobileInfoConverter =
std::function<c10::Dict<std::string, std::string>(
const Module&,
const std::unordered_map<std::string, std::string>&)>;
TORCH_API void SetExportModuleMobileInfoConverter(
ExportModuleMobileInfoConverter converter);
// Returns a list of names of all operators in the module and its submodules.
TORCH_API std::vector<std::string> export_opnames(const Module& m);
namespace mobile {
class Module;
/**
* Given a torch::jit::mobile::Module, return a set of operator names
* (with overload name) that are used by any method in this mobile
* Mobile. This method runs through the bytecode for all methods
* in the specified model (module), and extracts all the root
* operator names. Root operators are operators that are called
* directly by the model (as opposed to non-root operators, which
* may be called transitively by the root operators).
*
*/
TORCH_API std::set<std::string> _export_operator_list(
torch::jit::mobile::Module& module);
} // namespace mobile
} // namespace jit
} // namespace torch