Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

aaronreidsmith / scipy   python

Repository URL to install this package:

Version: 1.3.3 

/ linalg / tests / test_blas.py

#
# Created by: Pearu Peterson, April 2002
#
from __future__ import division, print_function, absolute_import


__usage__ = """
Build linalg:
  python setup.py build
Run tests if scipy is installed:
  python -c 'import scipy;scipy.linalg.test()'
"""

import math

import numpy as np
from numpy.testing import (assert_equal, assert_almost_equal, assert_,
                           assert_array_almost_equal, assert_allclose)
from pytest import raises as assert_raises

from numpy import float32, float64, complex64, complex128, arange, triu, \
                  tril, zeros, tril_indices, ones, mod, diag, append, eye, \
                  nonzero

from numpy.random import rand, seed
from scipy.linalg import _fblas as fblas, get_blas_funcs, toeplitz, solve, \
                                          solve_triangular

try:
    from scipy.linalg import _cblas as cblas
except ImportError:
    cblas = None

REAL_DTYPES = [float32, float64]
COMPLEX_DTYPES = [complex64, complex128]
DTYPES = REAL_DTYPES + COMPLEX_DTYPES


def test_get_blas_funcs():
    # check that it returns Fortran code for arrays that are
    # fortran-ordered
    f1, f2, f3 = get_blas_funcs(
        ('axpy', 'axpy', 'axpy'),
        (np.empty((2, 2), dtype=np.complex64, order='F'),
         np.empty((2, 2), dtype=np.complex128, order='C'))
        )

    # get_blas_funcs will choose libraries depending on most generic
    # array
    assert_equal(f1.typecode, 'z')
    assert_equal(f2.typecode, 'z')
    if cblas is not None:
        assert_equal(f1.module_name, 'cblas')
        assert_equal(f2.module_name, 'cblas')

    # check defaults.
    f1 = get_blas_funcs('rotg')
    assert_equal(f1.typecode, 'd')

    # check also dtype interface
    f1 = get_blas_funcs('gemm', dtype=np.complex64)
    assert_equal(f1.typecode, 'c')
    f1 = get_blas_funcs('gemm', dtype='F')
    assert_equal(f1.typecode, 'c')

    # extended precision complex
    f1 = get_blas_funcs('gemm', dtype=np.longcomplex)
    assert_equal(f1.typecode, 'z')

    # check safe complex upcasting
    f1 = get_blas_funcs('axpy',
                        (np.empty((2, 2), dtype=np.float64),
                         np.empty((2, 2), dtype=np.complex64))
                        )
    assert_equal(f1.typecode, 'z')


def test_get_blas_funcs_alias():
    # check alias for get_blas_funcs
    f, g = get_blas_funcs(('nrm2', 'dot'), dtype=np.complex64)
    assert f.typecode == 'c'
    assert g.typecode == 'c'

    f, g, h = get_blas_funcs(('dot', 'dotc', 'dotu'), dtype=np.float64)
    assert f is g
    assert f is h


class TestCBLAS1Simple(object):

    def test_axpy(self):
        for p in 'sd':
            f = getattr(cblas, p+'axpy', None)
            if f is None:
                continue
            assert_array_almost_equal(f([1, 2, 3], [2, -1, 3], a=5),
                                      [7, 9, 18])
        for p in 'cz':
            f = getattr(cblas, p+'axpy', None)
            if f is None:
                continue
            assert_array_almost_equal(f([1, 2j, 3], [2, -1, 3], a=5),
                                      [7, 10j-1, 18])


