"""Options for FX exporter."""
from __future__ import annotations
import dataclasses
from typing import Callable, Dict
import torch
from torch.onnx import _constants
from torch.onnx._internal.fx import function_dispatcher
@dataclasses.dataclass
class ExportOptions:
"""Options for FX-ONNX export.
Attributes:
opset_version: The export ONNX version.
use_binary_format: Whether to Return ModelProto in binary format.
decomposition_table: The decomposition table for graph ops. Default is for torch ops, including aten and prim.
op_level_debug: Whether to export the model with op level debug information with onnxruntime evaluator.
"""
opset_version: int = _constants.ONNX_DEFAULT_OPSET
use_binary_format: bool = True
op_level_debug: bool = False
decomposition_table: Dict[torch._ops.OpOverload, Callable] = dataclasses.field(
default_factory=lambda: function_dispatcher._ONNX_FRIENDLY_DECOMPOSITION_TABLE
)
def update(self, **kwargs):
for key, value in kwargs.items():
if hasattr(self, key):
if value is not None:
setattr(self, key, value)
else:
raise KeyError(f"ExportOptions has no attribute {key}")