Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

aaronreidsmith / scipy   python

Repository URL to install this package:

Version: 1.3.3 

/ sparse / tests / test_base.py

#
# Authors: Travis Oliphant, Ed Schofield, Robert Cimrman, Nathan Bell, and others

""" Test functions for sparse matrices. Each class in the "Matrix class
based tests" section become subclasses of the classes in the "Generic
tests" section. This is done by the functions in the "Tailored base
class for generic tests" section.

"""

from __future__ import division, print_function, absolute_import

__usage__ = """
Build sparse:
  python setup.py build
Run tests if scipy is installed:
  python -c 'import scipy;scipy.sparse.test()'
Run tests if sparse is not installed:
  python tests/test_base.py
"""

import operator
import contextlib
import functools
from distutils.version import LooseVersion

import numpy as np
from scipy._lib.six import xrange, zip as izip
from numpy import (arange, zeros, array, dot, asarray,
                   vstack, ndarray, transpose, diag, kron, inf, conjugate,
                   int8, ComplexWarning)

import random
from numpy.testing import (assert_equal, assert_array_equal,
        assert_array_almost_equal, assert_almost_equal, assert_,
        assert_allclose)
from pytest import raises as assert_raises
from scipy._lib._numpy_compat import suppress_warnings

import scipy.linalg

import scipy.sparse as sparse
from scipy.sparse import (csc_matrix, csr_matrix, dok_matrix,
        coo_matrix, lil_matrix, dia_matrix, bsr_matrix,
        eye, isspmatrix, SparseEfficiencyWarning, issparse)
from scipy.sparse.sputils import (supported_dtypes, isscalarlike,
                                  get_index_dtype, asmatrix, matrix)
from scipy.sparse.linalg import splu, expm, inv

from scipy._lib._version import NumpyVersion
from scipy._lib.decorator import decorator

import pytest


def assert_in(member, collection, msg=None):
    assert_(member in collection, msg=msg if msg is not None else "%r not found in %r" % (member, collection))


def assert_array_equal_dtype(x, y, **kwargs):
    assert_(x.dtype == y.dtype)
    assert_array_equal(x, y, **kwargs)


# Only test matmul operator (A @ B) when available (Python 3.5+)
TEST_MATMUL = hasattr(operator, 'matmul')

sup_complex = suppress_warnings()
sup_complex.filter(ComplexWarning)


def with_64bit_maxval_limit(maxval_limit=None, random=False, fixed_dtype=None,
                            downcast_maxval=None, assert_32bit=False):
    """
    Monkeypatch the maxval threshold at which scipy.sparse switches to
    64-bit index arrays, or make it (pseudo-)random.

    """
    if maxval_limit is None:
        maxval_limit = 10

    if assert_32bit:
        def new_get_index_dtype(arrays=(), maxval=None, check_contents=False):
            tp = get_index_dtype(arrays, maxval, check_contents)
            assert_equal(np.iinfo(tp).max, np.iinfo(np.int32).max)
            assert_(tp == np.int32 or tp == np.intc)
            return tp
    elif fixed_dtype is not None:
        def new_get_index_dtype(arrays=(), maxval=None, check_contents=False):
            return fixed_dtype
    elif random:
        counter = np.random.RandomState(seed=1234)

        def new_get_index_dtype(arrays=(), maxval=None, check_contents=False):
            return (np.int32, np.int64)[counter.randint(2)]
    else:
        def new_get_index_dtype(arrays=(), maxval=None, check_contents=False):
            dtype = np.int32
            if maxval is not None:
                if maxval > maxval_limit:
                    dtype = np.int64
            for arr in arrays:
                arr = np.asarray(arr)
                if arr.dtype > np.int32:
                    if check_contents:
                        if arr.size == 0:
                            # a bigger type not needed
                            continue
                        elif np.issubdtype(arr.dtype, np.integer):
                            maxval = arr.max()
                            minval = arr.min()
                            if minval >= -maxval_limit and maxval <= maxval_limit:
                                # a bigger type not needed
                                continue
                    dtype = np.int64
            return dtype

    if downcast_maxval is not None:
        def new_downcast_intp_index(arr):
            if arr.max() > downcast_maxval:
                raise AssertionError("downcast limited")
            return arr.astype(np.intp)

    @decorator
    def deco(func, *a, **kw):
        backup = []
        modules = [scipy.sparse.bsr, scipy.sparse.coo, scipy.sparse.csc,
                   scipy.sparse.csr, scipy.sparse.dia, scipy.sparse.dok,
                   scipy.sparse.lil, scipy.sparse.sputils,
                   scipy.sparse.compressed, scipy.sparse.construct]
        try:
            for mod in modules:
                backup.append((mod, 'get_index_dtype',
                               getattr(mod, 'get_index_dtype', None)))
                setattr(mod, 'get_index_dtype', new_get_index_dtype)
                if downcast_maxval is not None:
                    backup.append((mod, 'downcast_intp_index',
                                   getattr(mod, 'downcast_intp_index', None)))
                    setattr(mod, 'downcast_intp_index', new_downcast_intp_index)
            return func(*a, **kw)
        finally:
            for mod, name, oldfunc in backup:
                if oldfunc is not None:
                    setattr(mod, name, oldfunc)

    return deco


