from functools import wraps, partial
from itertools import product, chain, islice
import itertools
import functools
import copy
import operator
import random
import unittest
import math
import enum
import torch
import numpy as np
from torch import inf, nan
from typing import Any, Dict, List, Tuple, Union, Sequence
from torch.testing import make_tensor
from torch.testing._internal.common_dtype import (
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
all_types, empty_types, complex_types_and, integral_types
)
from torch.testing._internal.common_device_type import \
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride,
skipCPUIfNoMklSparse,
toleranceOverride, tol)
from torch.testing._internal.common_cuda import (
SM53OrLater, SM60OrLater, with_tf32_off, TEST_CUDNN,
_get_torch_cuda_version, _get_torch_rocm_version, PLATFORM_SUPPORTS_FUSED_SDPA,
SM80OrLater
)
from torch.testing._internal.common_utils import (
make_fullrank_matrices_with_distinct_singular_values,
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
torch_to_numpy_dtype_dict, TEST_WITH_ASAN,
GRADCHECK_NONDET_TOL, freeze_rng_state, slowTest, TEST_WITH_SLOW
)
import torch._refs as refs # noqa: F401
import torch._refs.nn.functional
import torch._refs.special
import torch._refs.linalg
import torch._prims as prims # noqa: F401
from torch.utils._pytree import tree_flatten
from distutils.version import LooseVersion
from torch.testing._internal.opinfo.core import ( # noqa: F401
L,
M,
S,
XS,
_NOTHING,
_getattr_qual,
DecorateInfo,
SampleInput,
ErrorInput,
AliasInfo,
NumericsFilter,
OpInfo,
_generate_reduction_inputs,
_generate_reduction_kwargs,
sample_inputs_reduction,
ReductionOpInfo,
reference_inputs_elementwise_binary,
make_error_inputs_elementwise_binary,
generate_elementwise_binary_tensors,
generate_elementwise_binary_arbitrarily_strided_tensors,
generate_elementwise_binary_small_value_tensors,
generate_elementwise_binary_large_value_tensors,
generate_elementwise_binary_extremal_value_tensors,
generate_elementwise_binary_broadcasting_tensors,
generate_elementwise_binary_with_scalar_samples,
generate_elementwise_binary_with_scalar_and_type_promotion_samples,
generate_elementwise_binary_noncontiguous_tensors,
sample_inputs_elementwise_binary,
BinaryUfuncInfo,
sample_inputs_elementwise_unary,
generate_elementwise_unary_tensors,
generate_elementwise_unary_small_value_tensors,
generate_elementwise_unary_large_value_tensors,
generate_elementwise_unary_extremal_value_tensors,
reference_inputs_elementwise_unary,
UnaryUfuncInfo,
sample_inputs_spectral_ops,
SpectralFuncType,
SpectralFuncInfo,
ShapeFuncInfo,
sample_inputs_foreach,
ForeachFuncInfo,
gradcheck_wrapper_hermitian_input,
gradcheck_wrapper_triangular_input,
gradcheck_wrapper_triangular_input_real_positive_diagonal,
gradcheck_wrapper_masked_operation,
gradcheck_wrapper_masked_pointwise_operation,
clone_sample,
)
from torch.testing._internal.opinfo.refs import ( # NOQA: F401
_find_referenced_opinfo,
_inherit_constructor_args,
PythonRefInfo,
ReductionPythonRefInfo,
ElementwiseUnaryPythonRefInfo,
ElementwiseBinaryPythonRefInfo,
)
from torch.testing._internal.opinfo.utils import (
np_unary_ufunc_integer_promotion_wrapper,
reference_reduction_numpy,
prod_numpy
)
from torch.testing._internal import opinfo
from torch.testing._internal.opinfo.definitions.linalg import (
sample_inputs_linalg_cholesky,
sample_inputs_linalg_cholesky_inverse,
sample_inputs_cross,
sample_inputs_linalg_qr_geqrf,
sample_inputs_linalg_invertible,
sample_inputs_lu_solve,
sample_inputs_legacy_solve,
sample_inputs_svd,
sample_inputs_linalg_det_logdet_slogdet,
sample_inputs_linalg_lu,
)
from torch.testing._internal.opinfo.definitions.special import (
sample_inputs_i0_i1,
sample_inputs_polygamma,
reference_polygamma,
)
from torch.testing._internal.opinfo.definitions._masked import (
sample_inputs_softmax_variant,
)
if TEST_SCIPY:
from scipy import stats
import scipy.spatial
import scipy.special
# test if a tensor is close to an integer
def close_to_int(x, eps=0.1):
if x.is_complex():
y = torch.abs(torch.view_as_complex(torch.frac(torch.view_as_real(x))))
else:
y = torch.abs(torch.frac(x))
return (y < eps) | (y > (1 - eps))
def sample_inputs_slice(op_info, device, dtype, requires_grad, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype,
low=None, high=None, requires_grad=requires_grad)
yield SampleInput(make_input(3), 0)
yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2)
yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2, step=3)
yield SampleInput(make_input(20, 30, 40), dim=0, start=-10, end=-2, step=2)
def sample_inputs_tensor_split(op_info, device, dtype, requires_grad, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype,
low=None, high=None, requires_grad=requires_grad)
args_cases = (
# Cases with tensor indices.
(torch.tensor([1, 2, 3]),),
(torch.tensor(1),),
(torch.tensor([1, 2, 3]), 1),
(torch.tensor([1, 4, 2, 5, 3, 6])[::2], 1),
# Cases with list of indices.
((2, 4),),
((2, 4), 1),
((2, 4), -1),
# Cases with integer section.
(3,),
(3, 1),
(3, -1),
)
for args in args_cases:
yield SampleInput(make_input((S, S, S)), args=args)
def sample_inputs_hsplit(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device,
low=None, high=None, requires_grad=requires_grad)
yield SampleInput(make_arg(6), 2)
yield SampleInput(make_arg(S, S, S), [1, 2, 3])
def sample_inputs_vsplit(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device,
low=None, high=None, requires_grad=requires_grad)
yield SampleInput(make_arg(6, S), 2)
yield SampleInput(make_arg(S, S, S), [1, 2, 3])
def sample_inputs_dsplit(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device,
low=None, high=None, requires_grad=requires_grad)
yield SampleInput(make_arg(S, S, S), [1, 2, 3])
yield SampleInput(make_arg(S, S, 6), 2)
def error_inputs_hsplit(op_info, device, **kwargs):
make_arg = partial(make_tensor, dtype=torch.float32, device=device)
err_msg1 = ("torch.hsplit requires a tensor with at least 1 dimension, "
"but got a tensor with 0 dimensions!")
yield ErrorInput(SampleInput(make_arg(()), 0), error_regex=err_msg1)
err_msg2 = (f"torch.hsplit attempted to split along dimension 1, "
f"but the size of the dimension {S} "
f"is not divisible by the split_size 0!")
yield ErrorInput(SampleInput(make_arg((S, S, S)), 0), error_regex=err_msg2)
# Incorrect type for indices_or_section argument
err_msg3 = ("received an invalid combination of arguments.")
yield ErrorInput(
SampleInput(make_arg((S, S, S)), "abc"),
error_type=TypeError, error_regex=err_msg3)
def error_inputs_vsplit(op_info, device, **kwargs):
make_arg = partial(make_tensor, dtype=torch.float32, device=device)
err_msg1 = ("torch.vsplit requires a tensor with at least 2 dimension, "
"but got a tensor with 1 dimensions!")
yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1)
err_msg2 = (f"torch.vsplit attempted to split along dimension 0, "
f"but the size of the dimension {S} "
f"is not divisible by the split_size 0!")
yield ErrorInput(SampleInput(make_arg(S, S, S), 0),
error_regex=err_msg2)
# Incorrect type for indices_or_section argument
err_msg3 = ("received an invalid combination of arguments.")
yield ErrorInput(SampleInput(make_arg(S, S, S), "abc"),
error_type=TypeError, error_regex=err_msg3)
def error_inputs_dsplit(op_info, device, **kwargs):
make_arg = partial(make_tensor, dtype=torch.float32, device=device)
err_msg1 = ("torch.dsplit requires a tensor with at least 3 dimension, "
"but got a tensor with 1 dimensions!")
yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1)
err_msg2 = (f"torch.dsplit attempted to split along dimension 2, "
f"but the size of the dimension {S} "
f"is not divisible by the split_size 0!")
yield ErrorInput(SampleInput(make_arg(S, S, S), 0), error_regex=err_msg2)
def sample_inputs_as_strided(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# input shape, output shape, output stride, output storage offset
test_cases = (
((1,), (1,), (1,), 0),
((3, 3), (2, 2), (1, 2), 0),
((3, 3), (2, 2), (1, 2), 1),
((16,), (2, 2, 2, 2), (1, 1, 1, 1), 0),
((16,), (2, 1, 1, 2), (1, 7, 7, 1), 0),
)
for input_shape, output_shape, stride, storage_offset in test_cases:
input_t = make_arg(input_shape)
kwargs = dict(storage_offset=storage_offset)
yield SampleInput(input_t, args=(output_shape, stride), kwargs=kwargs)
def sample_inputs_as_strided_partial_views(op_info, device, dtype, requires_grad, **kwargs):
def make_arg():
base = make_tensor((20,), device=device, dtype=dtype)
return base[5:15].requires_grad_(requires_grad)
# as_strided on offset, partial views
yield SampleInput(make_arg(), (2, 2), (1, 2))
yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=0)
yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=10)
def sample_inputs_as_strided_scatter(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# input shape, output shape, output stride, output storage offset
test_cases = [
((1,), (), (), 0),
((1,), (1,), (1,), 0),
((3, 3), (2, 2), (1, 2), 0),
((3, 3), (2, 2), (1, 2), 1),
((3, 3), (2, 2), (2, 1), 0),
# Scatter to larger dimentions
((16,), (2, 2, 2, 2), (8, 4, 2, 1), 0),
# Scatter to larger dimentions with strides inverted
((16,), (2, 1, 1, 2), (1, 2, 4, 8), 0),
]
for input_shape, output_shape, stride, storage_offset in test_cases:
input_t = make_arg(input_shape)
input_src = make_arg(output_shape)
yield SampleInput(input_t, input_src, output_shape, stride, storage_offset=storage_offset)
def error_inputs_as_strided_scatter(op_info, device, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
# Create a small tensor and try to scatter it out of bounds
input_t = make_arg([4, 4])
input_src = make_arg([2, 2])
yield ErrorInput(
SampleInput(input_t, input_src, [2, 2], [200, 200], storage_offset=0),
error_regex="itemsize 4 requiring a storage size of 1604 are out of bounds for storage of size 64"
)
def sample_inputs_combinations(op_info, device, dtype, requires_grad, **kwargs):
inputs = (
(0,),
(0, 1),
(0, 1, 2, 3),
)
rvals = [1, 2, 4]
products = product(inputs, rvals, [False, True])
for input_data, r, with_replacement in products:
input_t = torch.tensor(input_data, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(input_t, r=r, with_replacement=with_replacement)
def sample_inputs_cartesian_prod(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(torch.tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# constructs 1-D tensors with varying number of elements
a = make_arg((0,))
b = make_arg((0, 1))
c = make_arg((0, 1, 2, 3))
# sample with only 1 tensor
yield SampleInput(a)
# sample with 2 tensors
yield SampleInput(a, b)
# sample with 3 tensors
yield SampleInput(a, b, c)
def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwargs):
Loading ...