class TestFBLAS1Simple(object):

    def test_axpy(self):
        for p in 'sd':
            f = getattr(fblas, p+'axpy', None)
            if f is None:
                continue
            assert_array_almost_equal(f([1, 2, 3], [2, -1, 3], a=5),
                                      [7, 9, 18])
        for p in 'cz':
            f = getattr(fblas, p+'axpy', None)
            if f is None:
                continue
            assert_array_almost_equal(f([1, 2j, 3], [2, -1, 3], a=5),
                                      [7, 10j-1, 18])

    def test_copy(self):
        for p in 'sd':
            f = getattr(fblas, p+'copy', None)
            if f is None:
                continue
            assert_array_almost_equal(f([3, 4, 5], [8]*3), [3, 4, 5])
        for p in 'cz':
            f = getattr(fblas, p+'copy', None)
            if f is None:
                continue
            assert_array_almost_equal(f([3, 4j, 5+3j], [8]*3), [3, 4j, 5+3j])

    def test_asum(self):
        for p in 'sd':
            f = getattr(fblas, p+'asum', None)
            if f is None:
                continue
            assert_almost_equal(f([3, -4, 5]), 12)
        for p in ['sc', 'dz']:
            f = getattr(fblas, p+'asum', None)
            if f is None:
                continue
            assert_almost_equal(f([3j, -4, 3-4j]), 14)

    def test_dot(self):
        for p in 'sd':
            f = getattr(fblas, p+'dot', None)
            if f is None:
                continue
            assert_almost_equal(f([3, -4, 5], [2, 5, 1]), -9)

    def test_complex_dotu(self):
        for p in 'cz':
            f = getattr(fblas, p+'dotu', None)
            if f is None:
                continue
            assert_almost_equal(f([3j, -4, 3-4j], [2, 3, 1]), -9+2j)

    def test_complex_dotc(self):
        for p in 'cz':
            f = getattr(fblas, p+'dotc', None)
            if f is None:
                continue
            assert_almost_equal(f([3j, -4, 3-4j], [2, 3j, 1]), 3-14j)

    def test_nrm2(self):
        for p in 'sd':
            f = getattr(fblas, p+'nrm2', None)
            if f is None:
                continue
            assert_almost_equal(f([3, -4, 5]), math.sqrt(50))
        for p in ['c', 'z', 'sc', 'dz']:
            f = getattr(fblas, p+'nrm2', None)
            if f is None:
                continue
            assert_almost_equal(f([3j, -4, 3-4j]), math.sqrt(50))

    def test_scal(self):
        for p in 'sd':
            f = getattr(fblas, p+'scal', None)
            if f is None:
                continue
            assert_array_almost_equal(f(2, [3, -4, 5]), [6, -8, 10])
        for p in 'cz':
            f = getattr(fblas, p+'scal', None)
            if f is None:
                continue
            assert_array_almost_equal(f(3j, [3j, -4, 3-4j]), [-9, -12j, 12+9j])
        for p in ['cs', 'zd']:
            f = getattr(fblas, p+'scal', None)
            if f is None:
                continue
            assert_array_almost_equal(f(3, [3j, -4, 3-4j]), [9j, -12, 9-12j])

    def test_swap(self):
        for p in 'sd':
            f = getattr(fblas, p+'swap', None)
            if f is None:
                continue
            x, y = [2, 3, 1], [-2, 3, 7]
            x1, y1 = f(x, y)
            assert_array_almost_equal(x1, y)
            assert_array_almost_equal(y1, x)
        for p in 'cz':
            f = getattr(fblas, p+'swap', None)
            if f is None:
                continue
            x, y = [2, 3j, 1], [-2, 3, 7-3j]
            x1, y1 = f(x, y)
            assert_array_almost_equal(x1, y)
            assert_array_almost_equal(y1, x)

    def test_amax(self):
        for p in 'sd':
            f = getattr(fblas, 'i'+p+'amax')
            assert_equal(f([-2, 4, 3]), 1)
        for p in 'cz':
            f = getattr(fblas, 'i'+p+'amax')
            assert_equal(f([-5, 4+3j, 6]), 1)
    # XXX: need tests for rot,rotm,rotg,rotmg


