import unittest
from functools import partial
from itertools import product
from typing import List
import numpy as np
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
precisionOverride,
tol,
toleranceOverride,
)
from torch.testing._internal.common_dtype import all_types_and, floating_types
from torch.testing._internal.common_utils import TEST_SCIPY, torch_to_numpy_dtype_dict
from torch.testing._internal.opinfo.core import (
BinaryUfuncInfo,
DecorateInfo,
L,
NumericsFilter,
OpInfo,
S,
SampleInput,
UnaryUfuncInfo,
)
from torch.testing._internal.opinfo.refs import (
ElementwiseBinaryPythonRefInfo,
ElementwiseUnaryPythonRefInfo,
)
from torch.testing._internal.opinfo.utils import (
np_unary_ufunc_integer_promotion_wrapper,
)
if TEST_SCIPY:
import scipy.special
# TODO: Consolidate `i0e` with sample_inputs_unary when `make_tensor`,
# supports `exclude` argument.
# For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617
def sample_inputs_i0_i1(op_info, device, dtype, requires_grad, **kwargs):
exclude_zero = requires_grad and op_info.op == torch.special.i0e
make_arg = partial(
make_tensor,
dtype=dtype,
device=device,
requires_grad=requires_grad,
exclude_zero=exclude_zero,
)
yield SampleInput(make_arg((S,)))
yield SampleInput(make_arg(()))
if requires_grad and not exclude_zero:
# Special Case for gradient
# Sample with `0` in the input
t = make_arg((S,))
t[0] = 0
yield SampleInput(t)
def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
tensor_shapes = ((S, S), ())
ns = (1, 2, 3, 4, 5)
for shape, n in product(tensor_shapes, ns):
yield SampleInput(make_arg(shape), args=(n,))
def reference_polygamma(x, n):
# WEIRD `scipy.special.polygamma` behavior
# >>> scipy.special.polygamma(0, np.array(501, dtype=np.float32)).dtype
# dtype('float64')
# >>> scipy.special.polygamma(0, np.array([501], dtype=np.float32)).dtype
# dtype('float32')
#
# Thus we cast output to the default torch dtype or preserve double
result_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()]
if x.dtype == np.double:
result_dtype = np.double
return scipy.special.polygamma(n, x).astype(result_dtype)
def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
low, _ = op_info.domain
if requires_grad:
low = 0 + op_info._domain_eps
make_arg = partial(
make_tensor, dtype=dtype, device=device, low=low, requires_grad=requires_grad
)
yield SampleInput(make_arg((L,)))
yield SampleInput(make_arg(()))
op_db: List[OpInfo] = [
UnaryUfuncInfo(
"special.i0e",
aten_name="special_i0e",
ref=scipy.special.i0e if TEST_SCIPY else None,
decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),),
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
backward_dtypes=floating_types(),
sample_inputs_func=sample_inputs_i0_i1,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
),
UnaryUfuncInfo(
"special.i1",
aten_name="special_i1",
ref=np_unary_ufunc_integer_promotion_wrapper(scipy.special.i1)
if TEST_SCIPY
else None,
dtypes=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool),
sample_inputs_func=sample_inputs_i0_i1,
decorators=(
DecorateInfo(
toleranceOverride(
{
torch.float32: tol(atol=1e-4, rtol=0),
torch.bool: tol(atol=1e-4, rtol=0),
}
)
),
),
skips=(
DecorateInfo(
unittest.skip("Incorrect result!"),
"TestUnaryUfuncs",
"test_reference_numerics_large",
dtypes=(torch.int8,),
),
),
supports_fwgrad_bwgrad=True,
supports_forward_ad=True,
),
UnaryUfuncInfo(
"special.i1e",
aten_name="special_i1e",
ref=scipy.special.i1e if TEST_SCIPY else None,
dtypes=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool),
sample_inputs_func=sample_inputs_i0_i1,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
),
UnaryUfuncInfo(
"special.ndtr",
aten_name="special_ndtr",
decorators=(precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-4}),),
ref=scipy.special.ndtr if TEST_SCIPY else None,
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.float16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
# Dispatch stub: unsupported device typemeta
DecorateInfo(
unittest.expectedFailure,
"TestFwdGradients",
"test_fn_fwgrad_bwgrad",
device_type="meta",
),
),
),
# A separate OpInfo entry for special.polygamma is needed to reorder the arguments
# for the alias. See the discussion here: https://github.com/pytorch/pytorch/pull/59691#discussion_r650261939
UnaryUfuncInfo(
"special.polygamma",
op=lambda x, n, **kwargs: torch.special.polygamma(n, x, **kwargs),
variant_test_name="special_polygamma_n_0",
ref=reference_polygamma if TEST_SCIPY else None,
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_polygamma,
skips=(
# lambda impl
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
),
sample_kwargs=lambda device, dtype, input: ({"n": 0}, {"n": 0}),
# polygamma functions have multiple singularities at x <= 0
reference_numerics_filter=NumericsFilter(
condition=lambda x: x < 0.1, safe_val=1
),
),
BinaryUfuncInfo(
"special.xlog1py",
aten_name="special_xlog1py",
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
promotes_int_to_float=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_one_python_scalar=True,
# We don't test -1 as the gradient will be NaN and it'll break
rhs_make_tensor_kwargs=dict(low=-0.99),
),
BinaryUfuncInfo(
"special.zeta",
aten_name="special_zeta",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
supports_autograd=False,
supports_one_python_scalar=True,
skips=(
# Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
),
# TODO: FIXME
# OpInfo entry to verify the gradient formula of `other`/`q`
# BinaryUfuncInfo('special.zeta',
# op=lambda q, x, **kwargs: torch.special.zeta(x, q, **kwargs),
# aten_name='special_zeta',
# variant_test_name='grad',
# dtypes=all_types_and(torch.bool),
# promotes_int_to_float=True,
# supports_autograd=True,
# supports_rhs_python_scalar=False,
# decorators=[
# # Derivative wrt first tensor not implemented
# DecorateInfo(unittest.expectedFailure, "TestCommon",
# "test_floating_inputs_are_differentiable")
# ],
# skips=(
# # Lambda doesn't work in JIT test
# # AssertionError: JIT Test does not execute any logic
# DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"),
# )),
UnaryUfuncInfo(
"special.entr",
ref=scipy.special.entr if TEST_SCIPY else None,
aten_name="special_entr",
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),),
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
skips=(
DecorateInfo(
unittest.skip("Skipped!"),
"TestUnaryUfuncs",
"test_reference_numerics_large",
dtypes=[torch.bfloat16, torch.float16],
),
),
supports_inplace_autograd=False,
sample_inputs_func=sample_inputs_entr,
),
UnaryUfuncInfo(
"special.ndtri",
ref=scipy.special.ndtri if TEST_SCIPY else None,
domain=(0, 1),
aten_name="special_ndtri",
dtypes=all_types_and(torch.bool),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
),
UnaryUfuncInfo(
"special.log_ndtr",
aten_name="special_log_ndtr",
ref=scipy.special.log_ndtr if TEST_SCIPY else None,
dtypes=all_types_and(torch.bool),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
),
UnaryUfuncInfo(
"special.erfcx",
ref=scipy.special.erfcx if TEST_SCIPY else None,
aten_name="special_erfcx",
decorators=(
toleranceOverride(
{
torch.float32: tol(atol=0, rtol=4e-6),
}
),
),
dtypes=all_types_and(torch.bool),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
),
UnaryUfuncInfo(
"special.airy_ai",
decorators=(
precisionOverride(
{
torch.float32: 1e-03,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=lambda x: scipy.special.airy(x)[0] if TEST_SCIPY else None,
skips=(
DecorateInfo(
unittest.skip("Skipped!"),
"TestUnaryUfuncs",
"test_reference_numerics_large",
),
),
supports_autograd=False,
),
UnaryUfuncInfo(
"special.bessel_j0",
decorators=(
precisionOverride(
{
torch.float32: 1e-04,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.j0 if TEST_SCIPY else None,
supports_autograd=False,
),
UnaryUfuncInfo(
"special.bessel_j1",
decorators=(
precisionOverride(
{
torch.float32: 1e-04,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.j1 if TEST_SCIPY else None,
supports_autograd=False,
),
Loading ...