from __future__ import division, print_function, absolute_import
import itertools
import numpy as np
from numpy.testing import assert_, assert_allclose, assert_equal
from pytest import raises as assert_raises
from scipy import linalg
import scipy.linalg._decomp_update as _decomp_update
from scipy.linalg._decomp_update import *
def assert_unitary(a, rtol=None, atol=None, assert_sqr=True):
if rtol is None:
rtol = 10.0 ** -(np.finfo(a.dtype).precision-2)
if atol is None:
atol = 2*np.finfo(a.dtype).eps
if assert_sqr:
assert_(a.shape[0] == a.shape[1], 'unitary matrices must be square')
aTa = np.dot(a.T.conj(), a)
assert_allclose(aTa, np.eye(a.shape[1]), rtol=rtol, atol=atol)
def assert_upper_tri(a, rtol=None, atol=None):
if rtol is None:
rtol = 10.0 ** -(np.finfo(a.dtype).precision-2)
if atol is None:
atol = 2*np.finfo(a.dtype).eps
mask = np.tri(a.shape[0], a.shape[1], -1, np.bool_)
assert_allclose(a[mask], 0.0, rtol=rtol, atol=atol)
def check_qr(q, r, a, rtol, atol, assert_sqr=True):
assert_unitary(q, rtol, atol, assert_sqr)
assert_upper_tri(r, rtol, atol)
assert_allclose(q.dot(r), a, rtol=rtol, atol=atol)
def make_strided(arrs):
strides = [(3, 7), (2, 2), (3, 4), (4, 2), (5, 4), (2, 3), (2, 1), (4, 5)]
kmax = len(strides)
k = 0
ret = []
for a in arrs:
if a.ndim == 1:
s = strides[k % kmax]
k += 1
base = np.zeros(s[0]*a.shape[0]+s[1], a.dtype)
view = base[s[1]::s[0]]
view[...] = a
elif a.ndim == 2:
s = strides[k % kmax]
t = strides[(k+1) % kmax]
k += 2
base = np.zeros((s[0]*a.shape[0]+s[1], t[0]*a.shape[1]+t[1]), a.dtype)
view = base[s[1]::s[0], t[1]::t[0]]
view[...] = a
else:
raise ValueError('make_strided only works for ndim = 1 or 2 arrays')
ret.append(view)
return ret
def negate_strides(arrs):
ret = []
for a in arrs:
b = np.zeros_like(a)
if b.ndim == 2:
b = b[::-1, ::-1]
elif b.ndim == 1:
b = b[::-1]
else:
raise ValueError('negate_strides only works for ndim = 1 or 2 arrays')
b[...] = a
ret.append(b)
return ret
def nonitemsize_strides(arrs):
out = []
for a in arrs:
a_dtype = a.dtype
b = np.zeros(a.shape, [('a', a_dtype), ('junk', 'S1')])
c = b.getfield(a_dtype)
c[...] = a
out.append(c)
return out
def make_nonnative(arrs):
return [a.astype(a.dtype.newbyteorder()) for a in arrs]
class BaseQRdeltas(object):
def setup_method(self):
self.rtol = 10.0 ** -(np.finfo(self.dtype).precision-2)
self.atol = 10 * np.finfo(self.dtype).eps
def generate(self, type, mode='full'):
np.random.seed(29382)
shape = {'sqr': (8, 8), 'tall': (12, 7), 'fat': (7, 12),
'Mx1': (8, 1), '1xN': (1, 8), '1x1': (1, 1)}[type]
a = np.random.random(shape)
if np.iscomplexobj(self.dtype.type(1)):
b = np.random.random(shape)
a = a + 1j * b
a = a.astype(self.dtype)
q, r = linalg.qr(a, mode=mode)
return a, q, r
class BaseQRdelete(BaseQRdeltas):
def test_sqr_1_row(self):
a, q, r = self.generate('sqr')
for row in range(r.shape[0]):
q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
a1 = np.delete(a, row, 0)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_sqr_p_row(self):
a, q, r = self.generate('sqr')
for ndel in range(2, 6):
for row in range(a.shape[0]-ndel):
q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
a1 = np.delete(a, slice(row, row+ndel), 0)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_sqr_1_col(self):
a, q, r = self.generate('sqr')
for col in range(r.shape[1]):
q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
a1 = np.delete(a, col, 1)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_sqr_p_col(self):
a, q, r = self.generate('sqr')
for ndel in range(2, 6):
for col in range(r.shape[1]-ndel):
q1, r1 = qr_delete(q, r, col, ndel, which='col',
overwrite_qr=False)
a1 = np.delete(a, slice(col, col+ndel), 1)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_tall_1_row(self):
a, q, r = self.generate('tall')
for row in range(r.shape[0]):
q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
a1 = np.delete(a, row, 0)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_tall_p_row(self):
a, q, r = self.generate('tall')
for ndel in range(2, 6):
for row in range(a.shape[0]-ndel):
q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
a1 = np.delete(a, slice(row, row+ndel), 0)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_tall_1_col(self):
a, q, r = self.generate('tall')
for col in range(r.shape[1]):
q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
a1 = np.delete(a, col, 1)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_tall_p_col(self):
a, q, r = self.generate('tall')
for ndel in range(2, 6):
for col in range(r.shape[1]-ndel):
q1, r1 = qr_delete(q, r, col, ndel, which='col',
overwrite_qr=False)
a1 = np.delete(a, slice(col, col+ndel), 1)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_fat_1_row(self):
a, q, r = self.generate('fat')
for row in range(r.shape[0]):
q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
a1 = np.delete(a, row, 0)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_fat_p_row(self):
a, q, r = self.generate('fat')
for ndel in range(2, 6):
for row in range(a.shape[0]-ndel):
q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
a1 = np.delete(a, slice(row, row+ndel), 0)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_fat_1_col(self):
a, q, r = self.generate('fat')
for col in range(r.shape[1]):
q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
a1 = np.delete(a, col, 1)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_fat_p_col(self):
a, q, r = self.generate('fat')
for ndel in range(2, 6):
for col in range(r.shape[1]-ndel):
q1, r1 = qr_delete(q, r, col, ndel, which='col',
overwrite_qr=False)
a1 = np.delete(a, slice(col, col+ndel), 1)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_economic_1_row(self):
# this test always starts and ends with an economic decomp.
a, q, r = self.generate('tall', 'economic')
for row in range(r.shape[0]):
q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
a1 = np.delete(a, row, 0)
check_qr(q1, r1, a1, self.rtol, self.atol, False)
# for economic row deletes
# eco - prow = eco
# eco - prow = sqr
# eco - prow = fat
def base_economic_p_row_xxx(self, ndel):
a, q, r = self.generate('tall', 'economic')
for row in range(a.shape[0]-ndel):
q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
a1 = np.delete(a, slice(row, row+ndel), 0)
check_qr(q1, r1, a1, self.rtol, self.atol, False)
def test_economic_p_row_economic(self):
# (12, 7) - (3, 7) = (9,7) --> stays economic
self.base_economic_p_row_xxx(3)
def test_economic_p_row_sqr(self):
# (12, 7) - (5, 7) = (7, 7) --> becomes square
self.base_economic_p_row_xxx(5)
def test_economic_p_row_fat(self):
# (12, 7) - (7,7) = (5, 7) --> becomes fat
self.base_economic_p_row_xxx(7)
def test_economic_1_col(self):
a, q, r = self.generate('tall', 'economic')
for col in range(r.shape[1]):
q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
a1 = np.delete(a, col, 1)
check_qr(q1, r1, a1, self.rtol, self.atol, False)
def test_economic_p_col(self):
a, q, r = self.generate('tall', 'economic')
for ndel in range(2, 6):
for col in range(r.shape[1]-ndel):
q1, r1 = qr_delete(q, r, col, ndel, which='col',
overwrite_qr=False)
a1 = np.delete(a, slice(col, col+ndel), 1)
check_qr(q1, r1, a1, self.rtol, self.atol, False)
def test_Mx1_1_row(self):
a, q, r = self.generate('Mx1')
for row in range(r.shape[0]):
q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
a1 = np.delete(a, row, 0)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_Mx1_p_row(self):
a, q, r = self.generate('Mx1')
for ndel in range(2, 6):
for row in range(a.shape[0]-ndel):
q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
a1 = np.delete(a, slice(row, row+ndel), 0)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_1xN_1_col(self):
a, q, r = self.generate('1xN')
for col in range(r.shape[1]):
q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
a1 = np.delete(a, col, 1)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_1xN_p_col(self):
a, q, r = self.generate('1xN')
for ndel in range(2, 6):
for col in range(r.shape[1]-ndel):
q1, r1 = qr_delete(q, r, col, ndel, which='col',
overwrite_qr=False)
a1 = np.delete(a, slice(col, col+ndel), 1)
check_qr(q1, r1, a1, self.rtol, self.atol)
def test_Mx1_economic_1_row(self):
a, q, r = self.generate('Mx1', 'economic')
for row in range(r.shape[0]):
q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
a1 = np.delete(a, row, 0)
check_qr(q1, r1, a1, self.rtol, self.atol, False)
def test_Mx1_economic_p_row(self):
a, q, r = self.generate('Mx1', 'economic')
for ndel in range(2, 6):
for row in range(a.shape[0]-ndel):
q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
a1 = np.delete(a, slice(row, row+ndel), 0)
check_qr(q1, r1, a1, self.rtol, self.atol, False)
def test_delete_last_1_row(self):
# full and eco are the same for 1xN
a, q, r = self.generate('1xN')
q1, r1 = qr_delete(q, r, 0, 1, 'row')
assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))
def test_delete_last_p_row(self):
a, q, r = self.generate('tall', 'full')
q1, r1 = qr_delete(q, r, 0, a.shape[0], 'row')
assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))
a, q, r = self.generate('tall', 'economic')
q1, r1 = qr_delete(q, r, 0, a.shape[0], 'row')
assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))
def test_delete_last_1_col(self):
a, q, r = self.generate('Mx1', 'economic')
q1, r1 = qr_delete(q, r, 0, 1, 'col')
assert_equal(q1, np.ndarray(shape=(q.shape[0], 0), dtype=q.dtype))
assert_equal(r1, np.ndarray(shape=(0, 0), dtype=r.dtype))
a, q, r = self.generate('Mx1', 'full')
q1, r1 = qr_delete(q, r, 0, 1, 'col')
assert_unitary(q1)
assert_(q1.dtype == q.dtype)
assert_(q1.shape == q.shape)
assert_equal(r1, np.ndarray(shape=(r.shape[0], 0), dtype=r.dtype))
def test_delete_last_p_col(self):
a, q, r = self.generate('tall', 'full')
q1, r1 = qr_delete(q, r, 0, a.shape[1], 'col')
assert_unitary(q1)
assert_(q1.dtype == q.dtype)
assert_(q1.shape == q.shape)
assert_equal(r1, np.ndarray(shape=(r.shape[0], 0), dtype=r.dtype))
a, q, r = self.generate('tall', 'economic')
q1, r1 = qr_delete(q, r, 0, a.shape[1], 'col')
assert_equal(q1, np.ndarray(shape=(q.shape[0], 0), dtype=q.dtype))
assert_equal(r1, np.ndarray(shape=(0, 0), dtype=r.dtype))
def test_delete_1x1_row_col(self):
a, q, r = self.generate('1x1')
q1, r1 = qr_delete(q, r, 0, 1, 'row')
assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))
a, q, r = self.generate('1x1')
q1, r1 = qr_delete(q, r, 0, 1, 'col')
assert_unitary(q1)
Loading ...