Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

aaronreidsmith / scipy   python

Repository URL to install this package:

Version: 1.3.3 

/ linalg / tests / test_decomp_update.py

from __future__ import division, print_function, absolute_import

import itertools

import numpy as np
from numpy.testing import assert_, assert_allclose, assert_equal
from pytest import raises as assert_raises
from scipy import linalg
import scipy.linalg._decomp_update as _decomp_update
from scipy.linalg._decomp_update import *

def assert_unitary(a, rtol=None, atol=None, assert_sqr=True):
    if rtol is None:
        rtol = 10.0 ** -(np.finfo(a.dtype).precision-2)
    if atol is None:
        atol = 2*np.finfo(a.dtype).eps

    if assert_sqr:
        assert_(a.shape[0] == a.shape[1], 'unitary matrices must be square')
    aTa = np.dot(a.T.conj(), a)
    assert_allclose(aTa, np.eye(a.shape[1]), rtol=rtol, atol=atol)

def assert_upper_tri(a, rtol=None, atol=None):
    if rtol is None:
        rtol = 10.0 ** -(np.finfo(a.dtype).precision-2)
    if atol is None:
        atol = 2*np.finfo(a.dtype).eps
    mask = np.tri(a.shape[0], a.shape[1], -1, np.bool_)
    assert_allclose(a[mask], 0.0, rtol=rtol, atol=atol)

def check_qr(q, r, a, rtol, atol, assert_sqr=True):
    assert_unitary(q, rtol, atol, assert_sqr)
    assert_upper_tri(r, rtol, atol)
    assert_allclose(q.dot(r), a, rtol=rtol, atol=atol)

def make_strided(arrs):
    strides = [(3, 7), (2, 2), (3, 4), (4, 2), (5, 4), (2, 3), (2, 1), (4, 5)]
    kmax = len(strides)
    k = 0
    ret = []
    for a in arrs:
        if a.ndim == 1:
            s = strides[k % kmax]
            k += 1
            base = np.zeros(s[0]*a.shape[0]+s[1], a.dtype)
            view = base[s[1]::s[0]]
            view[...] = a
        elif a.ndim == 2:
            s = strides[k % kmax]
            t = strides[(k+1) % kmax]
            k += 2
            base = np.zeros((s[0]*a.shape[0]+s[1], t[0]*a.shape[1]+t[1]), a.dtype)
            view = base[s[1]::s[0], t[1]::t[0]]
            view[...] = a
        else:
            raise ValueError('make_strided only works for ndim = 1 or 2 arrays')
        ret.append(view)
    return ret

def negate_strides(arrs):
    ret = []
    for a in arrs:
        b = np.zeros_like(a)
        if b.ndim == 2:
            b = b[::-1, ::-1]
        elif b.ndim == 1:
            b = b[::-1]
        else:
            raise ValueError('negate_strides only works for ndim = 1 or 2 arrays')
        b[...] = a
        ret.append(b)
    return ret

def nonitemsize_strides(arrs):
    out = []
    for a in arrs:
        a_dtype = a.dtype
        b = np.zeros(a.shape, [('a', a_dtype), ('junk', 'S1')])
        c = b.getfield(a_dtype)
        c[...] = a
        out.append(c)
    return out


def make_nonnative(arrs):
    return [a.astype(a.dtype.newbyteorder()) for a in arrs]


class BaseQRdeltas(object):
    def setup_method(self):
        self.rtol = 10.0 ** -(np.finfo(self.dtype).precision-2)
        self.atol = 10 * np.finfo(self.dtype).eps

    def generate(self, type, mode='full'):
        np.random.seed(29382)
        shape = {'sqr': (8, 8), 'tall': (12, 7), 'fat': (7, 12),
                 'Mx1': (8, 1), '1xN': (1, 8), '1x1': (1, 1)}[type]
        a = np.random.random(shape)
        if np.iscomplexobj(self.dtype.type(1)):
            b = np.random.random(shape)
            a = a + 1j * b
        a = a.astype(self.dtype)
        q, r = linalg.qr(a, mode=mode)
        return a, q, r