def todense(a):
    if isinstance(a, np.ndarray) or isscalarlike(a):
        return a
    return a.todense()


class BinopTester(object):
    # Custom type to test binary operations on sparse matrices.

    def __add__(self, mat):
        return "matrix on the right"

    def __mul__(self, mat):
        return "matrix on the right"

    def __sub__(self, mat):
        return "matrix on the right"

    def __radd__(self, mat):
        return "matrix on the left"

    def __rmul__(self, mat):
        return "matrix on the left"

    def __rsub__(self, mat):
        return "matrix on the left"

    def __matmul__(self, mat):
        return "matrix on the right"

    def __rmatmul__(self, mat):
        return "matrix on the left"

class BinopTester_with_shape(object):
    # Custom type to test binary operations on sparse matrices
    # with object which has shape attribute.
    def __init__(self,shape):
        self._shape = shape

    def shape(self):
        return self._shape

    def ndim(self):
        return len(self._shape)

    def __add__(self, mat):
        return "matrix on the right"

    def __mul__(self, mat):
        return "matrix on the right"

    def __sub__(self, mat):
        return "matrix on the right"

    def __radd__(self, mat):
        return "matrix on the left"

    def __rmul__(self, mat):
        return "matrix on the left"

    def __rsub__(self, mat):
        return "matrix on the left"

    def __matmul__(self, mat):
        return "matrix on the right"

    def __rmatmul__(self, mat):
        return "matrix on the left"


#------------------------------------------------------------------------------
# Generic tests
#------------------------------------------------------------------------------


