import json
import logging
import math
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torchgen.api.cpp as cpp
from torchgen.context import native_function_manager
from torchgen.model import (
Argument,
BackendIndex,
BaseTy,
BaseType,
FunctionSchema,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
OptionalType,
SelfArgument,
TensorOptionsArguments,
Type,
)
from torchgen.static_runtime import config
logger: logging.Logger = logging.getLogger()
def has_alias(
arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]
) -> bool:
for arg in arguments:
annotation = getattr(arg, "annotation", None)
if not annotation:
continue
alias_set = getattr(annotation, "alias_set", ())
if alias_set:
return True
return False
BLOCKED_OPS = frozenset(
(
# non cpu ops
"sparse_sampled_addmm",
"hspmm",
"linalg_svdvals",
# sparse ops
"sspaddmm",
"coalesce",
"_indices",
"indices",
"_values",
"values",
"crow_indices",
"col_indices",
# deprecated ops
"floor_divide",
"ger",
# buggy ops
"conj_physical", # P495807361
"binary_cross_entropy", # P496394764
"arccosh",
# uncommon ops
"cholesky",
"lu_solve",
"linalg_cholesky",
"linalg_householder_product",
"linalg_ldl_solve",
"_compute_linear_combination",
# training related ops
"_make_dual",
# cannot call directly
"_fw_primal",
# no documentation
"_index_reduce",
# TODO: these ones got added recently and need manual inspection
"_new_zeros_with_same_feature_meta",
"_conj_physical",
"binary_cross_entropy_with_logits",
"bincount",
"conv_tbc",
"copy",
"_copy_from",
"_copy_from_and_resize",
"count_nonzero",
"cudnn_affine_grid_generator",
"cudnn_affine_grid_generator_backward",
"cudnn_grid_sampler",
"diag_embed",
"embedding",
"embedding_dense_backward",
"_embedding_bag_dense_backward",
"_embedding_bag_per_sample_weights_backward",
"grid_sampler_2d",
"_grid_sampler_2d_cpu_fallback",
"grid_sampler_3d",
"isnan",
"mkldnn_linear",
"median",
"nanmedian",
"_sparse_sparse_matmul",
"batch_norm_backward_elemt",
"_euclidean_dist",
"pixel_shuffle",
"pixel_unshuffle",
"channel_shuffle",
"_reshape_nested_backward",
"relu",
"prelu",
"celu",
"slice_scatter",
"select_scatter",
"diagonal_scatter",
"sum",
"_mkldnn_transpose",
"_nested_tensor_from_mask",
"_nested_from_padded",
"_nested_tensor_size",
"_nested_from_padded_and_nested_example",
"_standard_gamma_grad",
"_dirichlet_grad",
"native_norm",
"_sparse_softmax",
"_sparse_softmax_backward_data",
"_sparse_log_softmax",
"_sparse_log_softmax_backward_data",
"zero",
"_sparse_addmm",
"sparse_mask",
"_to_dense",
"_coalesce",
"_coalesced",
"copy_sparse_to_sparse",
"to_sparse",
"to_sparse_csr",
"to_sparse_csc",
"to_mkldnn",
"quantize_per_tensor_dynamic",
"quantize_per_channel",
"q_per_channel_scales",
"q_per_channel_zero_points",
"int_repr",
"_make_per_channel_quantized_tensor",
"set",
"lift",
"lift_fresh",
"lift_fresh_copy",
"masked_scatter",
"_masked_softmax",
"_masked_softmax_backward",
"put",
"index_reduce",
"trace",
"_cholesky_solve_helper",
"dist",
"max",
"_torch_cuda_cu_linker_symbol_op",
"glu_jvp",
"glu_backward_jvp",
"hardswish_backward",
"rrelu_with_noise_backward",
"mkldnn_adaptive_avg_pool2d_backward",
"_adaptive_avg_pool2d_backward",
"_adaptive_avg_pool3d_backward",
"isinf",
"linalg_lu_solve",
"linalg_vecdot",
"linalg_matrix_exp",
"linalg_eigvalsh",
"_test_warn_in_autograd",
"_test_autograd_multiple_dispatch_view",
"_test_autograd_multiple_dispatch_view_copy",
"_segment_reduce",
"_segment_reduce_backward",
"_fw_primal_copy",
"_make_dual_copy",
"view_as_real_copy",
"view_as_complex_copy",
"_conj_copy",
"_neg_view_copy",
"diagonal_copy",
"detach_copy",
"squeeze_copy",
"t_copy",
"unsqueeze_copy",
"_indices_copy",
"_values_copy",
"indices_copy",
"values_copy",
"crow_indices_copy",
"col_indices_copy",
"ccol_indices",
"ccol_indices_copy",
"row_indices",
"row_indices_copy",
"unfold_copy",
"alias_copy",
"_triton_multi_head_attention",
"special_airy_ai",
"special_bessel_j0",
"special_bessel_j1",
"special_bessel_y0",
"special_bessel_y1",
"special_chebyshev_polynomial_t",
"special_chebyshev_polynomial_u",
"special_chebyshev_polynomial_v",
"special_chebyshev_polynomial_w",
"special_hermite_polynomial_h",
"special_hermite_polynomial_he",
"special_laguerre_polynomial_l",
"special_legendre_polynomial_p",
"special_modified_bessel_i0",
"special_modified_bessel_i1",
"special_modified_bessel_k0",
"special_modified_bessel_k1",
"special_scaled_modified_bessel_k0",
"special_scaled_modified_bessel_k1",
"special_shifted_chebyshev_polynomial_t",
"special_shifted_chebyshev_polynomial_u",
"special_shifted_chebyshev_polynomial_v",
"special_shifted_chebyshev_polynomial_w",
"special_spherical_bessel_j0",
"_foobar",
"_nested_tensor_strides",
)
)
def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
base_op_name = ""
func = None
if isinstance(g, NativeFunctionsViewGroup):
base_op_name = g.view.root_name
func = g.view.func
else:
base_op_name = g.out.func.name.name.base
func = g.out.func
if config.is_hand_written(g):
logger.info(f"HAND WRITTEN: {base_op_name}")
return False
if base_op_name in BLOCKED_OPS:
logger.info(f"BLOCKED: {base_op_name}")
return False
for arg in func.schema_order_arguments():
maybe_method = ivalue_type_conversion_method(arg.type)
if not maybe_method:
# Type converting is unsupported yet.
logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(func)}")
return False
if isinstance(g, NativeFunctionsViewGroup):
# TODO: stop doing type tests by converting to C++ and then testing
# the string, just test the dang thing directly
if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type():
# Returns a non-Tensor value.
logger.info(f"NON-TENSOR RET TYPE: {str(func)}")
return False
return True
# For out variant ops, we need to check the arguments of its functional func.
for arg in g.functional.func.schema_order_arguments():
maybe_method = ivalue_type_conversion_method(arg.type)
if not maybe_method:
# Type converting is unsupported yet.
logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(g.functional.func)}")
return False
if not g.structured:
# In case of unstructured op, we check if it has out variant implementation.
# The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
# parameter.
if (
not hasattr(g, "out")
or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)")
or not str(func.name).endswith(".out")
):
return False
# TODO: stop type testing by converting to C++
if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type():
logger.info(f"NON_TENSOR RET TYPE: {str(func)}")
return False
if has_alias(func.arguments.non_out):
# This op may create an alias of inputs.
logger.info(f"INPUTS ALIAS: {base_op_name}")
return False
return True
def ivalue_type_conversion_method(
arg_type: Union[BaseType, OptionalType, Type]
) -> Optional[Tuple[bool, str]]:
"""
Return the method call expression of `c10::ivalue' to convert its contained value to
the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
this function returns ".toTensor()", so that it can be appended to the ivalue's
variable name to get the value of the expected type.
"""
type_conversion_methods = {
BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional<at::Tensor>()")),
BaseTy.int: ((False, "toInt()"), (False, "toOptional<int64_t>()")),
BaseTy.bool: ((False, "toBool()"), (False, "toOptional<bool>()")),
BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional<at::Scalar>()")),
BaseTy.ScalarType: (
(False, "toScalarType()"),
(False, "toOptional<at::ScalarType>()"),
),
BaseTy.str: (
(False, "toStringView()"),
(False, "toOptional<c10::string_view>()"),
),
}
base_ty_object = None
if isinstance(arg_type, BaseType):
base_ty_object = arg_type.name
elif isinstance(arg_type, OptionalType):
if not isinstance(arg_type.elem, BaseType):
# ListType is currently unsupported.
return None
base_ty_object = arg_type.elem.name
else:
return None
if base_ty_object not in type_conversion_methods:
return None
methods = type_conversion_methods[base_ty_object]
if isinstance(arg_type, BaseType):
return methods[0]
return methods[1]
should_use_int_tensor_ops_ = frozenset(
(
"bitwise_not",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"bitwise_left_shift",
"bitwise_right_shift",
"gcd",
"lcm",
"scatter",
"gather",
"_convert_indices_from_coo_to_csr",
"_convert_indices_from_csr_to_coo",
)
Loading ...