Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

aaronreidsmith / scikit-learn   python

Repository URL to install this package:

Version: 0.22 

/ utils / tests / test_testing.py

import warnings
import unittest
import sys
import os
import atexit

import numpy as np

from scipy import sparse

import pytest

from sklearn.utils.deprecation import deprecated
from sklearn.utils.metaestimators import if_delegate_has_method
from sklearn.utils._testing import (
    assert_raises,
    assert_less,
    assert_greater,
    assert_less_equal,
    assert_greater_equal,
    assert_warns,
    assert_no_warnings,
    assert_equal,
    assert_not_equal,
    assert_in,
    assert_not_in,
    set_random_state,
    assert_raise_message,
    ignore_warnings,
    check_docstring_parameters,
    assert_allclose_dense_sparse,
    assert_raises_regex,
    TempMemmap,
    create_memmap_backed_data,
    _delete_folder,
    _convert_container)

from sklearn.utils._testing import SkipTest
from sklearn.tree import DecisionTreeClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis


@pytest.mark.filterwarnings("ignore",
                            category=FutureWarning)  # 0.24
def test_assert_less():
    assert 0 < 1
    assert_raises(AssertionError, assert_less, 1, 0)


@pytest.mark.filterwarnings("ignore",
                            category=FutureWarning)  # 0.24
def test_assert_greater():
    assert 1 > 0
    assert_raises(AssertionError, assert_greater, 0, 1)


@pytest.mark.filterwarnings("ignore",
                            category=FutureWarning)  # 0.24
def test_assert_less_equal():
    assert 0 <= 1
    assert 1 <= 1
    assert_raises(AssertionError, assert_less_equal, 1, 0)


@pytest.mark.filterwarnings("ignore",
                            category=FutureWarning)  # 0.24
def test_assert_greater_equal():
    assert 1 >= 0
    assert 1 >= 1
    assert_raises(AssertionError, assert_greater_equal, 0, 1)


def test_set_random_state():
    lda = LinearDiscriminantAnalysis()
    tree = DecisionTreeClassifier()
    # Linear Discriminant Analysis doesn't have random state: smoke test
    set_random_state(lda, 3)
    set_random_state(tree, 3)
    assert tree.random_state == 3


def test_assert_allclose_dense_sparse():
    x = np.arange(9).reshape(3, 3)
    msg = "Not equal to tolerance "
    y = sparse.csc_matrix(x)
    for X in [x, y]:
        # basic compare
        assert_raise_message(AssertionError, msg, assert_allclose_dense_sparse,
                             X, X * 2)
        assert_allclose_dense_sparse(X, X)

    assert_raise_message(ValueError, "Can only compare two sparse",
                         assert_allclose_dense_sparse, x, y)

    A = sparse.diags(np.ones(5), offsets=0).tocsr()
    B = sparse.csr_matrix(np.ones((1, 5)))

    assert_raise_message(AssertionError, "Arrays are not equal",
                         assert_allclose_dense_sparse, B, A)


def test_assert_raises_msg():
    with assert_raises_regex(AssertionError, 'Hello world'):
        with assert_raises(ValueError, msg='Hello world'):
            pass


def test_assert_raise_message():
    def _raise_ValueError(message):
        raise ValueError(message)

    def _no_raise():
        pass

    assert_raise_message(ValueError, "test",
                         _raise_ValueError, "test")

    assert_raises(AssertionError,
                  assert_raise_message, ValueError, "something else",
                  _raise_ValueError, "test")

    assert_raises(ValueError,
                  assert_raise_message, TypeError, "something else",
                  _raise_ValueError, "test")

    assert_raises(AssertionError,
                  assert_raise_message, ValueError, "test",
                  _no_raise)

    # multiple exceptions in a tuple
    assert_raises(AssertionError,
                  assert_raise_message, (ValueError, AttributeError),
                  "test", _no_raise)


