"""This file exports ONNX ops for opset 14.
Note [ONNX operators that are added/updated in opset 14]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
New operators:
HardSwish, Trilu
Updated operators:
Reshape
Add, Sub, Mul, Div
GRU, LSTM, RNN
BatchNorm, Cumsum, Relu
"""
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
import functools
import torch
from torch.onnx import symbolic_helper
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype, jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14)
@_onnx_symbolic("aten::hardswish")
@symbolic_helper.parse_args("v")
@_beartype.beartype
def hardswish(g: jit_utils.GraphContext, self):
return g.op("HardSwish", self)
@_onnx_symbolic("aten::tril")
@_beartype.beartype
def tril(g: jit_utils.GraphContext, self, diagonal, out=None):
return g.op("Trilu", self, diagonal, upper_i=0)
@_onnx_symbolic("aten::triu")
@_beartype.beartype
def triu(g: jit_utils.GraphContext, self, diagonal, out=None):
return g.op("Trilu", self, diagonal, upper_i=1)
@_onnx_symbolic("aten::reshape")
@symbolic_helper.parse_args("v", "v")
@_beartype.beartype
def reshape(g: jit_utils.GraphContext, self, shape):
# NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664
# Reshape export cannot utilize the new allowzero attribute introduced in opset 14.
return symbolic_helper._reshape_helper(g, self, shape, allowzero=0)
@_onnx_symbolic("aten::batch_norm")
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
@_beartype.beartype
def batch_norm(
g: jit_utils.GraphContext,
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
cudnn_enabled,
):
if (
torch.is_autocast_enabled()
and not symbolic_helper.args_have_same_dtype(
[input, weight, bias, running_mean, running_var]
)
and GLOBALS.export_onnx_opset_version < 15
):
return symbolic_helper._onnx_opset_unsupported_detailed(
"BatchNormalization",
14,
15,
"All input tensors must have the same `dtype`."
" Turn off Autocast or export using opset version 15.",
input,
)
symbolic_helper.check_training_mode(training, "batch_norm")
weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper(
g, input, weight, bias, running_mean, running_var
)
out = g.op(
"BatchNormalization",
input,
weight,
bias,
running_mean,
running_var,
epsilon_f=eps,
momentum_f=1 - momentum,
training_mode_i=0 if not training else 1,
outputs=1 if not training else 3,
)
if not training:
return out
else:
res, new_running_mean, new_running_var = out
new_running_mean.setType(running_mean.type())
new_running_var.setType(running_var.type())
return res
@_onnx_symbolic("quantized::hardswish")
@_beartype.beartype
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
output = hardswish(g, x)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)