# TODO check that spmatrix( ... , copy=X ) is respected
# TODO test prune
# TODO test has_sorted_indices
class _TestCommon(object):
    """test common functionality shared by all sparse formats"""
    math_dtypes = supported_dtypes

    @classmethod
    def init_class(cls):
        # Canonical data.
        cls.dat = matrix([[1,0,0,2],[3,0,1,0],[0,2,0,0]],'d')
        cls.datsp = cls.spmatrix(cls.dat)

        # Some sparse and dense matrices with data for every supported
        # dtype.
        # This set union is a workaround for numpy#6295, which means that
        # two np.int64 dtypes don't hash to the same value.
        cls.checked_dtypes = set(supported_dtypes).union(cls.math_dtypes)
        cls.dat_dtypes = {}
        cls.datsp_dtypes = {}
        for dtype in cls.checked_dtypes:
            cls.dat_dtypes[dtype] = cls.dat.astype(dtype)
            cls.datsp_dtypes[dtype] = cls.spmatrix(cls.dat.astype(dtype))

        # Check that the original data is equivalent to the
        # corresponding dat_dtypes & datsp_dtypes.
        assert_equal(cls.dat, cls.dat_dtypes[np.float64])
        assert_equal(cls.datsp.todense(),
                     cls.datsp_dtypes[np.float64].todense())

    def test_bool(self):
        def check(dtype):
            datsp = self.datsp_dtypes[dtype]

            assert_raises(ValueError, bool, datsp)
            assert_(self.spmatrix([1]))
            assert_(not self.spmatrix([0]))

        if isinstance(self, TestDOK):
            pytest.skip("Cannot create a rank <= 2 DOK matrix.")
        for dtype in self.checked_dtypes:
            check(dtype)

    def test_bool_rollover(self):
        # bool's underlying dtype is 1 byte, check that it does not
        # rollover True -> False at 256.
        dat = matrix([[True, False]])
        datsp = self.spmatrix(dat)

        for _ in range(10):
            datsp = datsp + datsp
            dat = dat + dat
        assert_array_equal(dat, datsp.todense())

    def test_eq(self):
        sup = suppress_warnings()
        sup.filter(SparseEfficiencyWarning)

        @sup
        @sup_complex
        def check(dtype):
            dat = self.dat_dtypes[dtype]
            datsp = self.datsp_dtypes[dtype]
            dat2 = dat.copy()
            dat2[:,0] = 0
            datsp2 = self.spmatrix(dat2)
            datbsr = bsr_matrix(dat)
            datcsr = csr_matrix(dat)
            datcsc = csc_matrix(dat)
            datlil = lil_matrix(dat)

            # sparse/sparse
            assert_array_equal_dtype(dat == dat2, (datsp == datsp2).todense())
            # mix sparse types
            assert_array_equal_dtype(dat == dat2, (datbsr == datsp2).todense())
            assert_array_equal_dtype(dat == dat2, (datcsr == datsp2).todense())
            assert_array_equal_dtype(dat == dat2, (datcsc == datsp2).todense())
            assert_array_equal_dtype(dat == dat2, (datlil == datsp2).todense())
            # sparse/dense
            assert_array_equal_dtype(dat == datsp2, datsp2 == dat)
            # sparse/scalar
            assert_array_equal_dtype(dat == 0, (datsp == 0).todense())
            assert_array_equal_dtype(dat == 1, (datsp == 1).todense())
            assert_array_equal_dtype(dat == np.nan,
                                     (datsp == np.nan).todense())

        if not isinstance(self, (TestBSR, TestCSC, TestCSR)):
            pytest.skip("Bool comparisons only implemented for BSR, CSC, and CSR.")
        for dtype in self.checked_dtypes:
            check(dtype)

    def test_ne(self):
        sup = suppress_warnings()
        sup.filter(SparseEfficiencyWarning)

        @sup
        @sup_complex
        def check(dtype):
            dat = self.dat_dtypes[dtype]
            datsp = self.datsp_dtypes[dtype]
            dat2 = dat.copy()
            dat2[:,0] = 0
            datsp2 = self.spmatrix(dat2)
            datbsr = bsr_matrix(dat)
            datcsc = csc_matrix(dat)
            datcsr = csr_matrix(dat)
            datlil = lil_matrix(dat)

            # sparse/sparse
            assert_array_equal_dtype(dat != dat2, (datsp != datsp2).todense())
            # mix sparse types
            assert_array_equal_dtype(dat != dat2, (datbsr != datsp2).todense())
            assert_array_equal_dtype(dat != dat2, (datcsc != datsp2).todense())
            assert_array_equal_dtype(dat != dat2, (datcsr != datsp2).todense())
            assert_array_equal_dtype(dat != dat2, (datlil != datsp2).todense())
            # sparse/dense
            assert_array_equal_dtype(dat != datsp2, datsp2 != dat)
            # sparse/scalar
            assert_array_equal_dtype(dat != 0, (datsp != 0).todense())
            assert_array_equal_dtype(dat != 1, (datsp != 1).todense())
            assert_array_equal_dtype(0 != dat, (0 != datsp).todense())
            assert_array_equal_dtype(1 != dat, (1 != datsp).todense())
Loading ...