def test_ignore_warning():
    # This check that ignore_warning decorateur and context manager are working
    # as expected
    def _warning_function():
        warnings.warn("deprecation warning", DeprecationWarning)

    def _multiple_warning_function():
        warnings.warn("deprecation warning", DeprecationWarning)
        warnings.warn("deprecation warning")

    # Check the function directly
    assert_no_warnings(ignore_warnings(_warning_function))
    assert_no_warnings(ignore_warnings(_warning_function,
                                       category=DeprecationWarning))
    assert_warns(DeprecationWarning, ignore_warnings(_warning_function,
                                                     category=UserWarning))
    assert_warns(UserWarning,
                 ignore_warnings(_multiple_warning_function,
                                 category=FutureWarning))
    assert_warns(DeprecationWarning,
                 ignore_warnings(_multiple_warning_function,
                                 category=UserWarning))
    assert_no_warnings(ignore_warnings(_warning_function,
                                       category=(DeprecationWarning,
                                                 UserWarning)))

    # Check the decorator
    @ignore_warnings
    def decorator_no_warning():
        _warning_function()
        _multiple_warning_function()

    @ignore_warnings(category=(DeprecationWarning, UserWarning))
    def decorator_no_warning_multiple():
        _multiple_warning_function()

    @ignore_warnings(category=DeprecationWarning)
    def decorator_no_deprecation_warning():
        _warning_function()

    @ignore_warnings(category=UserWarning)
    def decorator_no_user_warning():
        _warning_function()

    @ignore_warnings(category=DeprecationWarning)
    def decorator_no_deprecation_multiple_warning():
        _multiple_warning_function()

    @ignore_warnings(category=UserWarning)
    def decorator_no_user_multiple_warning():
        _multiple_warning_function()

    assert_no_warnings(decorator_no_warning)
    assert_no_warnings(decorator_no_warning_multiple)
    assert_no_warnings(decorator_no_deprecation_warning)
    assert_warns(DeprecationWarning, decorator_no_user_warning)
    assert_warns(UserWarning, decorator_no_deprecation_multiple_warning)
    assert_warns(DeprecationWarning, decorator_no_user_multiple_warning)

    # Check the context manager
    def context_manager_no_warning():
        with ignore_warnings():
            _warning_function()

    def context_manager_no_warning_multiple():
        with ignore_warnings(category=(DeprecationWarning, UserWarning)):
            _multiple_warning_function()

    def context_manager_no_deprecation_warning():
        with ignore_warnings(category=DeprecationWarning):
            _warning_function()

    def context_manager_no_user_warning():
        with ignore_warnings(category=UserWarning):
            _warning_function()

    def context_manager_no_deprecation_multiple_warning():
        with ignore_warnings(category=DeprecationWarning):
            _multiple_warning_function()

    def context_manager_no_user_multiple_warning():
        with ignore_warnings(category=UserWarning):
            _multiple_warning_function()

    assert_no_warnings(context_manager_no_warning)
    assert_no_warnings(context_manager_no_warning_multiple)
    assert_no_warnings(context_manager_no_deprecation_warning)
    assert_warns(DeprecationWarning, context_manager_no_user_warning)
    assert_warns(UserWarning, context_manager_no_deprecation_multiple_warning)
    assert_warns(DeprecationWarning, context_manager_no_user_multiple_warning)

    # Check that passing warning class as first positional argument
    warning_class = UserWarning
    match = "'obj' should be a callable.+you should use 'category=UserWarning'"

    with pytest.raises(ValueError, match=match):
        silence_warnings_func = ignore_warnings(warning_class)(
            _warning_function)
        silence_warnings_func()

    with pytest.raises(ValueError, match=match):
        @ignore_warnings(warning_class)
        def test():
            pass


class TestWarns(unittest.TestCase):
    def test_warn(self):
        def f():
            warnings.warn("yo")
            return 3

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", UserWarning)
            filters_orig = warnings.filters[:]
            assert assert_warns(UserWarning, f) == 3
            # test that assert_warns doesn't have side effects on warnings
            # filters
            assert warnings.filters == filters_orig

        assert_raises(AssertionError, assert_no_warnings, f)
        assert assert_no_warnings(lambda x: x, 1) == 1

    def test_warn_wrong_warning(self):
        def f():
            warnings.warn("yo", FutureWarning)

        failed = False
        filters = sys.modules['warnings'].filters[:]
        try:
            try:
                # Should raise an AssertionError
                assert_warns(UserWarning, f)
                failed = True
            except AssertionError:
                pass
        finally:
            sys.modules['warnings'].filters = filters

        if failed:
            raise AssertionError("wrong warning caught by assert_warn")


# Tests for docstrings:

def f_ok(a, b):
    """Function f

    Parameters
    ----------
    a : int
        Parameter a
    b : float
        Parameter b

    Returns
    -------
    c : list
        Parameter c
    """
    c = a + b
    return c


def f_bad_sections(a, b):
    """Function f

    Parameters
    ----------
    a : int
        Parameter a
    b : float
        Parameter b

    Results
    -------
    c : list
        Parameter c
    """
    c = a + b
    return c


def f_bad_order(b, a):
    """Function f

    Parameters
    ----------
    a : int
        Parameter a
    b : float
        Parameter b

    Returns
    -------
    c : list
        Parameter c
    """
    c = a + b
    return c


def f_too_many_param_docstring(a, b):
    """Function f

    Parameters
    ----------
    a : int
        Parameter a
    b : int
        Parameter b
    c : int
        Parameter c

    Returns
    -------
    d : list
        Parameter c
    """
    d = a + b
    return d


