Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ static_runtime / generator.py

import json
import logging

import math
from typing import Dict, List, Optional, Sequence, Tuple, Union

import torchgen.api.cpp as cpp
from torchgen.context import native_function_manager
from torchgen.model import (
from torchgen.static_runtime import config

logger: logging.Logger = logging.getLogger()

def has_alias(
    arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]
) -> bool:
    for arg in arguments:
        annotation = getattr(arg, "annotation", None)
        if not annotation:
        alias_set = getattr(annotation, "alias_set", ())
        if alias_set:
            return True
    return False

BLOCKED_OPS = frozenset(
        # non cpu ops
        # sparse ops
        # deprecated ops
        # buggy ops
        "conj_physical",  # P495807361
        "binary_cross_entropy",  # P496394764
        # uncommon ops
        # training related ops
        # cannot call directly
        # no documentation
        # TODO: these ones got added recently and need manual inspection

def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
    base_op_name = ""
    func = None
    if isinstance(g, NativeFunctionsViewGroup):
        base_op_name = g.view.root_name
        func = g.view.func
        base_op_name = g.out.func.name.name.base
        func = g.out.func
    if config.is_hand_written(g):
        logger.info(f"HAND WRITTEN: {base_op_name}")
        return False
    if base_op_name in BLOCKED_OPS:
        logger.info(f"BLOCKED: {base_op_name}")
        return False
    for arg in func.schema_order_arguments():
        maybe_method = ivalue_type_conversion_method(arg.type)
        if not maybe_method:
            # Type converting is unsupported yet.
            logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(func)}")
            return False

    if isinstance(g, NativeFunctionsViewGroup):
        # TODO: stop doing type tests by converting to C++ and then testing
        # the string, just test the dang thing directly
        if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type():
            # Returns a non-Tensor value.
            logger.info(f"NON-TENSOR RET TYPE: {str(func)}")
            return False
        return True

    # For out variant ops, we need to check the arguments of its functional func.
    for arg in g.functional.func.schema_order_arguments():
        maybe_method = ivalue_type_conversion_method(arg.type)
        if not maybe_method:
            # Type converting is unsupported yet.
            logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(g.functional.func)}")
            return False

    if not g.structured:
        # In case of unstructured op, we check if it has out variant implementation.
        # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
        # parameter.
        if (
            not hasattr(g, "out")
            or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)")
            or not str(func.name).endswith(".out")
            return False
    # TODO: stop type testing by converting to C++
    if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type():
        logger.info(f"NON_TENSOR RET TYPE: {str(func)}")
        return False
    if has_alias(func.arguments.non_out):
        # This op may create an alias of inputs.
        logger.info(f"INPUTS ALIAS: {base_op_name}")
        return False
    return True

def ivalue_type_conversion_method(
    arg_type: Union[BaseType, OptionalType, Type]
) -> Optional[Tuple[bool, str]]:
    Return the method call expression of `c10::ivalue' to convert its contained value to
    the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
    this function returns ".toTensor()", so that it can be appended to the ivalue's
    variable name to get the value of the expected type.
    type_conversion_methods = {
        BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional<at::Tensor>()")),
        BaseTy.int: ((False, "toInt()"), (False, "toOptional<int64_t>()")),
        BaseTy.bool: ((False, "toBool()"), (False, "toOptional<bool>()")),
        BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional<at::Scalar>()")),
        BaseTy.ScalarType: (
            (False, "toScalarType()"),
            (False, "toOptional<at::ScalarType>()"),
        BaseTy.str: (
            (False, "toStringView()"),
            (False, "toOptional<c10::string_view>()"),

    base_ty_object = None
    if isinstance(arg_type, BaseType):
        base_ty_object = arg_type.name
    elif isinstance(arg_type, OptionalType):
        if not isinstance(arg_type.elem, BaseType):
            # ListType is currently unsupported.
            return None
        base_ty_object = arg_type.elem.name
        return None

    if base_ty_object not in type_conversion_methods:
        return None
    methods = type_conversion_methods[base_ty_object]
    if isinstance(arg_type, BaseType):
        return methods[0]
    return methods[1]

should_use_int_tensor_ops_ = frozenset(
Loading ...