class BaseQRdelete(BaseQRdeltas):
    def test_sqr_1_row(self):
        a, q, r = self.generate('sqr')
        for row in range(r.shape[0]):
            q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
            a1 = np.delete(a, row, 0)
            check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_sqr_p_row(self):
        a, q, r = self.generate('sqr')
        for ndel in range(2, 6):
            for row in range(a.shape[0]-ndel):
                q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
                a1 = np.delete(a, slice(row, row+ndel), 0)
                check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_sqr_1_col(self):
        a, q, r = self.generate('sqr')
        for col in range(r.shape[1]):
            q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
            a1 = np.delete(a, col, 1)
            check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_sqr_p_col(self):
        a, q, r = self.generate('sqr')
        for ndel in range(2, 6):
            for col in range(r.shape[1]-ndel):
                q1, r1 = qr_delete(q, r, col, ndel, which='col',
                                   overwrite_qr=False)
                a1 = np.delete(a, slice(col, col+ndel), 1)
                check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_tall_1_row(self):
        a, q, r = self.generate('tall')
        for row in range(r.shape[0]):
            q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
            a1 = np.delete(a, row, 0)
            check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_tall_p_row(self):
        a, q, r = self.generate('tall')
        for ndel in range(2, 6):
            for row in range(a.shape[0]-ndel):
                q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
                a1 = np.delete(a, slice(row, row+ndel), 0)
                check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_tall_1_col(self):
        a, q, r = self.generate('tall')
        for col in range(r.shape[1]):
            q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
            a1 = np.delete(a, col, 1)
            check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_tall_p_col(self):
        a, q, r = self.generate('tall')
        for ndel in range(2, 6):
            for col in range(r.shape[1]-ndel):
                q1, r1 = qr_delete(q, r, col, ndel, which='col',
                                   overwrite_qr=False)
                a1 = np.delete(a, slice(col, col+ndel), 1)
                check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_fat_1_row(self):
        a, q, r = self.generate('fat')
        for row in range(r.shape[0]):
            q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
            a1 = np.delete(a, row, 0)
            check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_fat_p_row(self):
        a, q, r = self.generate('fat')
        for ndel in range(2, 6):
            for row in range(a.shape[0]-ndel):
                q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
                a1 = np.delete(a, slice(row, row+ndel), 0)
                check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_fat_1_col(self):
        a, q, r = self.generate('fat')
        for col in range(r.shape[1]):
            q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
            a1 = np.delete(a, col, 1)
            check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_fat_p_col(self):
        a, q, r = self.generate('fat')
        for ndel in range(2, 6):
            for col in range(r.shape[1]-ndel):
                q1, r1 = qr_delete(q, r, col, ndel, which='col',
                                   overwrite_qr=False)
                a1 = np.delete(a, slice(col, col+ndel), 1)
                check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_economic_1_row(self):
        # this test always starts and ends with an economic decomp.
        a, q, r = self.generate('tall', 'economic')
        for row in range(r.shape[0]):
            q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
            a1 = np.delete(a, row, 0)
            check_qr(q1, r1, a1, self.rtol, self.atol, False)

    # for economic row deletes 
    # eco - prow = eco
    # eco - prow = sqr
    # eco - prow = fat
    def base_economic_p_row_xxx(self, ndel):
        a, q, r = self.generate('tall', 'economic')
        for row in range(a.shape[0]-ndel):
            q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
            a1 = np.delete(a, slice(row, row+ndel), 0)
            check_qr(q1, r1, a1, self.rtol, self.atol, False)

    def test_economic_p_row_economic(self):
        # (12, 7) - (3, 7) = (9,7) --> stays economic
        self.base_economic_p_row_xxx(3)

    def test_economic_p_row_sqr(self):
        # (12, 7) - (5, 7) = (7, 7) --> becomes square
        self.base_economic_p_row_xxx(5)

    def test_economic_p_row_fat(self):
        # (12, 7) - (7,7) = (5, 7) --> becomes fat
        self.base_economic_p_row_xxx(7)

    def test_economic_1_col(self):
        a, q, r = self.generate('tall', 'economic')
        for col in range(r.shape[1]):
            q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
            a1 = np.delete(a, col, 1)
            check_qr(q1, r1, a1, self.rtol, self.atol, False)

    def test_economic_p_col(self):
        a, q, r = self.generate('tall', 'economic')
        for ndel in range(2, 6):
            for col in range(r.shape[1]-ndel):
                q1, r1 = qr_delete(q, r, col, ndel, which='col',
                                   overwrite_qr=False)
                a1 = np.delete(a, slice(col, col+ndel), 1)
                check_qr(q1, r1, a1, self.rtol, self.atol, False)

    def test_Mx1_1_row(self):
        a, q, r = self.generate('Mx1')
        for row in range(r.shape[0]):
            q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
            a1 = np.delete(a, row, 0)
            check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_Mx1_p_row(self):
        a, q, r = self.generate('Mx1')
        for ndel in range(2, 6):
            for row in range(a.shape[0]-ndel):
                q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
                a1 = np.delete(a, slice(row, row+ndel), 0)
                check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_1xN_1_col(self):
        a, q, r = self.generate('1xN')
        for col in range(r.shape[1]):
            q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
            a1 = np.delete(a, col, 1)
            check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_1xN_p_col(self):
        a, q, r = self.generate('1xN')
        for ndel in range(2, 6):
            for col in range(r.shape[1]-ndel):
                q1, r1 = qr_delete(q, r, col, ndel, which='col',
                                   overwrite_qr=False)
                a1 = np.delete(a, slice(col, col+ndel), 1)
                check_qr(q1, r1, a1, self.rtol, self.atol)

    def test_Mx1_economic_1_row(self):
        a, q, r = self.generate('Mx1', 'economic')
        for row in range(r.shape[0]):
            q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
            a1 = np.delete(a, row, 0)
            check_qr(q1, r1, a1, self.rtol, self.atol, False)

    def test_Mx1_economic_p_row(self):
        a, q, r = self.generate('Mx1', 'economic')
        for ndel in range(2, 6):
            for row in range(a.shape[0]-ndel):
                q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
                a1 = np.delete(a, slice(row, row+ndel), 0)
                check_qr(q1, r1, a1, self.rtol, self.atol, False)

    def test_delete_last_1_row(self):
        # full and eco are the same for 1xN 
        a, q, r = self.generate('1xN')
        q1, r1 = qr_delete(q, r, 0, 1, 'row')
        assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
        assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))

    def test_delete_last_p_row(self):
        a, q, r = self.generate('tall', 'full')
        q1, r1 = qr_delete(q, r, 0, a.shape[0], 'row')
        assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
        assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))

        a, q, r = self.generate('tall', 'economic')
        q1, r1 = qr_delete(q, r, 0, a.shape[0], 'row')
        assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
        assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))

    def test_delete_last_1_col(self):
        a, q, r = self.generate('Mx1', 'economic')
        q1, r1 = qr_delete(q, r, 0, 1, 'col')
        assert_equal(q1, np.ndarray(shape=(q.shape[0], 0), dtype=q.dtype))
        assert_equal(r1, np.ndarray(shape=(0, 0), dtype=r.dtype))

        a, q, r = self.generate('Mx1', 'full')
        q1, r1 = qr_delete(q, r, 0, 1, 'col')
        assert_unitary(q1)
        assert_(q1.dtype == q.dtype)
        assert_(q1.shape == q.shape)
        assert_equal(r1, np.ndarray(shape=(r.shape[0], 0), dtype=r.dtype))

    def test_delete_last_p_col(self):
        a, q, r = self.generate('tall', 'full')
        q1, r1 = qr_delete(q, r, 0, a.shape[1], 'col')
        assert_unitary(q1)
        assert_(q1.dtype == q.dtype)
        assert_(q1.shape == q.shape)
        assert_equal(r1, np.ndarray(shape=(r.shape[0], 0), dtype=r.dtype))

        a, q, r = self.generate('tall', 'economic')
        q1, r1 = qr_delete(q, r, 0, a.shape[1], 'col')
        assert_equal(q1, np.ndarray(shape=(q.shape[0], 0), dtype=q.dtype))
        assert_equal(r1, np.ndarray(shape=(0, 0), dtype=r.dtype))

    def test_delete_1x1_row_col(self):
        a, q, r = self.generate('1x1')
        q1, r1 = qr_delete(q, r, 0, 1, 'row')
        assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
        assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))

        a, q, r = self.generate('1x1')
        q1, r1 = qr_delete(q, r, 0, 1, 'col')
        assert_unitary(q1)
Loading ...