Repository URL to install this package:
|
Version:
1.23.0 ▾
|
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import torch
from torch._C._onnx import OperatorExportTypes
TrainingMode = torch.onnx.TrainingMode
from packaging.version import Version # noqa: E402
def torch_onnx_export(
model,
args,
f,
export_params=True,
verbose=False,
training=TrainingMode.EVAL,
input_names=None,
output_names=None,
operator_export_type=OperatorExportTypes.ONNX,
opset_version=None,
_retain_param_name=None,
do_constant_folding=True,
example_outputs=None,
strip_doc_string=None,
dynamic_axes=None,
keep_initializers_as_inputs=None,
custom_opsets=None,
enable_onnx_checker=None,
use_external_data_format=None,
export_modules_as_functions=False,
):
if Version(torch.__version__) >= Version("1.11.0"):
torch.onnx.export(
model=model,
args=args,
f=f,
export_params=export_params,
verbose=verbose,
training=training,
input_names=input_names,
output_names=output_names,
operator_export_type=operator_export_type,
opset_version=opset_version,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets,
export_modules_as_functions=export_modules_as_functions,
)
else:
torch.onnx.export(
model=model,
args=args,
f=f,
export_params=export_params,
verbose=verbose,
training=training,
input_names=input_names,
output_names=output_names,
operator_export_type=operator_export_type,
opset_version=opset_version,
_retain_param_name=_retain_param_name,
do_constant_folding=do_constant_folding,
example_outputs=example_outputs,
strip_doc_string=strip_doc_string,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets,
enable_onnx_checker=enable_onnx_checker,
use_external_data_format=use_external_data_format,
)