from functools import reduce, wraps, partial
from itertools import product
from operator import mul, itemgetter
import collections
import operator

import torch
import numpy as np
from torch._six import inf, istuple
from torch.autograd import Variable
import collections.abc

from typing import List, Tuple, Dict, Any

from torch.testing import \
    (make_non_contiguous, _dispatch_dtypes, floating_types, floating_types_and,
     floating_and_complex_types, floating_and_complex_types_and,
     all_types_and_complex_and, all_types_and, all_types_and_complex)
from torch.testing._internal.common_device_type import \
    (skipIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCPUIfNoMkl,
     skipCUDAIfRocm, expectedAlertNondeterministic, precisionOverride)
from torch.testing._internal.common_cuda import CUDA11OrLater
from torch.testing._internal.common_utils import \
    (prod_single_zero, random_square_matrix_of_rank,
     random_symmetric_matrix, random_symmetric_psd_matrix,
     random_symmetric_pd_matrix, make_nonzero_det,
     random_fullrank_matrix_distinct_singular_value, set_rng_seed,
     torch_to_numpy_dtype_dict, slowTest, TEST_WITH_ASAN)

from distutils.version import LooseVersion

    import scipy.special

class DecorateInfo(object):
    """Describes which test, or type of tests, should be wrapped in the given
       decorators when testing an operator. Any test that matches all provided
       arguments will be decorated. The decorators will only be applied if the
       active_if argument is True."""

    __slots__ = ['decorators', 'cls_name', 'test_name', 'device_type', 'dtypes', 'active_if']

    def __init__(self, decorators, cls_name=None, test_name=None, *,
                 device_type=None, dtypes=None, active_if=True):
        self.decorators = list(decorators) if isinstance(decorators, collections.abc.Sequence) else [decorators]
        self.cls_name = cls_name
        self.test_name = test_name
        self.device_type = device_type
        self.dtypes = dtypes
        self.active_if = active_if

    def is_active(self, cls_name, test_name, device_type, dtype):
        return (
            self.active_if and
            (self.cls_name is None or self.cls_name == cls_name) and
            (self.test_name is None or self.test_name == test_name) and
            (self.device_type is None or self.device_type == device_type) and
            (self.dtypes is None or dtype in self.dtypes)

class SkipInfo(DecorateInfo):
    """Describes which test, or type of tests, should be skipped when testing
       an operator. Any test that matches all provided arguments will be skipped.
       The skip will only be checked if the active_if argument is True."""

    def __init__(self, cls_name=None, test_name=None, *,
                 device_type=None, dtypes=None, active_if=True):
        super().__init__(decorators=skipIf(True, "Skipped!"), cls_name=cls_name,
                         test_name=test_name, device_type=device_type, dtypes=dtypes,

class SampleInput(object):
    """Represents sample inputs to a function."""

    # output_process_fn_grad is a function that modifies the output of op compatible with input
    __slots__ = ['input', 'args', 'kwargs', 'output_process_fn_grad']

    def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=None):
        # test_ops.py expects input to be a tuple
        self.input = input if isinstance(input, tuple) else (input,)
        self.args = args
        self.kwargs = kwargs if kwargs is not None else {}
        self.output_process_fn_grad = output_process_fn_grad

    def __repr__(self):
        arguments = [
            f'args={self.args}' if len(self.args) > 0 else None,
            f'kwargs={self.kwargs}' if len(self.kwargs) > 0 else None,
             if self.output_process_fn_grad is not None else None)]

        return f'SampleInput({", ".join(a for a in arguments if a is not None)})'

