Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ static_runtime / generator.py

import json
import logging

import math
from typing import Dict, List, Optional, Sequence, Tuple, Union

import torchgen.api.cpp as cpp
from torchgen.context import native_function_manager
from torchgen.model import (
    Argument,
    BackendIndex,
    BaseTy,
    BaseType,
    FunctionSchema,
    NativeFunctionsGroup,
    NativeFunctionsViewGroup,
    OptionalType,
    SelfArgument,
    TensorOptionsArguments,
    Type,
)
from torchgen.static_runtime import config

logger: logging.Logger = logging.getLogger()


def has_alias(
    arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]
) -> bool:
    for arg in arguments:
        annotation = getattr(arg, "annotation", None)
        if not annotation:
            continue
        alias_set = getattr(annotation, "alias_set", ())
        if alias_set:
            return True
    return False


BLOCKED_OPS = frozenset(
    (
        # non cpu ops
        "sparse_sampled_addmm",
        "hspmm",
        "linalg_svdvals",
        # sparse ops
        "sspaddmm",
        "coalesce",
        "_indices",
        "indices",
        "_values",
        "values",
        "crow_indices",
        "col_indices",
        # deprecated ops
        "floor_divide",
        "ger",
        # buggy ops
        "conj_physical",  # P495807361
        "binary_cross_entropy",  # P496394764
        "arccosh",
        # uncommon ops
        "cholesky",
        "lu_solve",
        "linalg_cholesky",
        "linalg_householder_product",
        "linalg_ldl_solve",
        "_compute_linear_combination",
        # training related ops
        "_make_dual",
        # cannot call directly
        "_fw_primal",
        # no documentation
        "_index_reduce",
        # TODO: these ones got added recently and need manual inspection
        "_new_zeros_with_same_feature_meta",
        "_conj_physical",
        "binary_cross_entropy_with_logits",
        "bincount",
        "conv_tbc",
        "copy",
        "_copy_from",
        "_copy_from_and_resize",
        "count_nonzero",
        "cudnn_affine_grid_generator",
        "cudnn_affine_grid_generator_backward",
        "cudnn_grid_sampler",
        "diag_embed",
        "embedding",
        "embedding_dense_backward",
        "_embedding_bag_dense_backward",
        "_embedding_bag_per_sample_weights_backward",
        "grid_sampler_2d",
        "_grid_sampler_2d_cpu_fallback",
        "grid_sampler_3d",
        "isnan",
        "mkldnn_linear",
        "median",
        "nanmedian",
        "_sparse_sparse_matmul",
        "batch_norm_backward_elemt",
        "_euclidean_dist",
        "pixel_shuffle",
        "pixel_unshuffle",
        "channel_shuffle",
        "_reshape_nested_backward",
        "relu",
        "prelu",
        "celu",
        "slice_scatter",
        "select_scatter",
        "diagonal_scatter",
        "sum",
        "_mkldnn_transpose",
        "_nested_tensor_from_mask",
        "_nested_from_padded",
        "_nested_tensor_size",
        "_nested_from_padded_and_nested_example",
        "_standard_gamma_grad",
        "_dirichlet_grad",
        "native_norm",
        "_sparse_softmax",
        "_sparse_softmax_backward_data",
        "_sparse_log_softmax",
        "_sparse_log_softmax_backward_data",
        "zero",
        "_sparse_addmm",
        "sparse_mask",
        "_to_dense",
        "_coalesce",
        "_coalesced",
        "copy_sparse_to_sparse",
        "to_sparse",
        "to_sparse_csr",
        "to_sparse_csc",
        "to_mkldnn",
        "quantize_per_tensor_dynamic",
        "quantize_per_channel",
        "q_per_channel_scales",
        "q_per_channel_zero_points",
        "int_repr",
        "_make_per_channel_quantized_tensor",
        "set",
        "lift",
        "lift_fresh",
        "lift_fresh_copy",
        "masked_scatter",
        "_masked_softmax",
        "_masked_softmax_backward",
        "put",
        "index_reduce",
        "trace",
        "_cholesky_solve_helper",
        "dist",
        "max",
        "_torch_cuda_cu_linker_symbol_op",
        "glu_jvp",
        "glu_backward_jvp",
        "hardswish_backward",
        "rrelu_with_noise_backward",
        "mkldnn_adaptive_avg_pool2d_backward",
        "_adaptive_avg_pool2d_backward",
        "_adaptive_avg_pool3d_backward",
        "isinf",
        "linalg_lu_solve",
        "linalg_vecdot",
        "linalg_matrix_exp",
        "linalg_eigvalsh",
        "_test_warn_in_autograd",
        "_test_autograd_multiple_dispatch_view",
        "_test_autograd_multiple_dispatch_view_copy",
        "_segment_reduce",
        "_segment_reduce_backward",
        "_fw_primal_copy",
        "_make_dual_copy",
        "view_as_real_copy",
        "view_as_complex_copy",
        "_conj_copy",
        "_neg_view_copy",
        "diagonal_copy",
        "detach_copy",
        "squeeze_copy",
        "t_copy",
        "unsqueeze_copy",
        "_indices_copy",
        "_values_copy",
        "indices_copy",
        "values_copy",
        "crow_indices_copy",
        "col_indices_copy",
        "ccol_indices",
        "ccol_indices_copy",
        "row_indices",
        "row_indices_copy",
        "unfold_copy",
        "alias_copy",
        "_triton_multi_head_attention",
        "special_airy_ai",
        "special_bessel_j0",
        "special_bessel_j1",
        "special_bessel_y0",
        "special_bessel_y1",
        "special_chebyshev_polynomial_t",
        "special_chebyshev_polynomial_u",
        "special_chebyshev_polynomial_v",
        "special_chebyshev_polynomial_w",
        "special_hermite_polynomial_h",
        "special_hermite_polynomial_he",
        "special_laguerre_polynomial_l",
        "special_legendre_polynomial_p",
        "special_modified_bessel_i0",
        "special_modified_bessel_i1",
        "special_modified_bessel_k0",
        "special_modified_bessel_k1",
        "special_scaled_modified_bessel_k0",
        "special_scaled_modified_bessel_k1",
        "special_shifted_chebyshev_polynomial_t",
        "special_shifted_chebyshev_polynomial_u",
        "special_shifted_chebyshev_polynomial_v",
        "special_shifted_chebyshev_polynomial_w",
        "special_spherical_bessel_j0",
        "_foobar",
        "_nested_tensor_strides",
    )
)


