Repository URL to install this package:
Version:
0.15.1 ▾
|
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
from os.path import join, dirname
import numpy as np
from numpy.testing import assert_array_almost_equal, assert_equal, TestCase
from scipy.fftpack.realtransforms import dct, idct, dst, idst
# Matlab reference data
MDATA = np.load(join(dirname(__file__), 'test.npz'))
X = [MDATA['x%d' % i] for i in range(8)]
Y = [MDATA['y%d' % i] for i in range(8)]
# FFTW reference data: the data are organized as follows:
# * SIZES is an array containing all available sizes
# * for every type (1, 2, 3, 4) and every size, the array dct_type_size
# contains the output of the DCT applied to the input np.linspace(0, size-1,
# size)
FFTWDATA_DOUBLE = np.load(join(dirname(__file__), 'fftw_double_ref.npz'))
FFTWDATA_SINGLE = np.load(join(dirname(__file__), 'fftw_single_ref.npz'))
FFTWDATA_SIZES = FFTWDATA_DOUBLE['sizes']
def fftw_dct_ref(type, size, dt):
x = np.linspace(0, size-1, size).astype(dt)
if dt == np.double:
data = FFTWDATA_DOUBLE
elif dt == np.float32:
data = FFTWDATA_SINGLE
else:
raise ValueError()
y = (data['dct_%d_%d' % (type, size)]).astype(dt)
return x, y
def fftw_dst_ref(type, size, dt):
x = np.linspace(0, size-1, size).astype(dt)
if dt == np.double:
data = FFTWDATA_DOUBLE
elif dt == np.float32:
data = FFTWDATA_SINGLE
else:
raise ValueError()
y = (data['dst_%d_%d' % (type, size)]).astype(dt)
return x, y
class _TestDCTBase(TestCase):
def setUp(self):
self.rdt = None
self.dec = 14
self.type = None
def test_definition(self):
for i in FFTWDATA_SIZES:
x, yr = fftw_dct_ref(self.type, i, self.rdt)
y = dct(x, type=self.type)
self.assertTrue(y.dtype == self.rdt,
"Output dtype is %s, expected %s" % (y.dtype, self.rdt))
# XXX: we divide by np.max(y) because the tests fail otherwise. We
# should really use something like assert_array_approx_equal. The
# difference is due to fftw using a better algorithm w.r.t error
# propagation compared to the ones from fftpack.
assert_array_almost_equal(y / np.max(y), yr / np.max(y), decimal=self.dec,
err_msg="Size %d failed" % i)
def test_axis(self):
nt = 2
for i in [7, 8, 9, 16, 32, 64]:
x = np.random.randn(nt, i)
y = dct(x, type=self.type)
for j in range(nt):
assert_array_almost_equal(y[j], dct(x[j], type=self.type),
decimal=self.dec)
x = x.T
y = dct(x, axis=0, type=self.type)
for j in range(nt):
assert_array_almost_equal(y[:,j], dct(x[:,j], type=self.type),
decimal=self.dec)
class _TestDCTIIBase(_TestDCTBase):
def test_definition_matlab(self):
"""Test correspondance with matlab (orthornomal mode)."""
for i in range(len(X)):
x = np.array(X[i], dtype=self.rdt)
yr = Y[i]
y = dct(x, norm="ortho", type=2)
self.assertTrue(y.dtype == self.rdt,
"Output dtype is %s, expected %s" % (y.dtype, self.rdt))
assert_array_almost_equal(y, yr, decimal=self.dec)
class _TestDCTIIIBase(_TestDCTBase):
def test_definition_ortho(self):
"""Test orthornomal mode."""
for i in range(len(X)):
x = np.array(X[i], dtype=self.rdt)
y = dct(x, norm='ortho', type=2)
xi = dct(y, norm="ortho", type=3)
self.assertTrue(xi.dtype == self.rdt,
"Output dtype is %s, expected %s" % (xi.dtype, self.rdt))
assert_array_almost_equal(xi, x, decimal=self.dec)
class TestDCTIDouble(_TestDCTBase):
def setUp(self):
self.rdt = np.double
self.dec = 10
self.type = 1
class TestDCTIFloat(_TestDCTBase):
def setUp(self):
self.rdt = np.float32
self.dec = 5
self.type = 1
class TestDCTIIDouble(_TestDCTIIBase):
def setUp(self):
self.rdt = np.double
self.dec = 10
self.type = 2
class TestDCTIIFloat(_TestDCTIIBase):
def setUp(self):
self.rdt = np.float32
self.dec = 5
self.type = 2
class TestDCTIIIDouble(_TestDCTIIIBase):
def setUp(self):
self.rdt = np.double
self.dec = 14
self.type = 3
class TestDCTIIIFloat(_TestDCTIIIBase):
def setUp(self):
self.rdt = np.float32
self.dec = 5
self.type = 3
class _TestIDCTBase(TestCase):
def setUp(self):
self.rdt = None
self.dec = 14
self.type = None
def test_definition(self):
for i in FFTWDATA_SIZES:
xr, yr = fftw_dct_ref(self.type, i, self.rdt)
x = idct(yr, type=self.type)
if self.type == 1:
x /= 2 * (i-1)
else:
x /= 2 * i
self.assertTrue(x.dtype == self.rdt,
"Output dtype is %s, expected %s" % (x.dtype, self.rdt))
# XXX: we divide by np.max(y) because the tests fail otherwise. We
# should really use something like assert_array_approx_equal. The
# difference is due to fftw using a better algorithm w.r.t error
# propagation compared to the ones from fftpack.
assert_array_almost_equal(x / np.max(x), xr / np.max(x), decimal=self.dec,
err_msg="Size %d failed" % i)
class TestIDCTIDouble(_TestIDCTBase):
def setUp(self):
self.rdt = np.double
self.dec = 10
self.type = 1
class TestIDCTIFloat(_TestIDCTBase):
def setUp(self):
self.rdt = np.float32
self.dec = 4
self.type = 1
class TestIDCTIIDouble(_TestIDCTBase):
def setUp(self):
self.rdt = np.double
self.dec = 10
self.type = 2
class TestIDCTIIFloat(_TestIDCTBase):
def setUp(self):
self.rdt = np.float32
self.dec = 5
self.type = 2
class TestIDCTIIIDouble(_TestIDCTBase):
def setUp(self):
self.rdt = np.double
self.dec = 14
self.type = 3
class TestIDCTIIIFloat(_TestIDCTBase):
def setUp(self):
self.rdt = np.float32
self.dec = 5
self.type = 3
class _TestDSTBase(TestCase):
def setUp(self):
self.rdt = None # dtype
self.dec = None # number of decimals to match
self.type = None # dst type
def test_definition(self):
for i in FFTWDATA_SIZES:
xr, yr = fftw_dst_ref(self.type, i, self.rdt)
y = dst(xr, type=self.type)
self.assertTrue(y.dtype == self.rdt,
"Output dtype is %s, expected %s" % (y.dtype, self.rdt))
# XXX: we divide by np.max(y) because the tests fail otherwise. We
# should really use something like assert_array_approx_equal. The
# difference is due to fftw using a better algorithm w.r.t error
# propagation compared to the ones from fftpack.
assert_array_almost_equal(y / np.max(y), yr / np.max(y), decimal=self.dec,
err_msg="Size %d failed" % i)
class TestDSTIDouble(_TestDSTBase):
def setUp(self):
self.rdt = np.double
self.dec = 14
self.type = 1
class TestDSTIFloat(_TestDSTBase):
def setUp(self):
self.rdt = np.float32
self.dec = 5
self.type = 1
class TestDSTIIDouble(_TestDSTBase):
def setUp(self):
self.rdt = np.double
self.dec = 14
self.type = 2
class TestDSTIIFloat(_TestDSTBase):
def setUp(self):
self.rdt = np.float32
self.dec = 6
self.type = 2
class TestDSTIIIDouble(_TestDSTBase):
def setUp(self):
self.rdt = np.double
self.dec = 14
self.type = 3
class TestDSTIIIFloat(_TestDSTBase):
def setUp(self):
self.rdt = np.float32
self.dec = 7
self.type = 3
class _TestIDSTBase(TestCase):
def setUp(self):
self.rdt = None
self.dec = None
self.type = None
def test_definition(self):
for i in FFTWDATA_SIZES:
xr, yr = fftw_dst_ref(self.type, i, self.rdt)
x = idst(yr, type=self.type)
if self.type == 1:
x /= 2 * (i+1)
else:
x /= 2 * i
self.assertTrue(x.dtype == self.rdt,
"Output dtype is %s, expected %s" % (x.dtype, self.rdt))
# XXX: we divide by np.max(x) because the tests fail otherwise. We
# should really use something like assert_array_approx_equal. The
# difference is due to fftw using a better algorithm w.r.t error
# propagation compared to the ones from fftpack.
assert_array_almost_equal(x / np.max(x), xr / np.max(x), decimal=self.dec,
err_msg="Size %d failed" % i)
class TestIDSTIDouble(_TestIDSTBase):
def setUp(self):
self.rdt = np.double
self.dec = 12
self.type = 1
class TestIDSTIFloat(_TestIDSTBase):
def setUp(self):
self.rdt = np.float32
self.dec = 4
self.type = 1
class TestIDSTIIDouble(_TestIDSTBase):
def setUp(self):
self.rdt = np.double
self.dec = 14
self.type = 2
class TestIDSTIIFloat(_TestIDSTBase):
def setUp(self):
self.rdt = np.float32
self.dec = 6
self.type = 2
class TestIDSTIIIDouble(_TestIDSTBase):
def setUp(self):
self.rdt = np.double
self.dec = 14
self.type = 3
class TestIDSTIIIFloat(_TestIDSTBase):
def setUp(self):
self.rdt = np.float32
self.dec = 6
self.type = 3
class TestOverwrite(object):
"""
Check input overwrite behavior
"""
real_dtypes = [np.float32, np.float64]
def _check(self, x, routine, type, fftsize, axis, norm, overwrite_x,
should_overwrite, **kw):
x2 = x.copy()
y = routine(x2, type, fftsize, axis, norm, overwrite_x=overwrite_x)
sig = "%s(%s%r, %r, axis=%r, overwrite_x=%r)" % (
routine.__name__, x.dtype, x.shape, fftsize, axis, overwrite_x)
if not should_overwrite:
assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig)
def _check_1d(self, routine, dtype, shape, axis, overwritable_dtypes):
np.random.seed(1234)
if np.issubdtype(dtype, np.complexfloating):
data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
else:
data = np.random.randn(*shape)
data = data.astype(dtype)
for type in [1, 2, 3]:
for overwrite_x in [True, False]:
for norm in [None, 'ortho']:
if type == 1 and norm == 'ortho':
continue
should_overwrite = (overwrite_x
and dtype in overwritable_dtypes
and (len(shape) == 1 or
(axis % len(shape) == len(shape)-1
)))
self._check(data, routine, type, None, axis, norm,
overwrite_x, should_overwrite)
def test_dct(self):
overwritable = self.real_dtypes
for dtype in self.real_dtypes:
self._check_1d(dct, dtype, (16,), -1, overwritable)
self._check_1d(dct, dtype, (16, 2), 0, overwritable)
self._check_1d(dct, dtype, (2, 16), 1, overwritable)
def test_idct(self):
overwritable = self.real_dtypes
for dtype in self.real_dtypes:
self._check_1d(idct, dtype, (16,), -1, overwritable)
self._check_1d(idct, dtype, (16, 2), 0, overwritable)
self._check_1d(idct, dtype, (2, 16), 1, overwritable)
def test_dst(self):
overwritable = self.real_dtypes
for dtype in self.real_dtypes:
self._check_1d(dst, dtype, (16,), -1, overwritable)
self._check_1d(dst, dtype, (16, 2), 0, overwritable)
self._check_1d(dst, dtype, (2, 16), 1, overwritable)
def test_idst(self):
overwritable = self.real_dtypes
for dtype in self.real_dtypes:
self._check_1d(idst, dtype, (16,), -1, overwritable)
self._check_1d(idst, dtype, (16, 2), 0, overwritable)
self._check_1d(idst, dtype, (2, 16), 1, overwritable)
if __name__ == "__main__":
np.testing.run_module_suite()