import unittest
from functools import partial
from typing import List
import numpy as np
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import SM53OrLater
from torch.testing._internal.common_device_type import precisionOverride
from torch.testing._internal.common_dtype import (
all_types_and,
all_types_and_complex_and,
)
from torch.testing._internal.common_utils import TEST_SCIPY, TEST_WITH_ROCM
from torch.testing._internal.opinfo.core import (
DecorateInfo,
ErrorInput,
OpInfo,
SampleInput,
SpectralFuncInfo,
SpectralFuncType,
)
from torch.testing._internal.opinfo.refs import (
_find_referenced_opinfo,
_inherit_constructor_args,
PythonRefInfo,
)
has_scipy_fft = False
if TEST_SCIPY:
try:
import scipy.fft
has_scipy_fft = True
except ModuleNotFoundError:
pass
class SpectralFuncPythonRefInfo(SpectralFuncInfo):
"""
An OpInfo for a Python reference of an elementwise unary operation.
"""
def __init__(
self,
name, # the stringname of the callable Python reference
*,
op=None, # the function variant of the operation, populated as torch.<name> if None
torch_opinfo_name, # the string name of the corresponding torch opinfo
torch_opinfo_variant="",
supports_nvfuser=True,
**kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name
self.torch_opinfo = _find_referenced_opinfo(
torch_opinfo_name, torch_opinfo_variant, op_db=op_db
)
self.supports_nvfuser = supports_nvfuser
assert isinstance(self.torch_opinfo, SpectralFuncInfo)
inherited = self.torch_opinfo._original_spectral_func_args
ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
super().__init__(**ukwargs)
def error_inputs_fft(op_info, device, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=torch.float32)
# Zero-dimensional tensor has no dimension to take FFT of
yield ErrorInput(
SampleInput(make_arg()),
error_type=IndexError,
error_regex="Dimension specified as -1 but tensor has no dimensions",
)
def error_inputs_fftn(op_info, device, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=torch.float32)
# Specifying a dimension on a zero-dimensional tensor
yield ErrorInput(
SampleInput(make_arg(), dim=(0,)),
error_type=IndexError,
error_regex="Dimension specified as 0 but tensor has no dimensions",
)
def sample_inputs_fftshift(op_info, device, dtype, requires_grad, **kwargs):
def mt(shape, **kwargs):
return make_tensor(
shape, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
)
yield SampleInput(mt((9, 10)))
yield SampleInput(mt((50,)), kwargs=dict(dim=0))
yield SampleInput(mt((5, 11)), kwargs=dict(dim=(1,)))
yield SampleInput(mt((5, 6)), kwargs=dict(dim=(0, 1)))
yield SampleInput(mt((5, 6, 2)), kwargs=dict(dim=(0, 2)))
# Operator database
op_db: List[OpInfo] = [
SpectralFuncInfo(
"fft.fft",
aten_name="fft_fft",
decomp_aten_name="_fft_c2c",
ref=np.fft.fft,
ndimensional=SpectralFuncType.OneD,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool,
*(
()
if (TEST_WITH_ROCM or not SM53OrLater)
else (torch.half, torch.complex32)
),
),
error_inputs_func=error_inputs_fft,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
),
SpectralFuncInfo(
"fft.fft2",
aten_name="fft_fft2",
ref=np.fft.fft2,
decomp_aten_name="_fft_c2c",
ndimensional=SpectralFuncType.TwoD,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool,
*(
()
if (TEST_WITH_ROCM or not SM53OrLater)
else (torch.half, torch.complex32)
),
),
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})],
),
SpectralFuncInfo(
"fft.fftn",
aten_name="fft_fftn",
decomp_aten_name="_fft_c2c",
ref=np.fft.fftn,
ndimensional=SpectralFuncType.ND,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool,
*(
()
if (TEST_WITH_ROCM or not SM53OrLater)
else (torch.half, torch.complex32)
),
),
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})],
),
SpectralFuncInfo(
"fft.hfft",
aten_name="fft_hfft",
decomp_aten_name="_fft_c2r",
ref=np.fft.hfft,
ndimensional=SpectralFuncType.OneD,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool,
*(
()
if (TEST_WITH_ROCM or not SM53OrLater)
else (torch.half, torch.complex32)
),
),
error_inputs_func=error_inputs_fft,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
check_batched_gradgrad=False,
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
"TestSchemaCheckModeOpInfo",
"test_schema_correctness",
dtypes=(torch.complex64, torch.complex128),
),
),
),
SpectralFuncInfo(
"fft.hfft2",
aten_name="fft_hfft2",
decomp_aten_name="_fft_c2r",
ref=scipy.fft.hfft2 if has_scipy_fft else None,
ndimensional=SpectralFuncType.TwoD,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool,
*(
()
if (TEST_WITH_ROCM or not SM53OrLater)
else (torch.half, torch.complex32)
),
),
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
check_batched_gradgrad=False,
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
decorators=[
DecorateInfo(
precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
"TestFFT",
"test_reference_nd",
)
],
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
"TestSchemaCheckModeOpInfo",
"test_schema_correctness",
),
),
),
SpectralFuncInfo(
"fft.hfftn",
aten_name="fft_hfftn",
decomp_aten_name="_fft_c2r",
ref=scipy.fft.hfftn if has_scipy_fft else None,
ndimensional=SpectralFuncType.ND,
dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool,
*(
()
if (TEST_WITH_ROCM or not SM53OrLater)
else (torch.half, torch.complex32)
),
),
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
check_batched_gradgrad=False,
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
decorators=[
DecorateInfo(
precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
"TestFFT",
"test_reference_nd",
),
],
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
"TestSchemaCheckModeOpInfo",
"test_schema_correctness",
),
),
),
SpectralFuncInfo(
"fft.rfft",
aten_name="fft_rfft",
decomp_aten_name="_fft_r2c",
ref=np.fft.rfft,
ndimensional=SpectralFuncType.OneD,
dtypes=all_types_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and(
torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
),
error_inputs_func=error_inputs_fft,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
check_batched_grad=False,
skips=(),
check_batched_gradgrad=False,
),
SpectralFuncInfo(
"fft.rfft2",
aten_name="fft_rfft2",
decomp_aten_name="_fft_r2c",
ref=np.fft.rfft2,
ndimensional=SpectralFuncType.TwoD,
dtypes=all_types_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and(
torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
),
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
check_batched_grad=False,
check_batched_gradgrad=False,
decorators=[
precisionOverride({torch.float: 1e-4}),
],
),
SpectralFuncInfo(
"fft.rfftn",
aten_name="fft_rfftn",
Loading ...