Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
einops / tests / test_einsum.py
Size: Mime:
import string
from typing import Any, Callable

import numpy as np
import pytest

from einops.einops import EinopsError, _compactify_pattern_for_einsum, einsum
from einops.tests import collect_test_backends


class Arguments:
    def __init__(self, *args: Any, **kargs: Any):
        self.args = args
        self.kwargs = kargs

    def __call__(self, function: Callable):
        return function(*self.args, **self.kwargs)


test_layer_cases = [
    (
        Arguments("b c_in h w -> w c_out h b", "c_in c_out", bias_shape=None, c_out=13, c_in=12),
        (2, 12, 3, 4),
        (4, 13, 3, 2),
    ),
    (
        Arguments("b c_in h w -> w c_out h b", "c_in c_out", bias_shape="c_out", c_out=13, c_in=12),
        (2, 12, 3, 4),
        (4, 13, 3, 2),
    ),
    (
        Arguments("b c_in h w -> w c_in h b", "", bias_shape=None, c_in=12),
        (2, 12, 3, 4),
        (4, 12, 3, 2),
    ),
    (
        Arguments("b c_in h w -> b c_out", "c_in h w c_out", bias_shape=None, c_in=12, h=3, w=4, c_out=5),
        (2, 12, 3, 4),
        (2, 5),
    ),
    (
        Arguments("b t head c_in -> b t head c_out", "head c_in c_out", bias_shape=None, head=4, c_in=5, c_out=6),
        (2, 3, 4, 5),
        (2, 3, 4, 6),
    ),
]


# Each of the form:
# (Arguments, true_einsum_pattern, in_shapes, out_shape)
test_functional_cases = [
    (
        # Basic:
        "b c h w, b w -> b h",
        "abcd,ad->ac",
        ((2, 3, 4, 5), (2, 5)),
        (2, 4),
    ),
    (
        # Three tensors:
        "b c h w, b w, b c -> b h",
        "abcd,ad,ab->ac",
        ((2, 3, 40, 5), (2, 5), (2, 3)),
        (2, 40),
    ),
    (
        # Ellipsis, and full names:
        "... one two three, three four five -> ... two five",
        "...abc,cde->...be",
        ((32, 5, 2, 3, 4), (4, 5, 6)),
        (32, 5, 3, 6),
    ),
    (
        # Ellipsis at the end:
        "one two three ..., three four five -> two five ...",
        "abc...,cde->be...",
        ((2, 3, 4, 32, 5), (4, 5, 6)),
        (3, 6, 32, 5),
    ),
    (
        # Ellipsis on multiple tensors:
        "... one two three, ... three four five -> ... two five",
        "...abc,...cde->...be",
        ((32, 5, 2, 3, 4), (32, 5, 4, 5, 6)),
        (32, 5, 3, 6),
    ),
    (
        # One tensor, and underscores:
        "first_tensor second_tensor -> first_tensor",
        "ab->a",
        ((5, 4),),
        (5,),
    ),
    (
        # Trace (repeated index)
        "i i -> ",
        "aa->",
        ((5, 5),),
        (),
    ),
    (
        # Too many spaces in string:
        " one  two  ,  three four->two  four  ",
        "ab,cd->bd",
        ((2, 3), (4, 5)),
        (3, 5),
    ),
    # The following tests were inspired by numpy's einsum tests
    # https://github.com/numpy/numpy/blob/v1.23.0/numpy/core/tests/test_einsum.py
    (
        # Trace with other indices
        "i middle i -> middle",
        "aba->b",
        ((5, 10, 5),),
        (10,),
    ),
    (
        # Ellipsis in the middle:
        "i ... i -> ...",
        "a...a->...",
        ((5, 3, 2, 1, 4, 5),),
        (3, 2, 1, 4),
    ),
    (
        # Product of first and last axes:
        "i ... i -> i ...",
        "a...a->a...",
        ((5, 3, 2, 1, 4, 5),),
        (5, 3, 2, 1, 4),
    ),
    (
        # Triple diagonal
        "one one one -> one",
        "aaa->a",
        ((5, 5, 5),),
        (5,),
    ),
    (
        # Axis swap:
        "i j k -> j i k",
        "abc->bac",
        ((1, 2, 3),),
        (2, 1, 3),
    ),
    (
        # Identity:
        "... -> ...",
        "...->...",
        ((5, 4, 3, 2, 1),),
        (5, 4, 3, 2, 1),
    ),
    (
        # Elementwise product of three tensors
        "..., ..., ... -> ...",
        "...,...,...->...",
        ((3, 2), (3, 2), (3, 2)),
        (3, 2),
    ),
    (
        # Basic summation:
        "index ->",
        "a->",
        ((10,)),
        (()),
    ),
]


def test_layer():
    for backend in collect_test_backends(layers=True, symbolic=False):
        rng = np.random.default_rng()
        if backend.framework_name in ["tensorflow", "torch", "oneflow", "paddle"]:
            layer_type = backend.layers().EinMix
            for args, in_shape, out_shape in test_layer_cases:
                layer = args(layer_type)
                print("Running", layer.einsum_pattern, "for", backend.framework_name)
                input = rng.uniform(size=in_shape).astype("float32")
                input_framework = backend.from_numpy(input)
                output_framework = layer(input_framework)
                output = backend.to_numpy(output_framework)
                assert output.shape == out_shape


valid_backends_functional = [
    "tensorflow",
    "torch",
    "jax",
    "numpy",
    "oneflow",
    "cupy",
    "tensorflow.keras",
    "paddle",
    "pytensor",
]


