Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
scikit-learn / decomposition / tests / test_truncated_svd.py
Size: Mime:
"""Test truncated SVD transformer."""

import numpy as np
import scipy.sparse as sp

from sklearn.decomposition import TruncatedSVD
from sklearn.utils import check_random_state
from sklearn.utils.testing import (assert_array_almost_equal, assert_equal,
                                   assert_raises, assert_almost_equal,
                                   assert_greater, assert_array_less)


# Make an X that looks somewhat like a small tf-idf matrix.
# XXX newer versions of SciPy have scipy.sparse.rand for this.
shape = 60, 55
n_samples, n_features = shape
rng = check_random_state(42)
X = rng.randint(-100, 20, np.product(shape)).reshape(shape)
X = sp.csr_matrix(np.maximum(X, 0), dtype=np.float64)
X.data[:] = 1 + np.log(X.data)
Xdense = X.A


def test_algorithms():
    svd_a = TruncatedSVD(30, algorithm="arpack")
    svd_r = TruncatedSVD(30, algorithm="randomized", random_state=42)

    Xa = svd_a.fit_transform(X)[:, :6]
    Xr = svd_r.fit_transform(X)[:, :6]
    assert_array_almost_equal(Xa, Xr)

    comp_a = np.abs(svd_a.components_)
    comp_r = np.abs(svd_r.components_)
    # All elements are equal, but some elements are more equal than others.
    assert_array_almost_equal(comp_a[:9], comp_r[:9])
    assert_array_almost_equal(comp_a[9:], comp_r[9:], decimal=3)


def test_attributes():
    for n_components in (10, 25, 41):
        tsvd = TruncatedSVD(n_components).fit(X)
        assert_equal(tsvd.n_components, n_components)
        assert_equal(tsvd.components_.shape, (n_components, n_features))


def test_too_many_components():
    for algorithm in ["arpack", "randomized"]:
        for n_components in (n_features, n_features+1):
            tsvd = TruncatedSVD(n_components=n_components, algorithm=algorithm)
            assert_raises(ValueError, tsvd.fit, X)


def test_sparse_formats():
    for fmt in ("array", "csr", "csc", "coo", "lil"):
        Xfmt = Xdense if fmt == "dense" else getattr(X, "to" + fmt)()
        tsvd = TruncatedSVD(n_components=11)
        Xtrans = tsvd.fit_transform(Xfmt)
        assert_equal(Xtrans.shape, (n_samples, 11))
        Xtrans = tsvd.transform(Xfmt)
        assert_equal(Xtrans.shape, (n_samples, 11))


def test_inverse_transform():
    for algo in ("arpack", "randomized"):
        # We need a lot of components for the reconstruction to be "almost
        # equal" in all positions. XXX Test means or sums instead?
        tsvd = TruncatedSVD(n_components=52, random_state=42)
        Xt = tsvd.fit_transform(X)
        Xinv = tsvd.inverse_transform(Xt)
        assert_array_almost_equal(Xinv, Xdense, decimal=1)


def test_integers():
    Xint = X.astype(np.int64)
    tsvd = TruncatedSVD(n_components=6)
    Xtrans = tsvd.fit_transform(Xint)
    assert_equal(Xtrans.shape, (n_samples, tsvd.n_components))


def test_explained_variance():
    # Test sparse data
    svd_a_10_sp = TruncatedSVD(10, algorithm="arpack")
    svd_r_10_sp = TruncatedSVD(10, algorithm="randomized", random_state=42)
    svd_a_20_sp = TruncatedSVD(20, algorithm="arpack")
    svd_r_20_sp = TruncatedSVD(20, algorithm="randomized", random_state=42)
    X_trans_a_10_sp = svd_a_10_sp.fit_transform(X)
    X_trans_r_10_sp = svd_r_10_sp.fit_transform(X)
    X_trans_a_20_sp = svd_a_20_sp.fit_transform(X)
    X_trans_r_20_sp = svd_r_20_sp.fit_transform(X)

    # Test dense data
    svd_a_10_de = TruncatedSVD(10, algorithm="arpack")
    svd_r_10_de = TruncatedSVD(10, algorithm="randomized", random_state=42)
    svd_a_20_de = TruncatedSVD(20, algorithm="arpack")
    svd_r_20_de = TruncatedSVD(20, algorithm="randomized", random_state=42)
    X_trans_a_10_de = svd_a_10_de.fit_transform(X.toarray())
    X_trans_r_10_de = svd_r_10_de.fit_transform(X.toarray())
    X_trans_a_20_de = svd_a_20_de.fit_transform(X.toarray())
    X_trans_r_20_de = svd_r_20_de.fit_transform(X.toarray())

    # helper arrays for tests below
    svds = (svd_a_10_sp, svd_r_10_sp, svd_a_20_sp, svd_r_20_sp, svd_a_10_de,
            svd_r_10_de, svd_a_20_de, svd_r_20_de)
    svds_trans = (
        (svd_a_10_sp, X_trans_a_10_sp),
        (svd_r_10_sp, X_trans_r_10_sp),
        (svd_a_20_sp, X_trans_a_20_sp),
        (svd_r_20_sp, X_trans_r_20_sp),
        (svd_a_10_de, X_trans_a_10_de),
        (svd_r_10_de, X_trans_r_10_de),
        (svd_a_20_de, X_trans_a_20_de),
        (svd_r_20_de, X_trans_r_20_de),
    )
    svds_10_v_20 = (
        (svd_a_10_sp, svd_a_20_sp),
        (svd_r_10_sp, svd_r_20_sp),
        (svd_a_10_de, svd_a_20_de),
        (svd_r_10_de, svd_r_20_de),
    )
    svds_sparse_v_dense = (
        (svd_a_10_sp, svd_a_10_de),
        (svd_a_20_sp, svd_a_20_de),
        (svd_r_10_sp, svd_r_10_de),
        (svd_r_20_sp, svd_r_20_de),
    )

    # Assert the 1st component is equal
    for svd_10, svd_20 in svds_10_v_20:
        assert_array_almost_equal(
            svd_10.explained_variance_ratio_,
            svd_20.explained_variance_ratio_[:10],
            decimal=5,
        )

    # Assert that 20 components has higher explained variance than 10
    for svd_10, svd_20 in svds_10_v_20:
        assert_greater(
            svd_20.explained_variance_ratio_.sum(),
            svd_10.explained_variance_ratio_.sum(),
        )

    # Assert that all the values are greater than 0
    for svd in svds:
        assert_array_less(0.0, svd.explained_variance_ratio_)

    # Assert that total explained variance is less than 1
    for svd in svds:
        assert_array_less(svd.explained_variance_ratio_.sum(), 1.0)

    # Compare sparse vs. dense
    for svd_sparse, svd_dense in svds_sparse_v_dense:
        assert_array_almost_equal(svd_sparse.explained_variance_ratio_,
                svd_dense.explained_variance_ratio_)

    # Test that explained_variance is correct
    for svd, transformed in svds_trans:
        total_variance = np.var(X.toarray(), axis=0).sum()
        variances = np.var(transformed, axis=0)
        true_explained_variance_ratio = variances / total_variance

        assert_array_almost_equal(
            svd.explained_variance_ratio_,
            true_explained_variance_ratio,
        )