Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ onnx / symbolic_opset12.py

import functools
import sys
from typing import Optional, Tuple

import torch
from torch._C import _onnx as _C_onnx
from torch.onnx import (
    _type_utils,
    errors,
    symbolic_helper,
    symbolic_opset9 as opset9,
    utils,
)
from torch.onnx._internal import _beartype, jit_utils, registration


# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md

# This file exports ONNX ops for opset 12

__all__ = [
    "argmax",
    "argmin",
    "binary_cross_entropy_with_logits",
    "celu",
    "cross_entropy_loss",
    "dropout",
    "einsum",
    "ge",
    "le",
    "native_dropout",
    "nll_loss",
    "nll_loss2d",
    "nll_loss_nd",
    "outer",
    "pow",
    "tensordot",
    "unfold",
]

_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12)


@_beartype.beartype
def _einsum_helper(g: jit_utils.GraphContext, equation, tensors):
    if not tensors:
        raise RuntimeError("Einsum inputs are empty.")
    # ONNX does not support bool for Einsum inputs.
    if symbolic_helper._is_bool(tensors[0]):
        tensors = [
            g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64)
            for tensor in tensors
        ]
        return g.op(
            "Cast",
            g.op("Einsum", *tensors, equation_s=equation),
            to_i=_C_onnx.TensorProtoDataType.BOOL,
        )
    else:
        return g.op("Einsum", *tensors, equation_s=equation)


@_onnx_symbolic("aten::einsum")
@symbolic_helper.parse_args("s", "v", "is")
@_beartype.beartype
def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None):
    tensors = symbolic_helper._unpack_list(tensor_list)
    return _einsum_helper(g, equation, tensors)


@_onnx_symbolic("aten::outer")
@symbolic_helper.parse_args("v", "v")
@_beartype.beartype
def outer(g: jit_utils.GraphContext, input, other):
    # make sure to cast other to self's type
    if _type_utils.JitScalarType.from_value(
        other, _type_utils.JitScalarType.UNDEFINED
    ) != _type_utils.JitScalarType.from_value(input):
        other = g.op(
            "Cast",
            other,
            to_i=_type_utils.JitScalarType.from_value(input).onnx_type(),
        )
    return _einsum_helper(g, "i,j->ij", [input, other])


@_beartype.beartype
def _dropout_returns_masked_input_and_mask(
    g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool
) -> Tuple[torch._C.Value, Optional[torch._C.Value]]:
    symbolic_helper.check_training_mode(train, "dropout")
    # In eval mode, dropout is non-op. That is, if the node's
    # train param is set to False, dropout just returns its inputs.
    if not train:
        return input, None
    p = g.op("Constant", value_t=torch.tensor(p))
    t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool))
    r, mask = g.op("Dropout", input, p, t, outputs=2)
    return r, mask


@_onnx_symbolic("aten::dropout")
@symbolic_helper.parse_args("v", "f", "b")
@_beartype.beartype
def dropout(g: jit_utils.GraphContext, input, p, train):
    masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train)
    return masked


@_onnx_symbolic("aten::native_dropout")
@symbolic_helper.parse_args("v", "f", "b")
@_beartype.beartype
def native_dropout(g: jit_utils.GraphContext, input, p, train):
    return _dropout_returns_masked_input_and_mask(g, input, p, train)


@_onnx_symbolic("aten::nll_loss")
@_beartype.beartype
def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index):
    # none reduction : onnx::Constant[value={0}]
    # mean reduction : onnx::Constant[value={1}]
    # sum reduction : onnx::Constant[value={2}]
    reduction = symbolic_helper._maybe_get_const(reduction, "i")
    reduction_vals = ["none", "mean", "sum"]
    reduction = reduction_vals[reduction]

    # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value.
    # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
    ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
    if weight.node().mustBeNone():
        nllloss = g.op(
            "NegativeLogLikelihoodLoss",
            self,
            target,
            reduction_s=reduction,
            ignore_index_i=ignore_index,
        )
    else:
        nllloss = g.op(
            "NegativeLogLikelihoodLoss",
            self,
            target,
            weight,
            reduction_s=reduction,
            ignore_index_i=ignore_index,
        )

    return nllloss


@_onnx_symbolic("aten::nll_loss2d")
@_beartype.beartype
def nll_loss2d(
    g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
):
    return nll_loss(g, self, target, weight, reduction, ignore_index)


@_onnx_symbolic("aten::nll_loss_nd")
@_beartype.beartype
def nll_loss_nd(
    g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
):
    return nll_loss(g, self, target, weight, reduction, ignore_index)


