from functools import reduce, wraps, partial
from itertools import product
from operator import mul, itemgetter
import collections
import operator
import torch
import numpy as np
from torch._six import inf, istuple
from torch.autograd import Variable
import collections.abc
from typing import List, Tuple, Dict, Any
from torch.testing import \
(make_non_contiguous, _dispatch_dtypes, floating_types, floating_types_and,
floating_and_complex_types, floating_and_complex_types_and,
all_types_and_complex_and, all_types_and, all_types_and_complex)
from torch.testing._internal.common_device_type import \
(skipIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCPUIfNoMkl,
skipCUDAIfRocm, expectedAlertNondeterministic, precisionOverride)
from torch.testing._internal.common_cuda import CUDA11OrLater
from torch.testing._internal.common_utils import \
(prod_single_zero, random_square_matrix_of_rank,
random_symmetric_matrix, random_symmetric_psd_matrix,
random_symmetric_pd_matrix, make_nonzero_det,
random_fullrank_matrix_distinct_singular_value, set_rng_seed,
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor, TEST_SCIPY,
torch_to_numpy_dtype_dict, slowTest, TEST_WITH_ASAN)
from distutils.version import LooseVersion
if TEST_SCIPY:
import scipy.special
class DecorateInfo(object):
"""Describes which test, or type of tests, should be wrapped in the given
decorators when testing an operator. Any test that matches all provided
arguments will be decorated. The decorators will only be applied if the
active_if argument is True."""
__slots__ = ['decorators', 'cls_name', 'test_name', 'device_type', 'dtypes', 'active_if']
def __init__(self, decorators, cls_name=None, test_name=None, *,
device_type=None, dtypes=None, active_if=True):
self.decorators = list(decorators) if isinstance(decorators, collections.abc.Sequence) else [decorators]
self.cls_name = cls_name
self.test_name = test_name
self.device_type = device_type
self.dtypes = dtypes
self.active_if = active_if
def is_active(self, cls_name, test_name, device_type, dtype):
return (
self.active_if and
(self.cls_name is None or self.cls_name == cls_name) and
(self.test_name is None or self.test_name == test_name) and
(self.device_type is None or self.device_type == device_type) and
(self.dtypes is None or dtype in self.dtypes)
)
class SkipInfo(DecorateInfo):
"""Describes which test, or type of tests, should be skipped when testing
an operator. Any test that matches all provided arguments will be skipped.
The skip will only be checked if the active_if argument is True."""
def __init__(self, cls_name=None, test_name=None, *,
device_type=None, dtypes=None, active_if=True):
super().__init__(decorators=skipIf(True, "Skipped!"), cls_name=cls_name,
test_name=test_name, device_type=device_type, dtypes=dtypes,
active_if=active_if)
class SampleInput(object):
"""Represents sample inputs to a function."""
# output_process_fn_grad is a function that modifies the output of op compatible with input
__slots__ = ['input', 'args', 'kwargs', 'output_process_fn_grad']
def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=None):
# test_ops.py expects input to be a tuple
self.input = input if isinstance(input, tuple) else (input,)
self.args = args
self.kwargs = kwargs if kwargs is not None else {}
self.output_process_fn_grad = output_process_fn_grad
def __repr__(self):
arguments = [
f'input[{len(self.input)}]',
f'args={self.args}' if len(self.args) > 0 else None,
f'kwargs={self.kwargs}' if len(self.kwargs) > 0 else None,
(f'output_process_fn_grad={self.output_process_fn_grad}'
if self.output_process_fn_grad is not None else None)]
return f'SampleInput({", ".join(a for a in arguments if a is not None)})'
class AliasInfo(object):
"""Class holds alias information. For example, torch.abs ->
torch.absolute, torch.Tensor.absolute, torch.Tensor.absolute_
"""
def __init__(self, alias_name):
self.name = alias_name
self.op = _getattr_qual(torch, alias_name)
self.method_variant = getattr(torch.Tensor, alias_name, None)
self.inplace_variant = getattr(torch.Tensor, alias_name + "_", None)
def __call__(self, *args, **kwargs):
return self.op(*args, **kwargs)
_NOTHING = object() # Unique value to distinguish default from anything else
# Extension of getattr to support qualified names
# e.g. _getattr_qual(torch, 'linalg.norm') -> torch.linalg.norm
def _getattr_qual(obj, name, default=_NOTHING):
try:
for path in name.split('.'):
obj = getattr(obj, path)
return obj
except AttributeError:
if default is not _NOTHING:
return default
else:
raise
# Classes and methods for the operator database
class OpInfo(object):
"""Operator information and helper functions for acquiring it."""
def __init__(self,
name, # the string name of the function
*,
op=None, # the function variant of the operation, populated as torch.<name> if None
dtypes=floating_types(), # dtypes this function is expected to work with
dtypesIfCPU=None, # dtypes this function is expected to work with on CPU
dtypesIfCUDA=None, # dtypes this function is expected to work with on CUDA
dtypesIfROCM=None, # dtypes this function is expected to work with on ROCM
default_test_dtypes=None, # dtypes to test with by default. Gets intersected
# with the dtypes support on the tested device
test_inplace_grad=True, # whether to gradcheck and gradgradcheck the inplace variant
test_complex_grad=True, # whether to gradcheck and gradgradcheck for complex dtypes
skip_bfloat16_grad=False, # whether to skip grad and gradgradcheck for bfloat16 dtype
assert_autodiffed=False, # if a op's aten::node is expected to be symbolically autodiffed
autodiff_nonfusible_nodes=None, # a list of strings with node names that are expected to be in a
# DifferentiableGraph when autodiffed. Ex: ['aten::add', 'aten::mm'],
# default is populated to be ['aten::(name of Python operator)']
autodiff_fusible_nodes=None, # a list of strings with node names that are expected to be in FusionGroups
# inside of DifferentiableGraphs when this operation is autodiffed.
# Ex: ['aten::add', 'aten::mm'], defaults to an empty list
# Note: currently no ops use fusible nodes
output_func=lambda x: x, # fn mapping output to part that should be gradcheck'ed
supports_tensor_out=True, # whether the op supports the out kwarg, returning a Tensor
skips=tuple(), # information about which tests to skip
decorators=None, # decorators to apply to generated tests
safe_casts_outputs=False, # whether op allows safe casting when writing to out arguments
sample_inputs_func=None, # function to generate sample inputs
aten_name=None, # name of the corresponding aten:: operator
aliases=None, # iterable of aliases, e.g. ("absolute",) for torch.abs
variant_test_name='', # additional string to include in the test name
supports_sparse=False, # supported for sparse
check_batched_grad=True, # check batched grad when doing gradcheck
check_batched_gradgrad=True, # check batched grad grad when doing gradgradcheck
):
# Validates the dtypes are generated from the dispatch-related functions
for dtype_list in (dtypes, dtypesIfCPU, dtypesIfCUDA, dtypesIfROCM):
assert isinstance(dtype_list, (_dispatch_dtypes, type(None)))
self.name = name
self.aten_name = aten_name if aten_name is not None else name
self.variant_test_name = variant_test_name
self.dtypes = set(dtypes)
self.dtypesIfCPU = set(dtypesIfCPU) if dtypesIfCPU is not None else self.dtypes
self.dtypesIfCUDA = set(dtypesIfCUDA) if dtypesIfCUDA is not None else self.dtypes
self.dtypesIfROCM = set(dtypesIfROCM) if dtypesIfROCM is not None else self.dtypes
self._default_test_dtypes = set(default_test_dtypes) if default_test_dtypes is not None else None
# NOTE: if the op is unspecified it is assumed to be under the torch namespace
self.op = op if op else _getattr_qual(torch, self.name)
self.method_variant = getattr(torch.Tensor, name, None)
inplace_name = name + "_"
self.inplace_variant = getattr(torch.Tensor, inplace_name, None)
self.operator_variant = getattr(operator, name, None)
self.skip_bfloat16_grad = skip_bfloat16_grad
self.test_inplace_grad = test_inplace_grad
self.test_complex_grad = test_complex_grad
self.supports_tensor_out = supports_tensor_out
self.safe_casts_outputs = safe_casts_outputs
self.skips = skips
self.decorators = decorators
self.output_func = output_func
self.sample_inputs_func = sample_inputs_func
self.assert_autodiffed = assert_autodiffed
self.autodiff_fusible_nodes = autodiff_fusible_nodes if autodiff_fusible_nodes else []
if autodiff_nonfusible_nodes is None:
self.autodiff_nonfusible_nodes = ['aten::' + self.name]
else:
self.autodiff_nonfusible_nodes = autodiff_nonfusible_nodes
self.supports_sparse = supports_sparse
self.check_batched_grad = check_batched_grad
self.check_batched_gradgrad = check_batched_gradgrad
self.aliases = () # type: ignore
if aliases is not None:
self.aliases = tuple(AliasInfo(a) for a in aliases) # type: ignore
def __call__(self, *args, **kwargs):
"""Calls the function variant of the operator."""
return self.op(*args, **kwargs)
def get_op(self):
"""Returns the function variant of the operator, torch.<op_name>."""
return self.op
def get_method(self):
"""Returns the method variant of the operator, torch.Tensor.<op_name>.
Returns None if the operator has no method variant.
"""
return self.method_variant
def get_inplace(self):
"""Returns the inplace variant of the operator, torch.Tensor.<op_name>_.
Returns None if the operator has no inplace variant.
"""
return self.inplace_variant
def get_operator_variant(self):
"""Returns operator variant of the operator, e.g. operator.neg
Returns None if the operator has no operator variant.
"""
return self.operator_variant
def sample_inputs(self, device, dtype, requires_grad=False):
"""Returns an iterable of SampleInputs.
These samples should be sufficient to test the function works correctly
with autograd, TorchScript, etc.
"""
return self.sample_inputs_func(self, device, dtype, requires_grad)
# Returns True if the test should be skipped and False otherwise
def should_skip(self, cls_name, test_name, device_type, dtype):
return any(si.is_active(cls_name, test_name, device_type, dtype)
for si in self.skips)
def supported_dtypes(self, device_type):
if device_type == 'cpu':
return self.dtypesIfCPU
if device_type == 'cuda':
return self.dtypesIfROCM if TEST_WITH_ROCM else self.dtypesIfCUDA
else:
return self.dtypes
def supports_dtype(self, dtype, device_type):
return dtype in self.supported_dtypes(device_type)
def default_test_dtypes(self, device_type):
"""Returns the default dtypes used to test this operator on the device.
Equal to the operator's default_test_dtypes filtered to remove dtypes
not supported by the device.
"""
supported = self.supported_dtypes(device_type)
return (supported if self._default_test_dtypes is None
else supported.intersection(self._default_test_dtypes))
L = 20
M = 10
S = 5
def sample_inputs_unary(op_info, device, dtype, requires_grad):
low, high = op_info.domain
low = low if low is None else low + op_info._domain_eps
high = high if high is None else high - op_info._domain_eps
return (SampleInput(make_tensor((L,), device, dtype,
low=low, high=high,
requires_grad=requires_grad)),
SampleInput(make_tensor((), device, dtype,
low=low, high=high,
requires_grad=requires_grad)))
# Metadata class for unary "universal functions (ufuncs)" that accept a single
# tensor and have common properties like:
class UnaryUfuncInfo(OpInfo):
"""Operator information for 'universal unary functions (unary ufuncs).'
These are functions of a single tensor with common properties like:
- they are elementwise functions
- the input shape is the output shape
- they typically have method and inplace variants
- they typically support the out kwarg
- they typically have NumPy or SciPy references
See NumPy's universal function documentation
(https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
about the concept of ufuncs.
"""
def __init__(self,
name, # the string name of the function
*,
ref, # a reference function
dtypes=floating_types(),
dtypesIfCPU=floating_and_complex_types_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
dtypesIfROCM=floating_types_and(torch.half),
domain=(None, None), # the [low, high) domain of the function
handles_large_floats=True, # whether the op correctly handles large float values (like 1e20)
handles_extremals=True, # whether the op correctly handles extremal values (like inf)
handles_complex_extremals=True, # whether the op correct handles complex extremals (like inf -infj)
supports_complex_to_float=False, # op supports casting from complex input to real output safely eg. angle
sample_inputs_func=sample_inputs_unary,
supports_sparse=False,
**kwargs):
super(UnaryUfuncInfo, self).__init__(name,
dtypes=dtypes,
dtypesIfCPU=dtypesIfCPU,
dtypesIfCUDA=dtypesIfCUDA,
dtypesIfROCM=dtypesIfROCM,
sample_inputs_func=sample_inputs_func,
supports_sparse=supports_sparse,
**kwargs)
self.ref = ref
self.domain = domain
self.handles_large_floats = handles_large_floats
self.handles_extremals = handles_extremals
self.handles_complex_extremals = handles_complex_extremals
self.supports_complex_to_float = supports_complex_to_float
# Epsilon to ensure grad and gradgrad checks don't test values
# outside a function's domain.
self._domain_eps = 1e-5
def sample_inputs_tensor_split(op_info, device, dtype, requires_grad):
return (SampleInput(make_tensor((S, S, S), device, dtype,
Loading ...