Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ onnx / _internal / fx / function_dispatcher.py

"""Dispatcher for AtenLib functions from onnx-script."""

from __future__ import annotations

from typing import Callable, Dict, Union

import onnxscript  # type: ignore[import]
from onnxscript import opset18  # type: ignore[import]
from onnxscript.function_libs.torch_aten import ops  # type: ignore[import]

import torch
from torch.onnx._internal import _beartype


TORCH_ONNX_OPSET = onnxscript.values.Opset(domain="torch.onnx", version=1)


@onnxscript.script(opset=TORCH_ONNX_OPSET)
def prims_convert_element_type(tensor, dtype: int):
    return opset18.Cast(tensor, to=dtype)


@onnxscript.script(opset=TORCH_ONNX_OPSET)
def aten_getitem(self, i):
    # TODO(justinchuby): Support
    # i = opset18.Unsqueeze(i, opset18.Constant(value_ints=[0]))
    # return opset18.Gather(self, i, axis=0)
    return opset18.SequenceAt(self, i)


# A simple lookup table for atenlib functions
_ATENLIB_FUNCTIONS = {
    "aten::abs": ops.core.aten_abs,
    "aten::acos": ops.core.aten_acos,
    "aten::acosh": ops.core.aten_acosh,
    "aten::adaptive_avg_pool1d": ops.nn.aten_adaptive_avg_pool1d,
    "aten::adaptive_avg_pool2d": ops.nn.aten_adaptive_avg_pool2d,
    "aten::adaptive_avg_pool3d": ops.nn.aten_adaptive_avg_pool3d,
    "aten::add": ops.core.aten_add,
    "aten::addmm": ops.core.aten_addmm,
    "aten::amax": ops.core.aten_amax,
    "aten::amin": ops.core.aten_amin,
    "aten::arange": ops.core.aten_arange_start,
    "aten::argmax": ops.core.aten_argmax,
    "aten::argmin": ops.core.aten_argmin,
    "aten::asin": ops.core.aten_asin,
    "aten::asinh": ops.core.aten_asinh,
    "aten::atan": ops.core.aten_atan,
    "aten::atanh": ops.core.aten_atanh,
    "aten::bmm": ops.core.aten_bmm,
    "aten::ceil": ops.core.aten_ceil,
    "aten::celu": ops.nn.aten_celu,
    "aten::clamp_max": ops.core.aten_clamp_max,
    "aten::clamp_min": ops.core.aten_clamp_min,
    "aten::clamp": ops.core.aten_clamp,
    "aten::clone": ops.core.aten_clone,
    "aten::convolution": ops.core.aten_convolution,
    "aten::cos": ops.core.aten_cos,
    "aten::cosh": ops.core.aten_cosh,
    "aten::detach": ops.core.aten_detach,
    "aten::div": ops.core.aten_div,
    "aten::dot": ops.core.aten_dot,
    "aten::elu": ops.nn.aten_elu,
    "aten::embedding": ops.core.aten_embedding,
    "aten::empty_like": ops.core.aten_empty_like,
    "aten::empty": ops.core.aten_empty,
    "aten::eq": ops.core.aten_eq,
    "aten::equal": ops.core.aten_equal,
    "aten::erf": ops.core.aten_erf,
    "aten::exp": ops.core.aten_exp,
    "aten::exp2": ops.core.aten_exp2,
    "aten::expand": ops.core.aten_expand,
    "aten::fmod": ops.core.aten_fmod,
    "aten::full_like": ops.core.aten_full_like,
    "aten::full": ops.core.aten_full,
    "aten::ge": ops.core.aten_ge,
    "aten::gelu": ops.nn.aten_gelu,
    "aten::gt": ops.core.aten_gt,
    "aten::isinf": ops.core.aten_isinf,
    "aten::le": ops.core.aten_le,
    "aten::leaky_relu": ops.nn.aten_leaky_relu,
    "aten::linear": ops.nn.aten_linear,
    "aten::log_softmax": ops.special.aten_special_log_softmax,
    "aten::log": ops.core.aten_log,
    "aten::log10": ops.core.aten_log10,
    "aten::log1p": ops.core.aten_log1p,
    "aten::log2": ops.core.aten_log2,
    "aten::logaddexp": ops.core.aten_logaddexp,
    "aten::logaddexp2": ops.core.aten_logaddexp2,
    "aten::logcumsumexp": ops.core.aten_logcumsumexp,
    "aten::logdet": ops.core.aten_logdet,
    "aten::logsigmoid": ops.nn.aten_log_sigmoid,
    "aten::logsumexp": ops.core.aten_logsumexp,
    "aten::lt": ops.core.aten_lt,
    "aten::matmul": ops.core.aten_matmul,
    "aten::maximum": ops.core.aten_maximum,
    "aten::minimum": ops.core.aten_minimum,
    "aten::mm": ops.core.aten_mm,
    "aten::mul": ops.core.aten_mul,
    "aten::native_layer_norm": ops.core.aten_native_layer_norm,
    "aten::ne": ops.core.aten_ne,
    "aten::neg": ops.core.aten_neg,
    "aten::new_full": ops.core.aten_new_full,
    "aten::nonzero": ops.core.aten_nonzero,
    "aten::ones_like": ops.core.aten_ones_like,
    "aten::ones": ops.core.aten_ones,
    "aten::permute": ops.core.aten_permute,
    "aten::pow": ops.core.aten_pow,
    "aten::reciprocal": ops.core.aten_reciprocal,
    "aten::relu": ops.nn.aten_relu,
    "aten::relu6": ops.nn.aten_relu6,
    "aten::remainder": ops.core.aten_remainder,
    "aten::repeat": ops.core.aten_repeat,
    "aten::reshape": ops.core.aten_reshape,
    "aten::round": ops.core.aten_round,
    "aten::rsqrt": ops.core.aten_rsqrt,
    "aten::rsub": ops.core.aten_rsub,
    "aten::selu": ops.core.aten_selu,
    "aten::sigmoid": ops.core.aten_sigmoid,
    "aten::sign": ops.core.aten_sign,
    "aten::sin": ops.core.aten_sin,
    "aten::sinh": ops.core.aten_sinh,
    "aten::slice": ops.core.aten_slice,
    "aten::softmax": ops.special.aten_special_softmax,
    "aten::split": ops.core.aten_split,
    "aten::sqrt": ops.core.aten_sqrt,
    "aten::sub": ops.core.aten_sub,
    "aten::sum": ops.core.aten_sum_dim_IntList,
    "aten::t": ops.core.aten_t,
    "aten::tan": ops.core.aten_tan,
    "aten::tanh": ops.core.aten_tanh,
    "aten::topk": ops.core.aten_topk,
    "aten::transpose": ops.core.aten_transpose,
    "aten::unsqueeze": ops.core.aten_unsqueeze,
    "aten::upsample_nearest2d": ops.nn.aten_upsample_nearest2d,
    "aten::view": ops.core.aten_view,
    "aten::where": ops.core.aten_where,
    "aten::xlogy": ops.special.aten_special_xlogy,
    "aten::zeros_like": ops.core.aten_zeros_like,
    "aten::zeros": ops.core.aten_zeros,
    "getitem": aten_getitem,
    "prims::convert_element_type": prims_convert_element_type,
}


