Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ onnx / symbolic_opset8.py


import torch
import torch.onnx.symbolic_helper as sym_help
import torch.onnx.symbolic_opset9 as sym_opset9

from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type
from torch.onnx.symbolic_opset9 import _cast_Float  # type: ignore

import warnings

# Note [ONNX operators that are added/updated from opset 8 to opset 9]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# New operators:
#   Compress
#   ConstantOfShape
#   EyeLike
#   MaxUnpool
#   OneHot
#   Sinh
#   Cosh
#   Asinh
#   Acosh
#   Atanh
#   Shrink
#   IsNaN
#   Sign
#   Erf
#   Scatter
#   Where
#   NonZero
#   TfIdfVectorizer
#   MeanVarianceNormalization
#
# Updated operators:
#   BatchNormalization: removed spatial attribute.
#   Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported.
#   Cast: more data types{string} supported.
#   Upsample: moved scales from attribute to input.
#   Scan

block_listed_operators = [
    "nonzero", "where", "scatter", "scatter_add", "erf", "sign", "isnan", "gather",
    "arange", "masked_fill",
    "index_fill", "index_copy"
]

for block_listed_op in block_listed_operators:
    vars()[block_listed_op] = _block_list_in_opset(block_listed_op)


def _interpolate(name, dim, interpolate_mode):
    def symbolic_fn(g, input, output_size, *args):
        scales, align_corners = sym_help._get_interpolate_attributes(g, interpolate_mode, args)
        sym_help._interpolate_warning(interpolate_mode)
        align_corners = sym_help._maybe_get_scalar(align_corners)
        if align_corners:
            return _unimplemented(name, "align_corners == True")
        output_size = sym_help._maybe_get_const(output_size, 'is')
        if sym_help._is_value(output_size):
            return _unimplemented(name, "torch._C.Value (output_size) indexing")
        if scales is None:
            scales = [1. if i < 2 else
                      float(output_size[-(dim - i)]) / float(input.type().sizes()[-(dim - i)])
                      for i in range(0, dim)]
        return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales)
    return symbolic_fn


upsample_nearest1d = _interpolate('upsample_nearest1d', 3, "nearest")
upsample_nearest2d = _interpolate('upsample_nearest2d', 4, "nearest")
upsample_nearest3d = _interpolate('upsample_nearest3d', 5, "nearest")
upsample_linear1d = _interpolate('upsample_linear1d', 3, "linear")
upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, "linear")
upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, "linear")


def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_scale_factor):
    align_corners = sym_help._maybe_get_const(align_corners, 'b')
    if not sym_help._is_none(align_corners) and align_corners:
        return _unimplemented("interpolate", "align_corners == True")

    if not sym_help._is_none(scale_factor) and sym_help._is_value(scale_factor):
        return _unimplemented("interpolate", "dynamic scales in opset 8")

    if not sym_help._is_none(size) and sym_help._is_value(size):
        return _unimplemented("interpolate", "dynamic size in opset 8")

    scales, mode = sym_help._interpolate_get_scales_and_mode(g, input, size, scale_factor,
                                                             mode , align_corners)
    return g.op("Upsample", input, mode_s=mode, scales_f=scales)


# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation
#       issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which
#       is lost after casting.
def _try_cast_integer_to_float(g, *args):
    floating_scalar_types = ['Half', 'Float', 'Double']
    old_type = None
    # Cast the input tensor to Float if its scalarType is known and is not floating number.
    # If casting is performed, return the old scalarType, otherwise return None.
    arg0_type = args[0].type().scalarType()
    if arg0_type is not None:
        old_type = arg0_type
        if old_type not in floating_scalar_types:
            args = tuple(_cast_Float(g, arg, False) for arg in args)
        else:
            return (None,) + args
    else:
        warnings.warn("Only floating datatype is supported for these operators: "
                      "{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "
                      "the onnx model to be incorrect, if inputs have integer datatypes.")
    return (old_type,) + args


def _cast_to_type(g, input, to_type):
    if to_type is None:
        return input
    return getattr(sym_opset9, '_cast_{}'.format(to_type))(g, input, False)


def _comparison_operator(g, input, other, op_name):
    other = sym_help._maybe_get_scalar(other)
    other = sym_help._if_scalar_type_as(g, other, input)
    _, input, other = _try_cast_integer_to_float(g, input, other)
    return g.op(op_name, input, other)


# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
#       integer input type not supported in opset8. Cast to float if possible.
def gt(g, input, other):
    return _comparison_operator(g, input, other, "Greater")


def lt(g, input, other):
    return _comparison_operator(g, input, other, "Less")


def bmm(g, self, other):
    if _try_get_scalar_type(self):
        old_type, self, other = _try_cast_integer_to_float(g, self, other)
        return _cast_to_type(g, g.op("MatMul", self, other), old_type)
    else:
        return g.op("MatMul", self, other)


def matmul(g, self, other):
    return bmm(g, self, other)


