Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

aaronreidsmith / numpy   python

Repository URL to install this package:

Version: 1.17.4 

/ core / tests / test_einsum.py

from __future__ import division, absolute_import, print_function

import itertools

import numpy as np
from numpy.testing import (
    assert_, assert_equal, assert_array_equal, assert_almost_equal,
    assert_raises, suppress_warnings, assert_raises_regex, assert_allclose
    )

# Setup for optimize einsum
chars = 'abcdefghij'
sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3])
global_size_dict = dict(zip(chars, sizes))


class TestEinsum(object):
    def test_einsum_errors(self):
        for do_opt in [True, False]:
            # Need enough arguments
            assert_raises(ValueError, np.einsum, optimize=do_opt)
            assert_raises(ValueError, np.einsum, "", optimize=do_opt)

            # subscripts must be a string
            assert_raises(TypeError, np.einsum, 0, 0, optimize=do_opt)

            # out parameter must be an array
            assert_raises(TypeError, np.einsum, "", 0, out='test',
                          optimize=do_opt)

            # order parameter must be a valid order
            assert_raises(TypeError, np.einsum, "", 0, order='W',
                          optimize=do_opt)

            # casting parameter must be a valid casting
            assert_raises(ValueError, np.einsum, "", 0, casting='blah',
                          optimize=do_opt)

            # dtype parameter must be a valid dtype
            assert_raises(TypeError, np.einsum, "", 0, dtype='bad_data_type',
                          optimize=do_opt)

            # other keyword arguments are rejected
            assert_raises(TypeError, np.einsum, "", 0, bad_arg=0,
                          optimize=do_opt)

            # issue 4528 revealed a segfault with this call
            assert_raises(TypeError, np.einsum, *(None,)*63, optimize=do_opt)

            # number of operands must match count in subscripts string
            assert_raises(ValueError, np.einsum, "", 0, 0, optimize=do_opt)
            assert_raises(ValueError, np.einsum, ",", 0, [0], [0],
                          optimize=do_opt)
            assert_raises(ValueError, np.einsum, ",", [0], optimize=do_opt)

            # can't have more subscripts than dimensions in the operand
            assert_raises(ValueError, np.einsum, "i", 0, optimize=do_opt)
            assert_raises(ValueError, np.einsum, "ij", [0, 0], optimize=do_opt)
            assert_raises(ValueError, np.einsum, "...i", 0, optimize=do_opt)
            assert_raises(ValueError, np.einsum, "i...j", [0, 0], optimize=do_opt)
            assert_raises(ValueError, np.einsum, "i...", 0, optimize=do_opt)
            assert_raises(ValueError, np.einsum, "ij...", [0, 0], optimize=do_opt)

            # invalid ellipsis
            assert_raises(ValueError, np.einsum, "i..", [0, 0], optimize=do_opt)
            assert_raises(ValueError, np.einsum, ".i...", [0, 0], optimize=do_opt)
            assert_raises(ValueError, np.einsum, "j->..j", [0, 0], optimize=do_opt)
            assert_raises(ValueError, np.einsum, "j->.j...", [0, 0], optimize=do_opt)

            # invalid subscript character
            assert_raises(ValueError, np.einsum, "i%...", [0, 0], optimize=do_opt)
            assert_raises(ValueError, np.einsum, "...j$", [0, 0], optimize=do_opt)
            assert_raises(ValueError, np.einsum, "i->&", [0, 0], optimize=do_opt)

            # output subscripts must appear in input
            assert_raises(ValueError, np.einsum, "i->ij", [0, 0], optimize=do_opt)

            # output subscripts may only be specified once
            assert_raises(ValueError, np.einsum, "ij->jij", [[0, 0], [0, 0]],
                          optimize=do_opt)

            # dimensions much match when being collapsed
            assert_raises(ValueError, np.einsum, "ii",
                          np.arange(6).reshape(2, 3), optimize=do_opt)
            assert_raises(ValueError, np.einsum, "ii->i",
                          np.arange(6).reshape(2, 3), optimize=do_opt)

            # broadcasting to new dimensions must be enabled explicitly
            assert_raises(ValueError, np.einsum, "i", np.arange(6).reshape(2, 3),
                          optimize=do_opt)
            assert_raises(ValueError, np.einsum, "i->i", [[0, 1], [0, 1]],
                          out=np.arange(4).reshape(2, 2), optimize=do_opt)
            with assert_raises_regex(ValueError, "'b'"):
                # gh-11221 - 'c' erroneously appeared in the error message
                a = np.ones((3, 3, 4, 5, 6))
                b = np.ones((3, 4, 5))
                np.einsum('aabcb,abc', a, b)

    def test_einsum_views(self):
        # pass-through
        for do_opt in [True, False]:
            a = np.arange(6)
            a.shape = (2, 3)

            b = np.einsum("...", a, optimize=do_opt)
            assert_(b.base is a)

            b = np.einsum(a, [Ellipsis], optimize=do_opt)
            assert_(b.base is a)

            b = np.einsum("ij", a, optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, a)

            b = np.einsum(a, [0, 1], optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, a)

            # output is writeable whenever input is writeable
            b = np.einsum("...", a, optimize=do_opt)
            assert_(b.flags['WRITEABLE'])
            a.flags['WRITEABLE'] = False
            b = np.einsum("...", a, optimize=do_opt)
            assert_(not b.flags['WRITEABLE'])

            # transpose
            a = np.arange(6)
            a.shape = (2, 3)

            b = np.einsum("ji", a, optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, a.T)

            b = np.einsum(a, [1, 0], optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, a.T)

            # diagonal
            a = np.arange(9)
            a.shape = (3, 3)

            b = np.einsum("ii->i", a, optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [a[i, i] for i in range(3)])

            b = np.einsum(a, [0, 0], [0], optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [a[i, i] for i in range(3)])

            # diagonal with various ways of broadcasting an additional dimension
            a = np.arange(27)
            a.shape = (3, 3, 3)

            b = np.einsum("...ii->...i", a, optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [[x[i, i] for i in range(3)] for x in a])

            b = np.einsum(a, [Ellipsis, 0, 0], [Ellipsis, 0], optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [[x[i, i] for i in range(3)] for x in a])

            b = np.einsum("ii...->...i", a, optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [[x[i, i] for i in range(3)]
                             for x in a.transpose(2, 0, 1)])

            b = np.einsum(a, [0, 0, Ellipsis], [Ellipsis, 0], optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [[x[i, i] for i in range(3)]
                             for x in a.transpose(2, 0, 1)])

            b = np.einsum("...ii->i...", a, optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [a[:, i, i] for i in range(3)])

            b = np.einsum(a, [Ellipsis, 0, 0], [0, Ellipsis], optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [a[:, i, i] for i in range(3)])

            b = np.einsum("jii->ij", a, optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [a[:, i, i] for i in range(3)])

            b = np.einsum(a, [1, 0, 0], [0, 1], optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [a[:, i, i] for i in range(3)])

            b = np.einsum("ii...->i...", a, optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [a.transpose(2, 0, 1)[:, i, i] for i in range(3)])

            b = np.einsum(a, [0, 0, Ellipsis], [0, Ellipsis], optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [a.transpose(2, 0, 1)[:, i, i] for i in range(3)])

            b = np.einsum("i...i->i...", a, optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [a.transpose(1, 0, 2)[:, i, i] for i in range(3)])

            b = np.einsum(a, [0, Ellipsis, 0], [0, Ellipsis], optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [a.transpose(1, 0, 2)[:, i, i] for i in range(3)])

            b = np.einsum("i...i->...i", a, optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [[x[i, i] for i in range(3)]
                             for x in a.transpose(1, 0, 2)])

            b = np.einsum(a, [0, Ellipsis, 0], [Ellipsis, 0], optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [[x[i, i] for i in range(3)]
                             for x in a.transpose(1, 0, 2)])

            # triple diagonal
            a = np.arange(27)
            a.shape = (3, 3, 3)

            b = np.einsum("iii->i", a, optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [a[i, i, i] for i in range(3)])

            b = np.einsum(a, [0, 0, 0], [0], optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, [a[i, i, i] for i in range(3)])

            # swap axes
            a = np.arange(24)
            a.shape = (2, 3, 4)

            b = np.einsum("ijk->jik", a, optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, a.swapaxes(0, 1))

            b = np.einsum(a, [0, 1, 2], [1, 0, 2], optimize=do_opt)
            assert_(b.base is a)
            assert_equal(b, a.swapaxes(0, 1))

    def check_einsum_sums(self, dtype, do_opt=False):
        # Check various sums.  Does many sizes to exercise unrolled loops.

        # sum(a, axis=-1)
        for n in range(1, 17):
            a = np.arange(n, dtype=dtype)
            assert_equal(np.einsum("i->", a, optimize=do_opt),
                         np.sum(a, axis=-1).astype(dtype))
            assert_equal(np.einsum(a, [0], [], optimize=do_opt),
                         np.sum(a, axis=-1).astype(dtype))

        for n in range(1, 17):
            a = np.arange(2*3*n, dtype=dtype).reshape(2, 3, n)
            assert_equal(np.einsum("...i->...", a, optimize=do_opt),
                         np.sum(a, axis=-1).astype(dtype))
            assert_equal(np.einsum(a, [Ellipsis, 0], [Ellipsis], optimize=do_opt),
                         np.sum(a, axis=-1).astype(dtype))

        # sum(a, axis=0)
        for n in range(1, 17):
            a = np.arange(2*n, dtype=dtype).reshape(2, n)
            assert_equal(np.einsum("i...->...", a, optimize=do_opt),
                         np.sum(a, axis=0).astype(dtype))
            assert_equal(np.einsum(a, [0, Ellipsis], [Ellipsis], optimize=do_opt),
                         np.sum(a, axis=0).astype(dtype))

        for n in range(1, 17):
            a = np.arange(2*3*n, dtype=dtype).reshape(2, 3, n)
            assert_equal(np.einsum("i...->...", a, optimize=do_opt),
                         np.sum(a, axis=0).astype(dtype))
            assert_equal(np.einsum(a, [0, Ellipsis], [Ellipsis], optimize=do_opt),
                         np.sum(a, axis=0).astype(dtype))

        # trace(a)
        for n in range(1, 17):
            a = np.arange(n*n, dtype=dtype).reshape(n, n)
            assert_equal(np.einsum("ii", a, optimize=do_opt),
                         np.trace(a).astype(dtype))
            assert_equal(np.einsum(a, [0, 0], optimize=do_opt),
                         np.trace(a).astype(dtype))

        # multiply(a, b)
        assert_equal(np.einsum("..., ...", 3, 4), 12)  # scalar case
        for n in range(1, 17):
            a = np.arange(3 * n, dtype=dtype).reshape(3, n)
            b = np.arange(2 * 3 * n, dtype=dtype).reshape(2, 3, n)
            assert_equal(np.einsum("..., ...", a, b, optimize=do_opt),
                         np.multiply(a, b))
            assert_equal(np.einsum(a, [Ellipsis], b, [Ellipsis], optimize=do_opt),
                         np.multiply(a, b))

        # inner(a,b)
        for n in range(1, 17):
            a = np.arange(2 * 3 * n, dtype=dtype).reshape(2, 3, n)
            b = np.arange(n, dtype=dtype)
            assert_equal(np.einsum("...i, ...i", a, b, optimize=do_opt), np.inner(a, b))
            assert_equal(np.einsum(a, [Ellipsis, 0], b, [Ellipsis, 0], optimize=do_opt),
                         np.inner(a, b))

        for n in range(1, 11):
            a = np.arange(n * 3 * 2, dtype=dtype).reshape(n, 3, 2)
            b = np.arange(n, dtype=dtype)
            assert_equal(np.einsum("i..., i...", a, b, optimize=do_opt),
                         np.inner(a.T, b.T).T)
            assert_equal(np.einsum(a, [0, Ellipsis], b, [0, Ellipsis], optimize=do_opt),
                         np.inner(a.T, b.T).T)

        # outer(a,b)
        for n in range(1, 17):
            a = np.arange(3, dtype=dtype)+1
            b = np.arange(n, dtype=dtype)+1
            assert_equal(np.einsum("i,j", a, b, optimize=do_opt),
                         np.outer(a, b))
            assert_equal(np.einsum(a, [0], b, [1], optimize=do_opt),
                         np.outer(a, b))

        # Suppress the complex warnings for the 'as f8' tests
        with suppress_warnings() as sup:
            sup.filter(np.ComplexWarning)

            # matvec(a,b) / a.dot(b) where a is matrix, b is vector
            for n in range(1, 17):
                a = np.arange(4*n, dtype=dtype).reshape(4, n)
                b = np.arange(n, dtype=dtype)
                assert_equal(np.einsum("ij, j", a, b, optimize=do_opt),
                             np.dot(a, b))
                assert_equal(np.einsum(a, [0, 1], b, [1], optimize=do_opt),
                             np.dot(a, b))

                c = np.arange(4, dtype=dtype)
                np.einsum("ij,j", a, b, out=c,
                          dtype='f8', casting='unsafe', optimize=do_opt)
                assert_equal(c,
                             np.dot(a.astype('f8'),
                                    b.astype('f8')).astype(dtype))
                c[...] = 0
                np.einsum(a, [0, 1], b, [1], out=c,
                          dtype='f8', casting='unsafe', optimize=do_opt)
                assert_equal(c,
                             np.dot(a.astype('f8'),
                                    b.astype('f8')).astype(dtype))

            for n in range(1, 17):
                a = np.arange(4*n, dtype=dtype).reshape(4, n)
                b = np.arange(n, dtype=dtype)
                assert_equal(np.einsum("ji,j", a.T, b.T, optimize=do_opt),
                             np.dot(b.T, a.T))
                assert_equal(np.einsum(a.T, [1, 0], b.T, [1], optimize=do_opt),
Loading ...