def test_functional():
    # Functional tests:
    backends = filter(lambda x: x.framework_name in valid_backends_functional, collect_test_backends())
    for backend in backends:
        for einops_pattern, true_pattern, in_shapes, out_shape in test_functional_cases:
            print(f"Running '{einops_pattern}' for {backend.framework_name}")

            # Create pattern:
            predicted_pattern = _compactify_pattern_for_einsum(einops_pattern)
            assert predicted_pattern == true_pattern

            # Generate example data:
            rstate = np.random.RandomState(0)
            in_arrays = [rstate.uniform(size=shape).astype("float32") for shape in in_shapes]
            in_arrays_framework = [backend.from_numpy(array) for array in in_arrays]

            # Loop over whether we call it manually with the backend,
            # or whether we use `einops.einsum`.
            for do_manual_call in [True, False]:
                # Actually run einsum:
                if do_manual_call:
                    out_array = backend.einsum(predicted_pattern, *in_arrays_framework)
                else:
                    out_array = einsum(*in_arrays_framework, einops_pattern)

                # Check shape:
                if tuple(out_array.shape) != out_shape:
                    raise ValueError(f"Expected output shape {out_shape} but got {out_array.shape}")

                # Check values:
                true_out_array = np.einsum(true_pattern, *in_arrays)
                predicted_out_array = backend.to_numpy(out_array)
                np.testing.assert_array_almost_equal(predicted_out_array, true_out_array, decimal=5)


def test_functional_symbolic():
    backends = filter(
        lambda x: x.framework_name in valid_backends_functional, collect_test_backends(symbolic=True, layers=False)
    )
    for backend in backends:
        for einops_pattern, true_pattern, in_shapes, out_shape in test_functional_cases:
            print(f"Running '{einops_pattern}' for symbolic {backend.framework_name}")
            # Create pattern:
            predicted_pattern = _compactify_pattern_for_einsum(einops_pattern)
            assert predicted_pattern == true_pattern

            rstate = np.random.RandomState(0)
            in_syms = [backend.create_symbol(in_shape) for in_shape in in_shapes]
            in_data = [rstate.uniform(size=in_shape).astype("float32") for in_shape in in_shapes]

            expected_out_data = np.einsum(true_pattern, *in_data)

            for do_manual_call in [True, False]:
                if do_manual_call:
                    predicted_out_symbol = backend.einsum(predicted_pattern, *in_syms)
                else:
                    predicted_out_symbol = einsum(*in_syms, einops_pattern)

                predicted_out_data = backend.eval_symbol(
                    predicted_out_symbol,
                    list(zip(in_syms, in_data)),
                )
                if predicted_out_data.shape != out_shape:
                    raise ValueError(f"Expected output shape {out_shape} but got {predicted_out_data.shape}")
                np.testing.assert_array_almost_equal(predicted_out_data, expected_out_data, decimal=5)


def test_functional_errors():
    # Specific backend does not matter, as errors are raised
    # during the pattern creation.

    rstate = np.random.RandomState(0)

    def create_tensor(*shape):
        return rstate.uniform(size=shape).astype("float32")

    # raise NotImplementedError("Singleton () axes are not yet supported in einsum.")
    with pytest.raises(NotImplementedError, match="^Singleton"):
        einsum(
            create_tensor(5, 1),
            "i () -> i",
        )

    # raise NotImplementedError("Shape rearrangement is not yet supported in einsum.")
    with pytest.raises(NotImplementedError, match="^Shape rearrangement"):
        einsum(
            create_tensor(5, 1),
            "a b -> (a b)",
        )

    with pytest.raises(NotImplementedError, match="^Shape rearrangement"):
        einsum(
            create_tensor(10, 1),
            "(a b) -> a b",
        )

    # raise RuntimeError("Encountered empty axis name in einsum.")
    # raise RuntimeError("Axis name in einsum must be a string.")
    # ^ Not tested, these are just a failsafe in case an unexpected error occurs.

    # raise NotImplementedError("Anonymous axes are not yet supported in einsum.")
    with pytest.raises(NotImplementedError, match="^Anonymous axes"):
        einsum(
            create_tensor(5, 1),
            "i 2 -> i",
        )

    # ParsedExpression error:
    with pytest.raises(EinopsError, match="^Invalid axis identifier"):
        einsum(
            create_tensor(5, 1),
            "i 2j -> i",
        )

    # raise ValueError("Einsum pattern must contain '->'.")
    with pytest.raises(ValueError, match="^Einsum pattern"):
        einsum(
            create_tensor(5, 3, 2),
            "i j k",
        )

    # raise RuntimeError("Too many axes in einsum.")
    with pytest.raises(RuntimeError, match="^Too many axes"):
        einsum(
            create_tensor(1),
            " ".join(string.ascii_letters) + " extra ->",
        )

    # raise RuntimeError("Unknown axis on right side of einsum.")
    with pytest.raises(RuntimeError, match="^Unknown axis"):
        einsum(
            create_tensor(5, 1),
            "i j -> k",
        )

    # raise ValueError(
    # "The last argument passed to `einops.einsum` must be a string,"
    # " representing the einsum pattern."
    # )
    with pytest.raises(ValueError, match="^The last argument"):
        einsum(
            "i j k -> i",
            create_tensor(5, 4, 3),
        )

    # raise ValueError(
    #     "`einops.einsum` takes at minimum two arguments: the tensors,"
    #     " followed by the pattern."
    # )
    with pytest.raises(ValueError, match="^`einops.einsum` takes"):
        einsum(
            "i j k -> i",
        )
    with pytest.raises(ValueError, match="^`einops.einsum` takes"):
        einsum(
            create_tensor(5, 1),
        )

    # TODO: Include check for giving normal einsum pattern rather than einops.