import torch
from torch._C import ListType, OptionalType
from torch.nn.modules.utils import _single, _pair, _triple
import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
import torch.onnx.utils
from functools import partial
from functools import wraps
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented
from typing import Optional
import numpy
import math
import warnings
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
# This file exports ONNX ops for opset 9
# Opset 9 is supported by ONNX release 1.4.1
# release on 01/23/19
# Note [Pointwise by scalar]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# What happens if you add a tensor with a constant (e.g., x + 2)? There are
# some moving parts to implementing the ONNX translation in this case:
#
# - By the time we get the scalar in a symbolic function here, it is no longer
# a Python long/float, but a PyTorch tensor with numel == 1 (eventually, we
# want it to be a zero dim tensor but this change has not happened yet.)
# However, the type of this scalar is *exactly* what the user wrote in
# Python, which may not match the tensor it is being added to. PyTorch
# will do implicit conversions on scalars; however, ONNX will not, so
# we must do the conversion ourselves. This is what _if_scalar_type_as
# does.
#
# - Dispatch to these functions takes advantage an outrageous coincidence
# between the tensor and scalar name. When we add two tensors together,
# you get the dispatch:
#
# add(*[self, other], **{"alpha": alpha})
#
# When you add a tensor and a scalar, you get the dispatch:
#
# add(*[self], **{"other": other, "alpha": alpha})
#
# By having the argument name line up with the name of the scalar attribute
# if it exists, we can write a single function for both overloads.
#
# used to represent "missing" optional inputs
def unused(g):
n = g.op("prim::Constant")
n.setType(OptionalType.ofTensor())
return n
def _shape_as_tensor(g, input):
return g.op('Shape', input)
def _reshape_from_tensor(g, input, shape):
return g.op('Reshape', input, shape)
def reshape(g, self, shape):
return view(g, self, shape)
def reshape_as(g, self, other):
shape = g.op('Shape', other)
return reshape(g, self, shape)
def add(g, self, other, alpha=None):
if sym_help._is_value(self) and sym_help._is_tensor_list(self):
return sym_help._onnx_opset_unsupported_detailed('Add', 9, 11, 'Add between list of tensors not supported')
# default alpha arg is to allow no-alpha add (aten add st overload no alpha)
if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1:
return _unimplemented("add", "alpha != 1")
return g.op("Add", self, other)
def sub(g, self, other, alpha=None):
# default alpha arg is to allow no-alpha sub (aten sub st overload no alpha)
if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1:
return _unimplemented("sub", "alpha != 1")
return g.op("Sub", self, other)
def rsub(g, self, other, alpha=None):
return sub(g, other, self, alpha=alpha)
def mul(g, self, other):
return g.op("Mul", self, other)
def div(g, self, other):
return true_divide(g, self, other)
def floor_divide(g, self, other):
out = g.op('Div', self, other)
# the correct operation is truncate, which is not supported in ONNX,
# we cannot call floor since it will behave differently for negative numbers
# (eg. -0.1 should become -0 )
# - if scalar_type information are not available, assume that
# we need to call floor (treat as float)
out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx['Long'])
# Matching PyTorch's behavior:
# - if self is fp the output's type is self's type
# - if self is not fp and other is fp, the output is of type 'Float'
# - self is not fp and other is not fp, the output's type is self's output type
# - the output type defaults to Float
scalar_type = self.type().scalarType()
if scalar_type is not None:
if not sym_help._is_fp(self) and \
other.type().scalarType() is not None and \
sym_help._is_fp(other):
out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx['Float'])
else:
out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx[scalar_type])
else:
out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx['Float'])
return out
def floordiv(g, self, other):
return floor_divide(g, self, other)
# Division where both inputs are cast to floating types
# If both inputs are floating, performs div as usual
# If only one input is a floating type, the other input is cast to its type
# If neither input is a floating type, both inputs are cast to the default scalar type
def true_divide(g, self, other):
# Case 1: both values are floating
# Performs div as usual
if sym_help._is_fp(self) and sym_help._is_fp(other):
return g.op("Div", self, other)
# Case 2: self is floating, other is not
# Casts other to self's dtype
if sym_help._is_fp(self):
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
return g.op("Div", self, other)
# Case 3: other is floating, self is not
# Casts self to other's dtype
if sym_help._is_fp(other):
self = g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[other.type().scalarType()])
return g.op("Div", self, other)
# Case 4: neither is floating
# Casts both inputs to the default scalar type
scalar_type = torch.get_default_dtype()
onnx_scalar_type = sym_help.cast_pytorch_to_onnx['Float']
assert scalar_type is torch.float or scalar_type is torch.double
if torch.get_default_dtype() is torch.double:
onnx_scalar_type = sym_help.cast_pytorch_to_onnx['Double']
self = g.op("Cast", self, to_i=onnx_scalar_type)
other = g.op("Cast", other, to_i=onnx_scalar_type)
return g.op("Div", self, other)
def reciprocal(g, self):
return g.op("Div", torch.ones(1), self)
@parse_args('v', 'i')
def cat(g, tensor_list, dim):
tensors = sym_help._unpack_list(tensor_list)
return g.op("Concat", *tensors, axis_i=dim)
@parse_args('v', 'i')
def stack(g, tensor_list, dim):
unsqueezed = [sym_help._unsqueeze_helper(g, t, [dim]) for t in sym_help._unpack_list(tensor_list)]
return g.op("Concat", *unsqueezed, axis_i=dim)
def _list(g, self):
return self
def mm(g, self, other):
# Create a dummy C tensor. Only needed for API purposes, the value is
# since beta = 0
C = g.op("Constant", value_t=torch.tensor([1]))
return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
def bmm(g, self, other):
return g.op("MatMul", self, other)
def matmul(g, self, other):
return g.op("MatMul", self, other)
@parse_args('v', 'v', 'v', 't', 't')
def addmm(g, self, mat1, mat2, beta, alpha):
dtype = None
self_dtype = sym_help._try_get_scalar_type(self)
mat1_dtype = sym_help._try_get_scalar_type(mat1)
mat2_dtype = sym_help._try_get_scalar_type(mat2)
if self_dtype is not None:
dtype = self_dtype
elif mat1_dtype is not None:
dtype = mat1_dtype
elif mat2_dtype is not None:
dtype = mat2_dtype
mat1_rank = sym_help._get_tensor_rank(mat1)
mat2_rank = sym_help._get_tensor_rank(mat2)
def isNotNoneAnd(v, u):
return v is not None and v != u
if dtype is not None and (isNotNoneAnd(mat1_rank, 2) or isNotNoneAnd(mat2_rank, 2)):
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
dtype = sym_help.scalar_type_to_pytorch_type[dtype]
res1 = g.op("MatMul", mat1, mat2)
res2 = self
alpha = sym_help._scalar(alpha)
beta = sym_help._scalar(beta)
if alpha != 1:
alpha = g.op("Constant",
value_t=torch.tensor(alpha, dtype=dtype))
res1 = g.op("Mul", res1, alpha)
if beta != 1:
beta = g.op("Constant",
value_t=torch.tensor(sym_help._scalar(beta), dtype=dtype))
res2 = g.op("Mul", res2, beta)
return g.op("Add", res1, res2)
return g.op("Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha))
def neg(g, self):
return g.op("Neg", self)
def sqrt(g, self):
return g.op("Sqrt", self)
def rsqrt(g, self):
return g.op("Div", sym_help._if_scalar_type_as(g, torch.ones(1), self), sqrt(g, self))
def tanh(g, self):
return g.op("Tanh", self)
def sin(g, self):
return g.op("Sin", self)
def cos(g, self):
return g.op("Cos", self)
def tan(g, self):
return g.op("Tan", self)
def asin(g, self):
return g.op("Asin", self)
def acos(g, self):
return g.op("Acos", self)
def atan(g, self):
return g.op("Atan", self)
def sigmoid(g, self):
return g.op("Sigmoid", self)
def sign(g, self):
return g.op("Sign", self)
def _slice(g, input, axes, starts, ends):
assert len(starts) == len(ends)
if len(starts) == 1 and starts[0] == 0 and ends[0] == 9223372036854775807:
return input
return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends)
def _maybe_cast_reduce_op_input(g, self):
dtype = self.type().scalarType()
# This check only covers traced modules where dtype is present
if dtype is not None:
# pytorch reduce-ops cast all other integral types to int64
if not sym_help._is_fp(self) and not (dtype == 'Long'):
self = _cast_Long(g, self, False) # type: ignore
return self
def _reduce_op_symbolic(onnx_op_name, allow_multi_dim_support=True):
def symbolic(g, self, dim=None, keepdim=None):
self = _maybe_cast_reduce_op_input(g, self)
if dim is None:
# all-reduce path
return g.op(onnx_op_name, self, keepdims_i=0)
else:
# dim-reduce path
desc = 'is' if allow_multi_dim_support else 'i'
dim, keepdim = sym_help._get_const(dim, desc, 'dim'), sym_help._get_const(keepdim, 'i', 'keepdim')
dim_list = dim if allow_multi_dim_support else [dim]
return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim)
return symbolic
def overload_by_arg_count(fn):
@wraps(fn)
def wrapper(g, *args):
overloads = fn(g, *args)
last_exception = None
for overload in overloads:
arg_descriptors = overload._arg_descriptors
if len(arg_descriptors) == len(args):
return overload(g, *args)
raise NotImplementedError("Unknown aten::{} signature".format(fn.__name__))
Loading ...