Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ testing / _internal / opinfo / definitions / signal.py

import unittest
from functools import partial

from itertools import product
from typing import Callable, List, Tuple

import numpy

import torch
from torch.testing._internal.common_dtype import floating_types
from torch.testing._internal.common_utils import TEST_SCIPY
from torch.testing._internal.opinfo.core import (
    DecorateInfo,
    ErrorInput,
    OpInfo,
    SampleInput,
)

if TEST_SCIPY:
    import scipy.signal


def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs):
    r"""Base function used to create sample inputs for windows.

    For additional required args you should use *args, as well as **kwargs for
    additional keyword arguments.
    """

    # Tests window sizes up to 5 samples.
    for size, sym in product(range(6), (True, False)):
        yield SampleInput(
            size,
            *args,
            sym=sym,
            device=device,
            dtype=dtype,
            requires_grad=requires_grad,
            **kwargs,
        )


def reference_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs):
    r"""Reference inputs function to use for windows which have a common signature, i.e.,
    window size and sym only.

    Implement other special functions for windows that have a specific signature.
    See exponential and gaussian windows for instance.
    """
    yield from sample_inputs_window(
        op_info, device, dtype, requires_grad, *args, **kwargs
    )

    cases = (8, 16, 32, 64, 128, 256)

    for size in cases:
        yield SampleInput(size, sym=False)
        yield SampleInput(size, sym=True)


def reference_inputs_exponential_window(
    op_info, device, dtype, requires_grad, **kwargs
):
    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)

    cases = (
        (8, {"center": 4, "tau": 0.5}),
        (16, {"center": 8, "tau": 2.5}),
        (32, {"center": 16, "tau": 43.5}),
        (64, {"center": 20, "tau": 3.7}),
        (128, {"center": 62, "tau": 99}),
        (256, {"tau": 10}),
    )

    for size, kw in cases:
        yield SampleInput(size, sym=False, **kw)
        kw["center"] = None
        yield SampleInput(size, sym=True, **kw)


def reference_inputs_gaussian_window(op_info, device, dtype, requires_grad, **kwargs):
    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)

    cases = (
        (8, {"std": 0.1}),
        (16, {"std": 1.2}),
        (32, {"std": 2.1}),
        (64, {"std": 3.9}),
        (128, {"std": 4.5}),
        (256, {"std": 10}),
    )

    for size, kw in cases:
        yield SampleInput(size, sym=False, **kw)
        yield SampleInput(size, sym=True, **kw)


def reference_inputs_kaiser_window(op_info, device, dtype, requires_grad, **kwargs):
    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)

    cases = (
        (8, {"beta": 2}),
        (16, {"beta": 12}),
        (32, {"beta": 30}),
        (64, {"beta": 35}),
        (128, {"beta": 41.2}),
        (256, {"beta": 100}),
    )

    for size, kw in cases:
        yield SampleInput(size, sym=False, **kw)
        yield SampleInput(size, sym=True, **kw)


def reference_inputs_general_cosine_window(
    op_info, device, dtype, requires_grad, **kwargs
):
    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)

    cases = (
        (8, {"a": [0.5, 0.5]}),
        (16, {"a": [0.46, 0.54]}),
        (32, {"a": [0.46, 0.23, 0.31]}),
        (64, {"a": [0.5]}),
        (128, {"a": [0.1, 0.8, 0.05, 0.05]}),
        (256, {"a": [0.2, 0.2, 0.2, 0.2, 0.2]}),
    )

    for size, kw in cases:
        yield SampleInput(size, sym=False, **kw)
        yield SampleInput(size, sym=True, **kw)


def reference_inputs_general_hamming_window(
    op_info, device, dtype, requires_grad, **kwargs
):
    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)

    cases = (
        (8, {"alpha": 0.54}),
        (16, {"alpha": 0.5}),
        (32, {"alpha": 0.23}),
        (64, {"alpha": 0.8}),
        (128, {"alpha": 0.9}),
        (256, {"alpha": 0.05}),
    )

    for size, kw in cases:
        yield SampleInput(size, sym=False, **kw)
        yield SampleInput(size, sym=True, **kw)


def error_inputs_window(op_info, device, *args, **kwargs):
    # Tests for windows that have a negative size
    yield ErrorInput(
        SampleInput(-1, *args, dtype=torch.float32, device=device, **kwargs),
        error_type=ValueError,
        error_regex="requires non-negative window length, got M=-1",
    )

    # Tests for window tensors that are not torch.strided, for instance, torch.sparse_coo.
    yield ErrorInput(
        SampleInput(
            3,
            *args,
            layout=torch.sparse_coo,
            device=device,
            dtype=torch.float32,
            **kwargs,
        ),
        error_type=ValueError,
        error_regex="is implemented for strided tensors only, got: torch.sparse_coo",
    )

    # Tests for window tensors that are not floating point dtypes, for instance, torch.long.
    yield ErrorInput(
        SampleInput(3, *args, dtype=torch.long, device=device, **kwargs),
        error_type=ValueError,
        error_regex="expects float32 or float64 dtypes, got: torch.int64",
    )

    # Tests for window tensors that are bfloat16
    yield ErrorInput(
        SampleInput(3, *args, dtype=torch.bfloat16, device=device, **kwargs),
        error_type=ValueError,
        error_regex="expects float32 or float64 dtypes, got: torch.bfloat16",
    )

    # Tests for window tensors that are float16
    yield ErrorInput(
        SampleInput(3, *args, dtype=torch.float16, device=device, **kwargs),
        error_type=ValueError,
        error_regex="expects float32 or float64 dtypes, got: torch.float16",
    )