class AliasInfo(object):
    """Class holds alias information. For example, torch.abs ->
    torch.absolute, torch.Tensor.absolute, torch.Tensor.absolute_

    def __init__(self, alias_name):
        self.name = alias_name
        self.op = _getattr_qual(torch, alias_name)
        self.method_variant = getattr(torch.Tensor, alias_name, None)
        self.inplace_variant = getattr(torch.Tensor, alias_name + "_", None)

    def __call__(self, *args, **kwargs):
        return self.op(*args, **kwargs)

_NOTHING = object()  # Unique value to distinguish default from anything else

# Extension of getattr to support qualified names
# e.g. _getattr_qual(torch, 'linalg.norm') -> torch.linalg.norm
def _getattr_qual(obj, name, default=_NOTHING):
        for path in name.split('.'):
            obj = getattr(obj, path)
        return obj
    except AttributeError:
        if default is not _NOTHING:
            return default

# Classes and methods for the operator database
class OpInfo(object):
    """Operator information and helper functions for acquiring it."""

    def __init__(self,
                 name,  # the string name of the function
                 op=None,  # the function variant of the operation, populated as torch.<name> if None
                 dtypes=floating_types(),  # dtypes this function is expected to work with
                 dtypesIfCPU=None,  # dtypes this function is expected to work with on CPU
                 dtypesIfCUDA=None,  # dtypes this function is expected to work with on CUDA
                 dtypesIfROCM=None,  # dtypes this function is expected to work with on ROCM
                 default_test_dtypes=None,  # dtypes to test with by default. Gets intersected
                                            # with the dtypes support on the tested device
                 test_inplace_grad=True,  # whether to gradcheck and gradgradcheck the inplace variant
                 test_complex_grad=True,  # whether to gradcheck and gradgradcheck for complex dtypes
                 skip_bfloat16_grad=False,  # whether to skip grad and gradgradcheck for bfloat16 dtype
                 assert_autodiffed=False,  # if a op's aten::node is expected to be symbolically autodiffed
                 autodiff_nonfusible_nodes=None,  # a list of strings with node names that are expected to be in a
                                                  # DifferentiableGraph when autodiffed. Ex: ['aten::add', 'aten::mm'],
                                                  # default is populated to be ['aten::(name of Python operator)']
                 autodiff_fusible_nodes=None,  # a list of strings with node names that are expected to be in FusionGroups
                                               # inside of DifferentiableGraphs when this operation is autodiffed.
                                               # Ex: ['aten::add', 'aten::mm'], defaults to an empty list
                                               # Note: currently no ops use fusible nodes
                 output_func=lambda x: x,  # fn mapping output to part that should be gradcheck'ed
                 supports_tensor_out=True,  # whether the op supports the out kwarg, returning a Tensor
                 skips=tuple(),  # information about which tests to skip
                 decorators=None,  # decorators to apply to generated tests
                 safe_casts_outputs=False,  # whether op allows safe casting when writing to out arguments
                 sample_inputs_func=None,  # function to generate sample inputs
                 aten_name=None,  # name of the corresponding aten:: operator
                 aliases=None,  # iterable of aliases, e.g. ("absolute",) for torch.abs
                 variant_test_name='',  # additional string to include in the test name
                 supports_sparse=False,  # supported for sparse
                 check_batched_grad=True,  # check batched grad when doing gradcheck
                 check_batched_gradgrad=True,  # check batched grad grad when doing gradgradcheck

        # Validates the dtypes are generated from the dispatch-related functions
        for dtype_list in (dtypes, dtypesIfCPU, dtypesIfCUDA, dtypesIfROCM):
            assert isinstance(dtype_list, (_dispatch_dtypes, type(None)))

        self.name = name
        self.aten_name = aten_name if aten_name is not None else name
        self.variant_test_name = variant_test_name

        self.dtypes = set(dtypes)
        self.dtypesIfCPU = set(dtypesIfCPU) if dtypesIfCPU is not None else self.dtypes
        self.dtypesIfCUDA = set(dtypesIfCUDA) if dtypesIfCUDA is not None else self.dtypes
        self.dtypesIfROCM = set(dtypesIfROCM) if dtypesIfROCM is not None else self.dtypes
        self._default_test_dtypes = set(default_test_dtypes) if default_test_dtypes is not None else None

        # NOTE: if the op is unspecified it is assumed to be under the torch namespace
        self.op = op if op else _getattr_qual(torch, self.name)
        self.method_variant = getattr(torch.Tensor, name, None)
        inplace_name = name + "_"
        self.inplace_variant = getattr(torch.Tensor, inplace_name, None)
        self.operator_variant = getattr(operator, name, None)
        self.skip_bfloat16_grad = skip_bfloat16_grad

        self.test_inplace_grad = test_inplace_grad
        self.test_complex_grad = test_complex_grad
        self.supports_tensor_out = supports_tensor_out
        self.safe_casts_outputs = safe_casts_outputs

        self.skips = skips
        self.decorators = decorators
        self.output_func = output_func
        self.sample_inputs_func = sample_inputs_func

        self.assert_autodiffed = assert_autodiffed
        self.autodiff_fusible_nodes = autodiff_fusible_nodes if autodiff_fusible_nodes else []
        if autodiff_nonfusible_nodes is None:
            self.autodiff_nonfusible_nodes = ['aten::' + self.name]
            self.autodiff_nonfusible_nodes = autodiff_nonfusible_nodes
        self.supports_sparse = supports_sparse
        self.check_batched_grad = check_batched_grad
        self.check_batched_gradgrad = check_batched_gradgrad

        self.aliases = ()  # type: ignore
        if aliases is not None:
            self.aliases = tuple(AliasInfo(a) for a in aliases)  # type: ignore

    def __call__(self, *args, **kwargs):
        """Calls the function variant of the operator."""
        return self.op(*args, **kwargs)

    def get_op(self):
        """Returns the function variant of the operator, torch.<op_name>."""
        return self.op

    def get_method(self):
        """Returns the method variant of the operator, torch.Tensor.<op_name>.
        Returns None if the operator has no method variant.
        return self.method_variant

    def get_inplace(self):
        """Returns the inplace variant of the operator, torch.Tensor.<op_name>_.
        Returns None if the operator has no inplace variant.
        return self.inplace_variant

    def get_operator_variant(self):
        """Returns operator variant of the operator, e.g. operator.neg
        Returns None if the operator has no operator variant.
        return self.operator_variant

    def sample_inputs(self, device, dtype, requires_grad=False):
        """Returns an iterable of SampleInputs.

        These samples should be sufficient to test the function works correctly
        with autograd, TorchScript, etc.
        return self.sample_inputs_func(self, device, dtype, requires_grad)

    # Returns True if the test should be skipped and False otherwise
    def should_skip(self, cls_name, test_name, device_type, dtype):
        return any(si.is_active(cls_name, test_name, device_type, dtype)
                   for si in self.skips)

    def supported_dtypes(self, device_type):
        if device_type == 'cpu':
            return self.dtypesIfCPU
        if device_type == 'cuda':
            return self.dtypesIfROCM if TEST_WITH_ROCM else self.dtypesIfCUDA
            return self.dtypes

    def supports_dtype(self, dtype, device_type):
        return dtype in self.supported_dtypes(device_type)

    def default_test_dtypes(self, device_type):
        """Returns the default dtypes used to test this operator on the device.

        Equal to the operator's default_test_dtypes filtered to remove dtypes
        not supported by the device.
        supported = self.supported_dtypes(device_type)
        return (supported if self._default_test_dtypes is None
                else supported.intersection(self._default_test_dtypes))

