"""
Note [ONNX operators that are added/updated from opset 8 to opset 9]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
New operators:
Compress
ConstantOfShape
EyeLike
MaxUnpool
OneHot
Sinh
Cosh
Asinh
Acosh
Atanh
Shrink
IsNaN
Sign
Erf
Scatter
Where
NonZero
TfIdfVectorizer
MeanVarianceNormalization
Updated operators:
BatchNormalization: removed spatial attribute.
Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported.
Cast: more data types{string} supported.
Upsample: moved scales from attribute to input.
Scan
"""
import functools
import warnings
import torch
from torch._C import _onnx as _C_onnx
from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8)
block_listed_operators = (
"nonzero",
"where",
"scatter",
"scatter_add",
"erf",
"sign",
"isnan",
"gather",
"arange",
"masked_fill",
"index_fill",
"index_copy",
"repeat_interleave",
"any",
"all",
)
for block_listed_op in block_listed_operators:
_onnx_symbolic(f"aten::{block_listed_op}")(
symbolic_helper._block_list_in_opset(block_listed_op)
)
def _apply_params(*args, **kwargs):
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
def _apply(fn):
return fn(*args, **kwargs)
return _apply
@_onnx_symbolic(
"aten::upsample_nearest1d",
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest2d",
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest3d",
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_linear1d",
decorate=[_apply_params("upsample_linear1d", 3, "linear")],
)
@_onnx_symbolic(
"aten::upsample_bilinear2d",
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
)
@_onnx_symbolic(
"aten::upsample_trilinear3d",
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
)
def _interpolate(name, dim, interpolate_mode):
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = symbolic_helper._get_interpolate_attributes(
g, interpolate_mode, args
)
symbolic_helper._interpolate_warning(interpolate_mode)
align_corners = symbolic_helper._maybe_get_scalar(align_corners)
if align_corners:
return symbolic_helper._unimplemented(name, "align_corners == True", input)
output_size = symbolic_helper._maybe_get_const(output_size, "is")
if symbolic_helper._is_value(output_size):
return symbolic_helper._unimplemented(
name, "torch._C.Value (output_size) indexing"
)
if scales is None:
scales = [
1.0
if i < 2
else float(output_size[-(dim - i)])
/ float(input.type().sizes()[-(dim - i)])
for i in range(0, dim)
]
return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales)
return symbolic_fn
@_onnx_symbolic("aten::__interpolate")
def __interpolate(
g: jit_utils.GraphContext,
input,
size,
scale_factor,
mode,
align_corners,
recompute_scale_factor,
antialias,
):
align_corners = symbolic_helper._maybe_get_const(align_corners, "b")
if not symbolic_helper._is_none(align_corners) and align_corners:
return symbolic_helper._unimplemented("interpolate", "align_corners == True")
if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value(
scale_factor
):
return symbolic_helper._unimplemented(
"interpolate", "dynamic scales in opset 8"
)
if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size):
return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8")
scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
g, input, size, scale_factor, mode, align_corners
)
return g.op("Upsample", input, mode_s=mode, scales_f=scales)
# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation
# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which
# is lost after casting.
def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args):
floating_scalar_types = {
_type_utils.JitScalarType.HALF,
_type_utils.JitScalarType.FLOAT,
_type_utils.JitScalarType.DOUBLE,
}
old_type = None
# Cast the input tensor to Float if its scalarType is known and is not floating number.
# If casting is performed, return the old scalarType, otherwise return None.
arg0_type = _type_utils.JitScalarType.from_value(
args[0], _type_utils.JitScalarType.UNDEFINED
)
if arg0_type != _type_utils.JitScalarType.UNDEFINED:
old_type = arg0_type
if old_type not in floating_scalar_types:
old_type = old_type.scalar_name()
args = tuple(
g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT)
for arg in args
)
else:
return (None,) + args
else:
warnings.warn(
"Only floating datatype is supported for these operators: "
"{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "
"the onnx model to be incorrect, if inputs have integer datatypes."
)
return (old_type,) + args
def _cast_to_type(g: jit_utils.GraphContext, input, to_type):
if to_type is None:
return input
return getattr(opset9, f"_cast_{to_type}")(g, input, False)
def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name):
other = symbolic_helper._maybe_get_scalar(other)
other = symbolic_helper._if_scalar_type_as(other, input)
_, input, other = _try_cast_integer_to_float(g, input, other)
return g.op(op_name, input, other)
# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
# integer input type not supported in opset8. Cast to float if possible.
@_onnx_symbolic("aten::gt")
def gt(g: jit_utils.GraphContext, input, other):
return _comparison_operator(g, input, other, "Greater")
@_onnx_symbolic("aten::lt")
def lt(g: jit_utils.GraphContext, input, other):
return _comparison_operator(g, input, other, "Less")
@_onnx_symbolic("aten::bmm")
def bmm(g: jit_utils.GraphContext, self, other):
if symbolic_helper._try_get_scalar_type(self):
old_type, self, other = _try_cast_integer_to_float(g, self, other)
return _cast_to_type(g, g.op("MatMul", self, other), old_type)
else:
return g.op("MatMul", self, other)
@_onnx_symbolic("aten::matmul")
def matmul(g: jit_utils.GraphContext, self, other):
return bmm(g, self, other)
@_onnx_symbolic("aten::prelu")
def prelu(g: jit_utils.GraphContext, self, weight):
self_rank = symbolic_helper._get_tensor_rank(self)
weight_sizes = symbolic_helper._get_tensor_sizes(weight)
if self_rank is not None and self_rank > 2:
weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1)))
elif self_rank == 0 and weight_sizes == [1]:
# self and weight are both scalar but weight has rank == 1, squeeze weight.
weight = symbolic_helper._squeeze_helper(g, weight, [0])
if symbolic_helper._try_get_scalar_type(self):
old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
else:
return g.op("PRelu", self, weight)
@_onnx_symbolic("aten::mm")
def mm(g: jit_utils.GraphContext, self, other):
# Create a dummy C tensor. Only needed for API purposes, the value is
# since beta = 0
scalar_type = symbolic_helper._try_get_scalar_type(self, other)
if scalar_type is None:
raise errors.SymbolicValueError(
"mm can only operate on tensors with known types", self
)
zero_constant = g.op(
"Constant",
value_t=torch.tensor([0], dtype=scalar_type.dtype()),
)
if symbolic_helper._try_get_scalar_type(self):
old_type, self, other, zero_constant = _try_cast_integer_to_float(
g, self, other, zero_constant
)
return _cast_to_type(
g,
g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0),
old_type,
)
return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0)
@_onnx_symbolic("aten::addmm")
@symbolic_helper.parse_args("v", "v", "v", "t", "t")
def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha):
if symbolic_helper._try_get_scalar_type(self):
old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2)
return _cast_to_type(
g,
g.op(
"Gemm",
mat1,
mat2,
self,
beta_f=symbolic_helper._scalar(beta),
alpha_f=symbolic_helper._scalar(alpha),
),
old_type,
)
else:
return g.op(
"Gemm",
mat1,
mat2,
self,
beta_f=symbolic_helper._scalar(beta),
alpha_f=symbolic_helper._scalar(alpha),
)
@_onnx_symbolic("aten::flatten")
def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim")
end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim")
dim = input.type().dim()
if end_dim_i < 0:
end_dim_i = dim + end_dim_i
# use ONNX's Flatten operator for cases where the output shape is 2D
if start_dim_i == 1 and end_dim_i == dim - 1:
if symbolic_helper._try_get_scalar_type(input):
old_type, input = _try_cast_integer_to_float(g, input)
return _cast_to_type(
g, g.op("Flatten", input, axis_i=start_dim_i), old_type
)
else:
return g.op("Flatten", input, axis_i=start_dim_i)
if start_dim_i == 0 and end_dim_i == dim - 2:
if symbolic_helper._try_get_scalar_type(input):
old_type, input = _try_cast_integer_to_float(g, input)
return _cast_to_type(
g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type
)
else:
return g.op("Flatten", input, axis_i=end_dim_i + 1)
return opset9.flatten(g, input, start_dim, end_dim)
def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value):
if dtype is None:
scalar_type = _type_utils.JitScalarType.FLOAT
else:
scalar_type = _type_utils.JitScalarType(dtype)
if not scalar_type.dtype().is_floating_point:
result = g.op(
"ConstantFill",
sizes,
dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(),
input_as_shape_i=1,
value_f=const_value,
)
return g.op("Cast", result, to_i=scalar_type.onnx_type())
else:
return g.op(
Loading ...