def f_missing(a, b):
    """Function f

    Parameters
    ----------
    a : int
        Parameter a

    Returns
    -------
    c : list
        Parameter c
    """
    c = a + b
    return c


def f_check_param_definition(a, b, c, d, e):
    """Function f

    Parameters
    ----------
    a: int
        Parameter a
    b:
        Parameter b
    c :
        Parameter c
    d:int
        Parameter d
    e
        No typespec is allowed without colon
    """
    return a + b + c + d


class Klass:
    def f_missing(self, X, y):
        pass

    def f_bad_sections(self, X, y):
        """Function f

        Parameter
        ----------
        a : int
            Parameter a
        b : float
            Parameter b

        Results
        -------
        c : list
            Parameter c
        """
        pass


class MockEst:
    def __init__(self):
        """MockEstimator"""
    def fit(self, X, y):
        return X

    def predict(self, X):
        return X

    def predict_proba(self, X):
        return X

    def score(self, X):
        return 1.


class MockMetaEstimator:
    def __init__(self, delegate):
        """MetaEstimator to check if doctest on delegated methods work.

        Parameters
        ---------
        delegate : estimator
            Delegated estimator.
        """
        self.delegate = delegate

    @if_delegate_has_method(delegate=('delegate'))
    def predict(self, X):
        """This is available only if delegate has predict.

        Parameters
        ----------
        y : ndarray
            Parameter y
        """
        return self.delegate.predict(X)

    @if_delegate_has_method(delegate=('delegate'))
    @deprecated("Testing a deprecated delegated method")
    def score(self, X):
        """This is available only if delegate has score.

        Parameters
        ---------
        y : ndarray
            Parameter y
        """

    @if_delegate_has_method(delegate=('delegate'))
    def predict_proba(self, X):
        """This is available only if delegate has predict_proba.

        Parameters
        ---------
        X : ndarray
            Parameter X
        """
        return X

    @deprecated('Testing deprecated function with wrong params')
    def fit(self, X, y):
        """Incorrect docstring but should not be tested"""


def test_check_docstring_parameters():
    try:
        import numpydoc  # noqa
    except ImportError:
        raise SkipTest(
            "numpydoc is required to test the docstrings")

    incorrect = check_docstring_parameters(f_ok)
    assert incorrect == []
    incorrect = check_docstring_parameters(f_ok, ignore=['b'])
    assert incorrect == []
    incorrect = check_docstring_parameters(f_missing, ignore=['b'])
    assert incorrect == []
    assert_raise_message(RuntimeError, 'Unknown section Results',
                         check_docstring_parameters, f_bad_sections)
    assert_raise_message(RuntimeError, 'Unknown section Parameter',
                         check_docstring_parameters, Klass.f_bad_sections)

    incorrect = check_docstring_parameters(f_check_param_definition)
    assert (
        incorrect == [
            "sklearn.utils.tests.test_testing.f_check_param_definition There "
            "was no space between the param name and colon ('a: int')",

            "sklearn.utils.tests.test_testing.f_check_param_definition There "
            "was no space between the param name and colon ('b:')",

            "sklearn.utils.tests.test_testing.f_check_param_definition "
            "Parameter 'c :' has an empty type spec. Remove the colon",

            "sklearn.utils.tests.test_testing.f_check_param_definition There "
            "was no space between the param name and colon ('d:int')",
        ])

    messages = [
            ["In function: sklearn.utils.tests.test_testing.f_bad_order",
             "There's a parameter name mismatch in function docstring w.r.t."
             " function signature, at index 0 diff: 'b' != 'a'",
             "Full diff:",
             "- ['b', 'a']",
             "+ ['a', 'b']"],

            ["In function: " +
                "sklearn.utils.tests.test_testing.f_too_many_param_docstring",
             "Parameters in function docstring have more items w.r.t. function"
             " signature, first extra item: c",
             "Full diff:",
             "- ['a', 'b']",
             "+ ['a', 'b', 'c']",
             "?          +++++"],

            ["In function: sklearn.utils.tests.test_testing.f_missing",
             "Parameters in function docstring have less items w.r.t. function"
             " signature, first missing item: b",
             "Full diff:",
             "- ['a', 'b']",
             "+ ['a']"],

            ["In function: sklearn.utils.tests.test_testing.Klass.f_missing",
             "Parameters in function docstring have less items w.r.t. function"
             " signature, first missing item: X",
             "Full diff:",
             "- ['X', 'y']",
             "+ []"],

            ["In function: " +
             "sklearn.utils.tests.test_testing.MockMetaEstimator.predict",
             "There's a parameter name mismatch in function docstring w.r.t."
             " function signature, at index 0 diff: 'X' != 'y'",
             "Full diff:",
             "- ['X']",
             "?   ^",
             "+ ['y']",
             "?   ^"],

            ["In function: " +
             "sklearn.utils.tests.test_testing.MockMetaEstimator."
             + "predict_proba",
             "Parameters in function docstring have less items w.r.t. function"
             " signature, first missing item: X",
             "Full diff:",
             "- ['X']",
             "+ []"],

            ["In function: " +
                "sklearn.utils.tests.test_testing.MockMetaEstimator.score",
             "Parameters in function docstring have less items w.r.t. function"
             " signature, first missing item: X",
             "Full diff:",
             "- ['X']",
             "+ []"],

            ["In function: " +
                "sklearn.utils.tests.test_testing.MockMetaEstimator.fit",
             "Parameters in function docstring have less items w.r.t. function"
             " signature, first missing item: X",
             "Full diff:",
             "- ['X', 'y']",
             "+ []"],

            ]

    mock_meta = MockMetaEstimator(delegate=MockEst())

    for msg, f in zip(messages,
                      [f_bad_order,
                       f_too_many_param_docstring,
                       f_missing,
                       Klass.f_missing,
                       mock_meta.predict,
                       mock_meta.predict_proba,
                       mock_meta.score,
                       mock_meta.fit]):
        incorrect = check_docstring_parameters(f)
        assert msg == incorrect, ('\n"%s"\n not in \n"%s"' % (msg, incorrect))