def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
    base_op_name = ""
    func = None
    if isinstance(g, NativeFunctionsViewGroup):
        base_op_name = g.view.root_name
        func = g.view.func
    else:
        base_op_name = g.out.func.name.name.base
        func = g.out.func
    if config.is_hand_written(g):
        logger.info(f"HAND WRITTEN: {base_op_name}")
        return False
    if base_op_name in BLOCKED_OPS:
        logger.info(f"BLOCKED: {base_op_name}")
        return False
    for arg in func.schema_order_arguments():
        maybe_method = ivalue_type_conversion_method(arg.type)
        if not maybe_method:
            # Type converting is unsupported yet.
            logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(func)}")
            return False

    if isinstance(g, NativeFunctionsViewGroup):
        # TODO: stop doing type tests by converting to C++ and then testing
        # the string, just test the dang thing directly
        if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type():
            # Returns a non-Tensor value.
            logger.info(f"NON-TENSOR RET TYPE: {str(func)}")
            return False
        return True

    # For out variant ops, we need to check the arguments of its functional func.
    for arg in g.functional.func.schema_order_arguments():
        maybe_method = ivalue_type_conversion_method(arg.type)
        if not maybe_method:
            # Type converting is unsupported yet.
            logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(g.functional.func)}")
            return False

    if not g.structured:
        # In case of unstructured op, we check if it has out variant implementation.
        # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
        # parameter.
        if (
            not hasattr(g, "out")
            or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)")
            or not str(func.name).endswith(".out")
        ):
            return False
    # TODO: stop type testing by converting to C++
    if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type():
        logger.info(f"NON_TENSOR RET TYPE: {str(func)}")
        return False
    if has_alias(func.arguments.non_out):
        # This op may create an alias of inputs.
        logger.info(f"INPUTS ALIAS: {base_op_name}")
        return False
    return True


def ivalue_type_conversion_method(
    arg_type: Union[BaseType, OptionalType, Type]
) -> Optional[Tuple[bool, str]]:
    """
    Return the method call expression of `c10::ivalue' to convert its contained value to
    the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
    this function returns ".toTensor()", so that it can be appended to the ivalue's
    variable name to get the value of the expected type.
    """
    type_conversion_methods = {
        BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional<at::Tensor>()")),
        BaseTy.int: ((False, "toInt()"), (False, "toOptional<int64_t>()")),
        BaseTy.bool: ((False, "toBool()"), (False, "toOptional<bool>()")),
        BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional<at::Scalar>()")),
        BaseTy.ScalarType: (
            (False, "toScalarType()"),
            (False, "toOptional<at::ScalarType>()"),
        ),
        BaseTy.str: (
            (False, "toStringView()"),
            (False, "toOptional<c10::string_view>()"),
        ),
    }

    base_ty_object = None
    if isinstance(arg_type, BaseType):
        base_ty_object = arg_type.name
    elif isinstance(arg_type, OptionalType):
        if not isinstance(arg_type.elem, BaseType):
            # ListType is currently unsupported.
            return None
        base_ty_object = arg_type.elem.name
    else:
        return None

    if base_ty_object not in type_conversion_methods:
        return None
    methods = type_conversion_methods[base_ty_object]
    if isinstance(arg_type, BaseType):
        return methods[0]
    return methods[1]


should_use_int_tensor_ops_ = frozenset(
    (
        "bitwise_not",
        "bitwise_and",
        "bitwise_or",
        "bitwise_xor",
        "bitwise_left_shift",
        "bitwise_right_shift",
        "gcd",
        "lcm",
        "scatter",
        "gather",
        "_convert_indices_from_coo_to_csr",
        "_convert_indices_from_csr_to_coo",
    )
Loading ...