def prelu(g, self, weight):
    self_rank = sym_help._get_tensor_rank(self)
    if self_rank is not None and self_rank > 2:
        weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1)))
    if _try_get_scalar_type(self):
        old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
        return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
    else:
        return g.op("PRelu", self, weight)


def mm(g, self, other):
    # Create a dummy C tensor. Only needed for API purposes, the value is
    # since beta = 0
    ty = sym_help._try_get_scalar_type(self, other).lower()
    C = g.constant(0, [1], ty)
    if _try_get_scalar_type(self):
        old_type, self, other, C = _try_cast_integer_to_float(g, self, other, C)
        return _cast_to_type(g, g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0), old_type)
    else:
        return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)


@parse_args('v', 'v', 'v', 't', 't')
def addmm(g, self, mat1, mat2, beta, alpha):
    if _try_get_scalar_type(self):
        old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2)
        return _cast_to_type(
            g, g.op("Gemm", mat1, mat2, self,
                    beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha)), old_type)
    else:
        return g.op("Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha))


def flatten(g, input, start_dim, end_dim):
    start_dim_i = sym_help._get_const(start_dim, 'i', 'start_dim')
    end_dim_i = sym_help._get_const(end_dim, 'i', 'end_dim')

    dim = input.type().dim()
    if end_dim_i < 0 :
        end_dim_i = dim + end_dim_i
    # use ONNX's Flatten operator for cases where the output shape is 2D
    if start_dim_i == 1 and end_dim_i == dim - 1 :
        if _try_get_scalar_type(input):
            old_type, input = _try_cast_integer_to_float(g, input)
            return _cast_to_type(g, g.op("Flatten", input, axis_i=start_dim_i), old_type)
        else:
            return g.op("Flatten", input, axis_i=start_dim_i)
    if start_dim_i == 0 and end_dim_i == dim - 2 :
        if _try_get_scalar_type(input):
            old_type, input = _try_cast_integer_to_float(g, input)
            return _cast_to_type(g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type)
        else:
            return g.op("Flatten", input, axis_i=end_dim_i + 1)

    return sym_opset9.flatten(g, input, start_dim, end_dim)


def _constant_fill(g, sizes, dtype, const_value):
    if dtype is None:
        dtype = 6  # float
    if not sym_help.scalar_type_to_pytorch_type[dtype].is_floating_point:
        result = g.op(
            "ConstantFill", sizes, dtype_i=sym_help.cast_pytorch_to_onnx["Float"], input_as_shape_i=1, value_f=const_value)
        return sym_help._cast_func_template(sym_help.scalar_type_to_onnx[dtype], g, result, None)
    else:
        return g.op("ConstantFill", sizes, dtype_i=sym_help.scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=const_value)

@parse_args('v', 'i', 'v', 'v', 'v', 'v')
def empty(g, sizes, dtype, layout, device, pin_memory=False, memory_format=None):
    return zeros(g, sizes, dtype, layout, device, pin_memory)


@parse_args('v', 'i', 'v', 'v', 'v', 'v')
def empty_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None):
    return zeros_like(g, input, dtype, layout, device, pin_memory)

@parse_args('v', 'i', 'v', 'v', 'v')
def zeros(g, sizes, dtype, layout, device, pin_memory=False):
    # NOTE: no way to set device and layout in ONNX, so we ignore it
    return _constant_fill(g, sizes, dtype, 0)


@parse_args('v', 'i', 'v', 'v', 'v', 'v')
def zeros_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None):
    shape = g.op("Shape", input)
    return _constant_fill(g, shape, dtype, 0)


@parse_args('v', 'i', 'v', 'v', 'v')
def ones(g, sizes, dtype, layout, device, pin_memory=False):
    return _constant_fill(g, sizes, dtype, 1)


@parse_args('v', 'i', 'v', 'v', 'v', 'v')
def ones_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None):
    shape = g.op("Shape", input)
    return _constant_fill(g, shape, dtype, 1)


def full(g, sizes, value, dtype, layout, device, pin_memory=False):
    const_value = sym_help._maybe_get_const(value, 't')
    if sym_help._is_value(const_value):
        tmp = zeros(g, sizes, dtype, layout, device)
        return sym_opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
    else:
        dtype = sym_help._get_const(dtype, 'i', 'dtype')
        return _constant_fill(g, sizes, dtype, const_value)


@parse_args('v', 'f', 'i', 'v', 'v', 'v', 'v')
def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False, memory_format=None):
    shape = g.op("Shape", input)
    return _constant_fill(g, shape, dtype, fill_value)


def repeat(g, self, repeats):
    if not sym_help._is_value(repeats):
        repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
    if sym_help._is_packed_list(repeats):
        repeat_size_len = len(sym_help._unpack_list(repeats))
    else:
        const_repeats = sym_help._maybe_get_const(repeats, 'is')
        repeat_size_len = len(const_repeats)
    if self.isCompleteTensor():
        sizes = self.type().sizes()
        diff_dims = repeat_size_len - len(sizes)
        if diff_dims > 0:
            self = sym_opset9.view(g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes)))
    return g.op("Tile", self, repeats)