def error_inputs_exponential_window(op_info, device, **kwargs):
    # Yield common error inputs
    yield from error_inputs_window(op_info, device, **kwargs)

    # Tests for negative decay values.
    yield ErrorInput(
        SampleInput(3, tau=-1, dtype=torch.float32, device=device, **kwargs),
        error_type=ValueError,
        error_regex="Tau must be positive, got: -1 instead.",
    )

    # Tests for symmetric windows and a given center value.
    yield ErrorInput(
        SampleInput(3, center=1, sym=True, dtype=torch.float32, device=device),
        error_type=ValueError,
        error_regex="Center must be None for symmetric windows",
    )


def error_inputs_gaussian_window(op_info, device, **kwargs):
    # Yield common error inputs
    yield from error_inputs_window(op_info, device, std=0.5, **kwargs)

    # Tests for negative standard deviations
    yield ErrorInput(
        SampleInput(3, std=-1, dtype=torch.float32, device=device, **kwargs),
        error_type=ValueError,
        error_regex="Standard deviation must be positive, got: -1 instead.",
    )


def error_inputs_kaiser_window(op_info, device, **kwargs):
    # Yield common error inputs
    yield from error_inputs_window(op_info, device, beta=12, **kwargs)

    # Tests for negative beta
    yield ErrorInput(
        SampleInput(3, beta=-1, dtype=torch.float32, device=device, **kwargs),
        error_type=ValueError,
        error_regex="beta must be non-negative, got: -1 instead.",
    )


def error_inputs_general_cosine_window(op_info, device, **kwargs):
    # Yield common error inputs
    yield from error_inputs_window(op_info, device, a=[0.54, 0.46], **kwargs)

    # Tests for negative beta
    yield ErrorInput(
        SampleInput(3, a=None, dtype=torch.float32, device=device, **kwargs),
        error_type=TypeError,
        error_regex="Coefficients must be a list/tuple",
    )

    yield ErrorInput(
        SampleInput(3, a=[], dtype=torch.float32, device=device, **kwargs),
        error_type=ValueError,
        error_regex="Coefficients cannot be empty",
    )


def reference_signal_window(fn: Callable):
    r"""Wrapper for scipy signal window references.

    Discards keyword arguments for window reference functions that don't have a matching signature with
    torch, e.g., gaussian window.
    """

    def _fn(
        *args,
        dtype=numpy.float64,
        device=None,
        layout=torch.strided,
        requires_grad=False,
        **kwargs,
    ):
        r"""The unused arguments are defined to disregard those values"""
        return fn(*args, **kwargs).astype(dtype)

    return _fn


def make_signal_windows_opinfo(
    name: str,
    ref: Callable,
    sample_inputs_func: Callable,
    reference_inputs_func: Callable,
    error_inputs_func: Callable,
    *,
    skips: Tuple[DecorateInfo, ...] = (),
):
    r"""Helper function to create OpInfo objects related to different windows."""
    return OpInfo(
        name=name,
        ref=ref if TEST_SCIPY else None,
        dtypes=floating_types(),
        dtypesIfCUDA=floating_types(),
        sample_inputs_func=sample_inputs_func,
        reference_inputs_func=reference_inputs_func,
        error_inputs_func=error_inputs_func,
        supports_out=False,
        supports_autograd=False,
        skips=(
            # TODO: same as this?
            # https://github.com/pytorch/pytorch/issues/81774
            # also see: arange, new_full
            # fails to match any schemas despite working in the interpreter
            DecorateInfo(
                unittest.expectedFailure,
                "TestOperatorSignatures",
                "test_get_torch_func_signature_exhaustive",
            ),
            # fails to match any schemas despite working in the interpreter
            DecorateInfo(
                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
            ),
            # skip these tests since we have non tensor input
            DecorateInfo(
                unittest.skip("Skipped!"), "TestCommon", "test_noncontiguous_samples"
            ),
            DecorateInfo(
                unittest.skip("Skipped!"),
                "TestCommon",
                "test_variant_consistency_eager",
            ),
            DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_conj_view"),
            DecorateInfo(
                unittest.skip("Skipped!"), "TestMathBits", "test_neg_conj_view"
            ),
            DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_neg_view"),
            DecorateInfo(
                unittest.skip("Skipped!"),
                "TestVmapOperatorsOpInfo",
                "test_vmap_exhaustive",
            ),
            DecorateInfo(
                unittest.skip("Skipped!"),
                "TestVmapOperatorsOpInfo",
                "test_op_has_batch_rule",
            ),
            DecorateInfo(
                unittest.skip("Buggy on MPS for now (mistakenly promotes to float64)"),
                "TestCommon",
                "test_numpy_ref_mps",
            ),
            *skips,
        ),
    )
Loading ...