L = 20
M = 10
S = 5

def sample_inputs_unary(op_info, device, dtype, requires_grad):
    low, high = op_info.domain
    low = low if low is None else low + op_info._domain_eps
    high = high if high is None else high - op_info._domain_eps

    return (SampleInput(make_tensor((L,), device, dtype,
                                    low=low, high=high,
            SampleInput(make_tensor((), device, dtype,
                                    low=low, high=high,

# Metadata class for unary "universal functions (ufuncs)" that accept a single
# tensor and have common properties like:
class UnaryUfuncInfo(OpInfo):
    """Operator information for 'universal unary functions (unary ufuncs).'
    These are functions of a single tensor with common properties like:
      - they are elementwise functions
      - the input shape is the output shape
      - they typically have method and inplace variants
      - they typically support the out kwarg
      - they typically have NumPy or SciPy references
    See NumPy's universal function documentation
    (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
    about the concept of ufuncs.

    def __init__(self,
                 name,  # the string name of the function
                 ref,  # a reference function
                 domain=(None, None),  # the [low, high) domain of the function
                 handles_large_floats=True,  # whether the op correctly handles large float values (like 1e20)
                 handles_extremals=True,  # whether the op correctly handles extremal values (like inf)
                 handles_complex_extremals=True,  # whether the op correct handles complex extremals (like inf -infj)
                 supports_complex_to_float=False,  # op supports casting from complex input to real output safely eg. angle
        super(UnaryUfuncInfo, self).__init__(name,
        self.ref = ref
        self.domain = domain
        self.handles_large_floats = handles_large_floats
        self.handles_extremals = handles_extremals
        self.handles_complex_extremals = handles_complex_extremals
        self.supports_complex_to_float = supports_complex_to_float

        # Epsilon to ensure grad and gradgrad checks don't test values
        #   outside a function's domain.
        self._domain_eps = 1e-5

def sample_inputs_tensor_split(op_info, device, dtype, requires_grad):
    return (SampleInput(make_tensor((S, S, S), device, dtype,
