Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ onnx / symbolic_registry.py

import warnings
import importlib
from inspect import getmembers, isfunction
from typing import Dict, Tuple, Any, Union

# The symbolic registry "_registry" is a dictionary that maps operators
# (for a specific domain and opset version) to their symbolic functions.
# An operator is defined by its domain, opset version, and opname.
# The keys are tuples (domain, version), (where domain is a string, and version is an int),
# and the operator's name (string).
# The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
_registry: Dict[Tuple[str, int], Dict] = {}

_symbolic_versions: Dict[Union[int, str], Any] = {}
from torch.onnx.symbolic_helper import _onnx_stable_opsets, _onnx_main_opset
for opset_version in _onnx_stable_opsets + [_onnx_main_opset]:
    module = importlib.import_module('torch.onnx.symbolic_opset{}'.format(opset_version))
    _symbolic_versions[opset_version] = module


def register_version(domain, version):
    if not is_registered_version(domain, version):
        global _registry
        _registry[(domain, version)] = {}
    register_ops_in_version(domain, version)


def register_ops_helper(domain, version, iter_version):
    version_ops = get_ops_in_version(iter_version)
    for op in version_ops:
        if op[0] == '_len':
            op = ('len', op[1])
        if op[0] == '_list':
            op = ('list', op[1])
        if isfunction(op[1]) and not is_registered_op(op[0], domain, version):
            register_op(op[0], op[1], domain, version)


def register_ops_in_version(domain, version):
    # iterates through the symbolic functions of
    # the specified opset version, and the previous
    # opset versions for operators supported in
    # previous versions.

    # Opset 9 is the base version. It is selected as the base version because
    #   1. It is the first opset version supported by PyTorch export.
    #   2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
    #      that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
    #      we chose to handle them as special cases separately.
    # Backward support for opset versions beyond opset 7 is not in our roadmap.

    # For opset versions other than 9, by default they will inherit the symbolic functions defined in
    # symbolic_opset9.py.
    # To extend support for updated operators in different opset versions on top of opset 9,
    # simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
    # Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
    iter_version = version
    while iter_version != 9:
        register_ops_helper(domain, version, iter_version)
        if iter_version > 9:
            iter_version = iter_version - 1
        else:
            iter_version = iter_version + 1

    register_ops_helper(domain, version, 9)


def get_ops_in_version(version):
    return getmembers(_symbolic_versions[version])


def is_registered_version(domain, version):
    global _registry
    return (domain, version) in _registry


def register_op(opname, op, domain, version):
    if domain is None or version is None:
        warnings.warn("ONNX export failed. The ONNX domain and/or version to register are None.")
    global _registry
    if not is_registered_version(domain, version):
        _registry[(domain, version)] = {}
    _registry[(domain, version)][opname] = op


def is_registered_op(opname, domain, version):
    if domain is None or version is None:
        warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
    global _registry
    return (domain, version) in _registry and opname in _registry[(domain, version)]

def get_op_supported_version(opname, domain, version):
    iter_version = version
    while iter_version <= _onnx_main_opset:
        ops = [op[0] for op in get_ops_in_version(iter_version)]
        if opname in ops:
            return iter_version
        iter_version += 1
    return None

def get_registered_op(opname, domain, version):
    if domain is None or version is None:
        warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
    global _registry
    if not is_registered_op(opname, domain, version):
        msg = "Exporting the operator " + opname + " to ONNX opset version " + str(version) + " is not supported. "
        supported_version = get_op_supported_version(opname, domain, version)
        if supported_version is not None:
            msg += "Support for this operator was added in version " + str(supported_version) + ", try exporting with this version."
        else:
            msg += "Please feel free to request support or submit a pull request on PyTorch GitHub."
        raise RuntimeError(msg)
    return _registry[(domain, version)][opname]