def _create_op_overload_to_exporter_key_table() -> Dict[
    Union[torch._ops.OpOverload, Callable], str
]:
    # TODO(justinchuby): Improve how the table is constructed.
    table: Dict[Union[torch._ops.OpOverload, Callable], str] = {}

    for op_namespace in (torch.ops.aten, torch.ops.prims):
        for attr_name in dir(op_namespace):
            op_overload_packet = getattr(op_namespace, attr_name)

            if not isinstance(op_overload_packet, torch._ops.OpOverloadPacket):
                continue

            exporter_look_up_key = op_overload_packet._qualified_op_name
            if _ATENLIB_FUNCTIONS.get(exporter_look_up_key) is None:
                # This aten op doesn't have ONNX exporter.
                continue

            for overload_name in op_overload_packet.overloads():
                op_overload = getattr(op_overload_packet, overload_name)
                # This line maps torch.ops.aten.add.Tensor, torch.ops.aten.add.Scalar, torch.ops.aten.add.out, etc
                # to "aten::add". This means the exporter for "aten::add" is used for all overloads of "aten::add".
                # This is applied to all ops under torch.ops.aten.
                #
                # TODO(wechi): in the future, we might want to write individual exporter for each overload, if,
                # for example, they have different type promotion rules. If so, just map different overloads to
                # different exporter keys.

                table[op_overload] = op_overload_packet._qualified_op_name
    # TODO(justinchuby): is baddbmm different?
    table[torch.ops.aten.baddbmm.default] = "aten::baddbmm"
    return table


# Dictionary that maps torch.ops.aten.* to exporter look up key; e.g.,
# _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE[torch.add.Tensor] is "aten::add".
_OP_OVERLOAD_TO_EXPORTER_KEY_TABLE = _create_op_overload_to_exporter_key_table()


@_beartype.beartype
def _create_onnx_friendly_decomposition_table() -> Dict[
    torch._ops.OpOverload, Callable
]:
    decomposition_table: Dict[torch._ops.OpOverload, Callable] = {}
    for op_overload, decomp_fn in torch._decomp.decomposition_table.items():
        # Skip decomposition into "prim::*" ops, because they are not generally supported by ONNX.
        # Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX exporter.
        if (
            "torch._refs" in decomp_fn.__module__
            or op_overload in _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE
        ):
            continue
        decomposition_table[op_overload] = decomp_fn
    return decomposition_table


# This is a subset of PyTorch's built-in aten-to-aten decomposition. If an aten
# op (e.g., torch.ops.aten.add.Tensor) has exporter, we exclude the op's decomposition
# function in the _ONNX_FRIENDLY_DECOMPOSITION_TABLE.
_ONNX_FRIENDLY_DECOMPOSITION_TABLE = _create_onnx_friendly_decomposition_table()