import torch
import warnings
import inspect
from sys import maxsize as maxsize
from typing import Set
import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
import torch.onnx.utils
from functools import wraps
from torch._C import OptionalType
# Note [Edit Symbolic Files]
# EDITING THIS FILE AND SYMBOLIC_OPSET<VERSION> FILES? READ THIS FIRST!
#
# - These files is ONLY for ATen operators (e.g., operators that show up in the
# trace as aten::blah). If you need to special case a primitive operator,
# look at _run_symbolic_function
# - Parameter ordering does NOT necessarily match what is in VariableType.cpp;
# tensors are always first, then non-tensor arguments.
# - Parameter names must *exactly* match the names in VariableType.cpp, because
# dispatch is done with keyword arguments.
# - Looking for inplace ops? They're detected by the trailing underscore, and
# transparently dispatched to their non inplace versions in
# 'run_symbolic_function'. See Note [Export inplace]
#
# ----------------------------------------------------------------------------------
# A note on Tensor types
# ----------------------------------------------------------------------------------
#
# In general, we should avoid depending on the type of Tensor Values contained
# within the trace graph. However, this is sometimes unavoidable (due to ONNX
# spec requirements, etc). The TensorType object has accessors for these properties
# that return the property if it is statically known and return nullopt otherwise.
#
# In general, we should prefer to rely on the least specific information possible.
# For example, not relying on tensor properties at all is better than relying
# on the number of dimensions which is better than relying on
# concrete shapes. Doing so will make the export symbolics
# more robust to different graphs.
# ---------------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------------
# Save some builtins as locals, because we'll shadow them below
_sum = sum
def _parse_arg(value, desc, arg_name=None, node_name=None):
if desc == 'none':
return value
if desc == 'v' or not _is_value(value):
return value
if value.node().mustBeNone():
return None
if value.node().kind() == 'onnx::Constant':
tval = value.node()['value']
if desc == 'i':
return int(tval)
elif desc == 'f':
return float(tval)
elif desc == 'b':
return bool(tval)
elif desc == 's':
return str(tval)
elif desc == 't':
return tval
elif desc == 'is':
return [int(v) for v in tval]
elif desc == 'fs':
return [float(v) for v in tval]
else:
raise RuntimeError("ONNX symbolic doesn't know to interpret Constant node")
elif value.node().kind() == 'prim::ListConstruct':
if desc == 'is':
for v in value.node().inputs():
if v.node().kind() != 'onnx::Constant':
raise RuntimeError("Failed to export an ONNX attribute '" + v.node().kind() +
"', since it's not constant, please try to make "
"things (e.g., kernel size) static if possible")
return [int(v.node()['value']) for v in value.node().inputs()]
else:
raise RuntimeError("ONNX symbolic doesn't know to interpret ListConstruct node")
if arg_name is None or node_name is None:
raise RuntimeError("Expected node type 'onnx::Constant', got '{}'.".format(value.node().kind()))
else:
raise RuntimeError("Expected node type 'onnx::Constant' "
"for argument '{}' of node '{}', got '{}'.".format(arg_name, node_name, value.node().kind()))
def _maybe_get_const(value, desc):
if _is_value(value) and value.node().kind() == 'onnx::Constant':
return _parse_arg(value, desc)
return value
def _maybe_get_scalar(value):
value_t = _maybe_get_const(value, 't')
if isinstance(value_t, torch.Tensor) and value_t.shape == ():
return value_t
return value
def _get_const(value, desc, arg_name):
if _is_value(value) and value.node().kind() not in ('onnx::Constant', 'prim::Constant'):
raise RuntimeError("ONNX symbolic expected a constant value of the {} argument, got `{}`".format(arg_name, value))
return _parse_arg(value, desc)
def _unpack_list(list_value):
list_node = list_value.node()
assert list_node.kind() == "prim::ListConstruct"
return list(list_node.inputs())
# Check if list_value is output from prim::ListConstruct
# This is usually called before _unpack_list to ensure the list can be unpacked.
def _is_packed_list(list_value):
return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct"
def parse_args(*arg_descriptors):
def decorator(fn):
fn._arg_descriptors = arg_descriptors
def wrapper(g, *args, **kwargs):
# some args may be optional, so the length may be smaller
assert len(arg_descriptors) >= len(args)
try:
sig = inspect.signature(fn)
arg_names = list(sig.parameters.keys())[1:]
fn_name = fn.__name__
except Exception:
arg_names = [None] * len(args) # type: ignore
fn_name = None # type: ignore
args = [_parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore
for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names)] # type: ignore
# only support _outputs in kwargs
assert len(kwargs) <= 1
if len(kwargs) == 1:
assert '_outputs' in kwargs
return fn(g, *args, **kwargs)
# In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
try:
wrapper = wraps(fn)(wrapper)
except Exception:
pass
return wrapper
return decorator
def _scalar(x):
"""Convert a scalar tensor into a Python value."""
assert x.numel() == 1
return x.item()
def _if_scalar_type_as(g, self, tensor):
"""
Convert self into the same type of tensor, as necessary.
We only support implicit casting for scalars, so we never
actually need to insert an ONNX cast operator here; just
fix up the scalar.
"""
if isinstance(self, torch._C.Value):
return self
scalar_type = tensor.type().scalarType()
if scalar_type:
ty = scalar_type.lower()
return getattr(self, ty)()
return self
def _is_none(x):
return x.node().mustBeNone()
def _is_value(x):
return isinstance(x, torch._C.Value)
def _is_tensor(x):
return x.type().isSubtypeOf(torch._C.TensorType.get())
def _is_tensor_list(x):
return isinstance(x.type(), torch._C.ListType) and isinstance(x.type().getElementType(), torch._C.TensorType)
def _get_tensor_rank(x):
if not _is_tensor(x) or x.type() is None:
return None
return x.type().dim()
def _get_tensor_sizes(x, allow_nonstatic=True):
if not _is_tensor(x) or x.type() is None:
return None
if allow_nonstatic:
# Each individual symbol is returned as None.
# e.g. [1, 'a', 'b'] -> [1, None, None]
return x.type().varyingSizes()
# returns None, if exists any symbol in sizes.
# e.g. [1, 'a', 'b'] -> None
return x.type().sizes()
def _get_tensor_dim_size(x, dim):
try:
sizes = _get_tensor_sizes(x)
return sizes[dim]
except Exception:
pass
return None
def _unimplemented(op, msg):
warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
def _onnx_unsupported(op_name):
raise RuntimeError('Unsupported: ONNX export of operator {}. '
'Please feel free to request support or submit a pull request on PyTorch GitHub.'.format(op_name))
def _onnx_opset_unsupported(op_name, current_opset, supported_opset):
raise RuntimeError('Unsupported: ONNX export of {} in '
'opset {}. Please try opset version {}.'.format(op_name, current_opset, supported_opset))
def _onnx_opset_unsupported_detailed(op_name, current_opset, supported_opset, reason):
raise RuntimeError('Unsupported: ONNX export of {} in '
'opset {}. {}. Please try opset version {}.'.format(op_name, current_opset, reason, supported_opset))
def _block_list_in_opset(name):
def symbolic_fn(*args, **kwargs):
raise RuntimeError("ONNX export failed on {}, which is not implemented for opset {}. "
"Try exporting with other opset versions."
.format(name, _export_onnx_opset_version))
return symbolic_fn
def _try_get_scalar_type(*args):
for arg in args:
try:
return arg.type().scalarType()
except RuntimeError:
pass
return None
def _select_helper(g, self, dim, index, apply_reshape=True):
index_const = _maybe_get_scalar(index)
index_dim = _get_tensor_rank(index)
if not _is_value(index_const):
# Index is a constant scalar. Make it a size 1 constant tensor.
index = g.op("Constant", value_t=torch.LongTensor([index_const]))
elif index_dim is not None and apply_reshape:
if index_dim == 0:
# Index is a scalar. Reshape it to a size 1 tensor.
index = g.op("Reshape", index, g.op("Constant", value_t=torch.LongTensor([1])))
index_scalar_type = index.type().scalarType()
if index_scalar_type is None or index_scalar_type not in ['Long', 'Int']:
index = g.op("Cast", index, to_i=cast_pytorch_to_onnx["Long"])
return g.op("Gather", self, index, axis_i=dim)
def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
if _export_onnx_opset_version <= 9:
from torch.onnx.symbolic_opset9 import _slice as _slice9
return _slice9(g, input, axes, starts, ends)
else:
from torch.onnx.symbolic_opset10 import _slice as _slice10
return _slice10(g, input, axes, starts, ends, steps, dynamic_slice)
def _hardtanh_helper(g, input, min_val, max_val):
if _export_onnx_opset_version <= 10:
from torch.onnx.symbolic_opset9 import hardtanh
return hardtanh(g, input, min_val, max_val)
else:
from torch.onnx.symbolic_opset11 import hardtanh # type: ignore[no-redef]
return hardtanh(g, input, min_val, max_val)
def _is_fp(value):
if value:
if isinstance(value, torch.Tensor):
type = value.dtype
return (type == 'torch.float32') or (type == 'torch.float64') or (type == 'torch.float16')
else:
type = value.type().scalarType()
if type is None:
warnings.warn("Type cannot be inferred, which might cause exported graph to produce incorrect results.")
return (type == 'Float') or (type == 'Double') or (type == 'Half')
return False
def _generate_wrapped_number(g, scalar):
"""
Create a wrapped number based on https://github.com/pytorch/pytorch/issues/9515
A Tensor is a considered a "wrapped number" if it is
auto-wrapped from a C++ or Python number type. Integer types are
wrapped as 0-dim int64 tensors and floating-point types are
wrapped as 0-dim double tensors.
The input to this function is constant value. If the data type
is a floating point type, it is converted to a 0-dim double
tensor, else it is converted to a 0-dim tensor of its original type
"""
assert not isinstance(scalar, torch.Tensor)
if isinstance(scalar, float):
return g.op("Constant", value_t=torch.tensor(scalar, dtype=torch.double))
return g.op("Constant", value_t=torch.tensor(scalar))
def _sort_helper(g, input, dim, decending=True, out=None):
if out is not None:
_unimplemented("Sort", "Out parameter is not supported")
shape_ = g.op("Shape", input)
dim_size_ = g.op("Gather", shape_, g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64)))
if _export_onnx_opset_version <= 10:
if not decending:
_unimplemented("Sort", "Ascending is not supported")
return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2)
else:
return g.op("TopK", input, dim_size_, axis_i=dim, largest_i=decending, outputs=2)
def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None):
if out is not None:
_unimplemented("TopK", "Out parameter is not supported")
if not _is_value(k):
k = g.op("Constant", value_t=torch.tensor([k], dtype=torch.int64))
else:
k = g.op("Reshape", k, g.op("Constant", value_t=torch.tensor([1])))
if _export_onnx_opset_version <= 10:
if not largest:
_unimplemented("TopK", "Ascending is not supported")
return g.op("TopK", input, k, axis_i=dim, outputs=2)
else:
return g.op("TopK", input, k, axis_i=dim, largest_i=largest, sorted_i=sorted, outputs=2)
def _interpolate_warning(interpolate_mode):
onnx_op = "onnx:Resize" if _export_onnx_opset_version >= 10 else "onnx:Upsample"
Loading ...