@_onnx_symbolic("aten::cross_entropy_loss")
@_beartype.beartype
def cross_entropy_loss(
    g: jit_utils.GraphContext,
    self,
    target,
    weight,
    reduction,
    ignore_index,
    label_smoothing,
):
    # none reduction : onnx::Constant[value={0}]
    # mean reduction : onnx::Constant[value={1}]
    # sum reduction : onnx::Constant[value={2}]
    reduction = symbolic_helper._maybe_get_const(reduction, "i")
    reduction_vals = ["none", "mean", "sum"]
    reduction = reduction_vals[reduction]

    label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f")
    if label_smoothing is not None and label_smoothing > 0.0:
        raise errors.SymbolicValueError(
            "Unsupported: ONNX does not support label_smoothing", self
        )

    # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value.
    # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
    ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
    if weight.node().mustBeNone():
        celoss = g.op(
            "SoftmaxCrossEntropyLoss",
            self,
            target,
            reduction_s=reduction,
            ignore_index_i=ignore_index,
        )
    else:
        celoss = g.op(
            "SoftmaxCrossEntropyLoss",
            self,
            target,
            weight,
            reduction_s=reduction,
            ignore_index_i=ignore_index,
        )

    return celoss


@_onnx_symbolic("aten::binary_cross_entropy_with_logits")
@symbolic_helper.parse_args("v", "v", "v", "v", "i")
@_beartype.beartype
def binary_cross_entropy_with_logits(
    g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction
):
    p = g.op("Constant", value_t=torch.tensor([1]))
    sig_x = opset9.sigmoid(g, input)
    log_sig_x = opset9.log(g, sig_x)
    sub_1_x = opset9.sub(g, p, sig_x)
    sub_1_y = opset9.sub(g, p, target)
    log_1_x = opset9.log(g, sub_1_x)
    if pos_weight is None or symbolic_helper._is_none(pos_weight):
        output = opset9.neg(
            g,
            opset9.add(
                g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x)
            ),
        )
    else:
        output = opset9.neg(
            g,
            opset9.add(
                g,
                opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight),
                opset9.mul(g, sub_1_y, log_1_x),
            ),
        )

    if weight is not None and not symbolic_helper._is_none(weight):
        output = opset9.mul(g, weight, output)

    reduction = symbolic_helper._maybe_get_const(reduction, "i")
    if reduction == 0:
        return output
    elif reduction == 1:
        return g.op("ReduceMean", output, keepdims_i=0)
    elif reduction == 2:
        return g.op("ReduceSum", output, keepdims_i=0)
    else:
        return symbolic_helper._onnx_unsupported(
            "binary_cross_entropy_with_logits with reduction other than none, mean, or sum",
            input,
        )


@_onnx_symbolic("aten::celu")
@_beartype.beartype
def celu(g: jit_utils.GraphContext, self, alpha):
    alpha = symbolic_helper._maybe_get_const(alpha, "f")
    # if the input is of type double cast it to float
    if (
        _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
        == _type_utils.JitScalarType.DOUBLE
    ):
        self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
        out = g.op("Celu", self, alpha_f=alpha)
        return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE)

    return g.op("Celu", self, alpha_f=alpha)


@_onnx_symbolic("aten::argmax")
@symbolic_helper.parse_args("v", "v", "b")
@_beartype.beartype
def argmax(
    g: jit_utils.GraphContext,
    input: torch._C.Value,
    dim: torch._C.Value,
    keepdim: bool,
):
    return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")


@_onnx_symbolic("aten::argmin")
@symbolic_helper.parse_args("v", "v", "b")
@_beartype.beartype
def argmin(
    g: jit_utils.GraphContext,
    input: torch._C.Value,
    dim: torch._C.Value,
    keepdim: bool,
):
    return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")


@_onnx_symbolic("aten::pow")
@_beartype.beartype
def pow(g: jit_utils.GraphContext, self, exponent):
    return g.op("Pow", self, exponent)


@_onnx_symbolic("aten::ge")
@_beartype.beartype
def ge(g: jit_utils.GraphContext, input, other):
    return g.op("GreaterOrEqual", input, other)


@_onnx_symbolic("aten::le")
@_beartype.beartype
def le(g: jit_utils.GraphContext, input, other):
    return g.op("LessOrEqual", input, other)


@_onnx_symbolic("aten::unfold")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
    const_size = symbolic_helper._maybe_get_const(size, "i")
    const_step = symbolic_helper._maybe_get_const(step, "i")
    if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value(
        const_step
    ):
        return opset9.unfold(g, input, dimension, const_size, const_step)
    if symbolic_helper.is_caffe2_aten_fallback():
        return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step)

    sizedim = symbolic_helper._get_tensor_dim_size(input, dimension)
    if sizedim is not None:
        low_start = g.op("Constant", value_t=torch.tensor(0))
        low_end = g.op("Constant", value_t=torch.tensor(sizedim))
        hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1))
        low_indices = g.op("Range", low_start, low_end, step)
        hi_indices = g.op("Range", size, hi_end, step)

        low_size = symbolic_helper._size_helper(
            g, low_indices, g.op("Constant", value_t=torch.tensor(0))
        )
        hi_size = symbolic_helper._size_helper(
            g, hi_indices, g.op("Constant", value_t=torch.tensor(0))
Loading ...