class TestFBLAS2Simple(object):

    def test_gemv(self):
        for p in 'sd':
            f = getattr(fblas, p+'gemv', None)
            if f is None:
                continue
            assert_array_almost_equal(f(3, [[3]], [-4]), [-36])
            assert_array_almost_equal(f(3, [[3]], [-4], 3, [5]), [-21])
        for p in 'cz':
            f = getattr(fblas, p+'gemv', None)
            if f is None:
                continue
            assert_array_almost_equal(f(3j, [[3-4j]], [-4]), [-48-36j])
            assert_array_almost_equal(f(3j, [[3-4j]], [-4], 3, [5j]),
                                      [-48-21j])

    def test_ger(self):

        for p in 'sd':
            f = getattr(fblas, p+'ger', None)
            if f is None:
                continue
            assert_array_almost_equal(f(1, [1, 2], [3, 4]), [[3, 4], [6, 8]])
            assert_array_almost_equal(f(2, [1, 2, 3], [3, 4]),
                                      [[6, 8], [12, 16], [18, 24]])

            assert_array_almost_equal(f(1, [1, 2], [3, 4],
                                        a=[[1, 2], [3, 4]]), [[4, 6], [9, 12]])

        for p in 'cz':
            f = getattr(fblas, p+'geru', None)
            if f is None:
                continue
            assert_array_almost_equal(f(1, [1j, 2], [3, 4]),
                                      [[3j, 4j], [6, 8]])
            assert_array_almost_equal(f(-2, [1j, 2j, 3j], [3j, 4j]),
                                      [[6, 8], [12, 16], [18, 24]])

        for p in 'cz':
            for name in ('ger', 'gerc'):
                f = getattr(fblas, p+name, None)
                if f is None:
                    continue
                assert_array_almost_equal(f(1, [1j, 2], [3, 4]),
                                          [[3j, 4j], [6, 8]])
                assert_array_almost_equal(f(2, [1j, 2j, 3j], [3j, 4j]),
                                          [[6, 8], [12, 16], [18, 24]])

    def test_syr_her(self):
        x = np.arange(1, 5, dtype='d')
        resx = np.triu(x[:, np.newaxis] * x)
        resx_reverse = np.triu(x[::-1, np.newaxis] * x[::-1])

        y = np.linspace(0, 8.5, 17, endpoint=False)

        z = np.arange(1, 9, dtype='d').view('D')
        resz = np.triu(z[:, np.newaxis] * z)
        resz_reverse = np.triu(z[::-1, np.newaxis] * z[::-1])
        rehz = np.triu(z[:, np.newaxis] * z.conj())
        rehz_reverse = np.triu(z[::-1, np.newaxis] * z[::-1].conj())

        w = np.c_[np.zeros(4), z, np.zeros(4)].ravel()

        for p, rtol in zip('sd', [1e-7, 1e-14]):
            f = getattr(fblas, p+'syr', None)
            if f is None:
                continue
            assert_allclose(f(1.0, x), resx, rtol=rtol)
            assert_allclose(f(1.0, x, lower=True), resx.T, rtol=rtol)
            assert_allclose(f(1.0, y, incx=2, offx=2, n=4), resx, rtol=rtol)
            # negative increments imply reversed vectors in blas
            assert_allclose(f(1.0, y, incx=-2, offx=2, n=4),
                            resx_reverse, rtol=rtol)

            a = np.zeros((4, 4), 'f' if p == 's' else 'd', 'F')
            b = f(1.0, x, a=a, overwrite_a=True)
            assert_allclose(a, resx, rtol=rtol)

            b = f(2.0, x, a=a)
            assert_(a is not b)
            assert_allclose(b, 3*resx, rtol=rtol)

            assert_raises(Exception, f, 1.0, x, incx=0)
            assert_raises(Exception, f, 1.0, x, offx=5)
            assert_raises(Exception, f, 1.0, x, offx=-2)
            assert_raises(Exception, f, 1.0, x, n=-2)
            assert_raises(Exception, f, 1.0, x, n=5)
            assert_raises(Exception, f, 1.0, x, lower=2)
            assert_raises(Exception, f, 1.0, x, a=np.zeros((2, 2), 'd', 'F'))

        for p, rtol in zip('cz', [1e-7, 1e-14]):
            f = getattr(fblas, p+'syr', None)
            if f is None:
                continue
            assert_allclose(f(1.0, z), resz, rtol=rtol)
            assert_allclose(f(1.0, z, lower=True), resz.T, rtol=rtol)
            assert_allclose(f(1.0, w, incx=3, offx=1, n=4), resz, rtol=rtol)
            # negative increments imply reversed vectors in blas
            assert_allclose(f(1.0, w, incx=-3, offx=1, n=4),
                            resz_reverse, rtol=rtol)

            a = np.zeros((4, 4), 'F' if p == 'c' else 'D', 'F')
            b = f(1.0, z, a=a, overwrite_a=True)
            assert_allclose(a, resz, rtol=rtol)

            b = f(2.0, z, a=a)
            assert_(a is not b)
            assert_allclose(b, 3*resz, rtol=rtol)

            assert_raises(Exception, f, 1.0, x, incx=0)
            assert_raises(Exception, f, 1.0, x, offx=5)
            assert_raises(Exception, f, 1.0, x, offx=-2)
            assert_raises(Exception, f, 1.0, x, n=-2)
            assert_raises(Exception, f, 1.0, x, n=5)
            assert_raises(Exception, f, 1.0, x, lower=2)
            assert_raises(Exception, f, 1.0, x, a=np.zeros((2, 2), 'd', 'F'))

        for p, rtol in zip('cz', [1e-7, 1e-14]):
            f = getattr(fblas, p+'her', None)
            if f is None:
                continue
Loading ...