class RegistrationCounter:
    def __init__(self):
        self.nb_calls = 0

    def __call__(self, to_register_func):
        self.nb_calls += 1
        assert to_register_func.func is _delete_folder


def check_memmap(input_array, mmap_data, mmap_mode='r'):
    assert isinstance(mmap_data, np.memmap)
    writeable = mmap_mode != 'r'
    assert mmap_data.flags.writeable is writeable
    np.testing.assert_array_equal(input_array, mmap_data)


def test_tempmemmap(monkeypatch):
    registration_counter = RegistrationCounter()
    monkeypatch.setattr(atexit, 'register', registration_counter)

    input_array = np.ones(3)
    with TempMemmap(input_array) as data:
        check_memmap(input_array, data)
        temp_folder = os.path.dirname(data.filename)
    if os.name != 'nt':
        assert not os.path.exists(temp_folder)
    assert registration_counter.nb_calls == 1

    mmap_mode = 'r+'
    with TempMemmap(input_array, mmap_mode=mmap_mode) as data:
        check_memmap(input_array, data, mmap_mode=mmap_mode)
        temp_folder = os.path.dirname(data.filename)
    if os.name != 'nt':
        assert not os.path.exists(temp_folder)
    assert registration_counter.nb_calls == 2


def test_create_memmap_backed_data(monkeypatch):
    registration_counter = RegistrationCounter()
    monkeypatch.setattr(atexit, 'register', registration_counter)

    input_array = np.ones(3)
    data = create_memmap_backed_data(input_array)
    check_memmap(input_array, data)
    assert registration_counter.nb_calls == 1

    data, folder = create_memmap_backed_data(input_array,
                                             return_folder=True)
    check_memmap(input_array, data)
    assert folder == os.path.dirname(data.filename)
    assert registration_counter.nb_calls == 2

    mmap_mode = 'r+'
    data = create_memmap_backed_data(input_array, mmap_mode=mmap_mode)
    check_memmap(input_array, data, mmap_mode)
    assert registration_counter.nb_calls == 3

    input_list = [input_array, input_array + 1, input_array + 2]
    mmap_data_list = create_memmap_backed_data(input_list)
    for input_array, data in zip(input_list, mmap_data_list):
        check_memmap(input_array, data)
    assert registration_counter.nb_calls == 4


# 0.24
@pytest.mark.parametrize('callable, args', [
    (assert_equal, (0, 0)),
    (assert_not_equal, (0, 1)),
    (assert_greater, (1, 0)),
    (assert_greater_equal, (1, 0)),
    (assert_less, (0, 1)),
    (assert_less_equal, (0, 1)),
    (assert_in, (0, [0])),
    (assert_not_in, (0, [1]))])
def test_deprecated_helpers(callable, args):
    msg = ('is deprecated in version 0.22 and will be removed in version '
           '0.24. Please use "assert" instead')
    with pytest.warns(FutureWarning, match=msg):
        callable(*args)


@pytest.mark.parametrize(
    "constructor_name, container_type",
    [('list', list),
     ('tuple', tuple),
     ('array', np.ndarray),
     ('sparse', sparse.csr_matrix),
     ('dataframe', pytest.importorskip('pandas').DataFrame),
     ('series', pytest.importorskip('pandas').Series),
     ('index', pytest.importorskip('pandas').Index),
     ('slice', slice)]
)
def test_convert_container(constructor_name, container_type):
    container = [0, 1]
    assert isinstance(_convert_container(container, constructor_name),
                      container_type)