Repository URL to install this package:
|
Version:
0.17.1 ▾
|
import warnings
import unittest
import sys
from nose.tools import assert_raises
from sklearn.utils.testing import (
_assert_less,
_assert_greater,
assert_less_equal,
assert_greater_equal,
assert_warns,
assert_no_warnings,
assert_equal,
set_random_state,
assert_raise_message)
from sklearn.tree import DecisionTreeClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
try:
from nose.tools import assert_less
def test_assert_less():
# Check that the nose implementation of assert_less gives the
# same thing as the scikit's
assert_less(0, 1)
_assert_less(0, 1)
assert_raises(AssertionError, assert_less, 1, 0)
assert_raises(AssertionError, _assert_less, 1, 0)
except ImportError:
pass
try:
from nose.tools import assert_greater
def test_assert_greater():
# Check that the nose implementation of assert_less gives the
# same thing as the scikit's
assert_greater(1, 0)
_assert_greater(1, 0)
assert_raises(AssertionError, assert_greater, 0, 1)
assert_raises(AssertionError, _assert_greater, 0, 1)
except ImportError:
pass
def test_assert_less_equal():
assert_less_equal(0, 1)
assert_less_equal(1, 1)
assert_raises(AssertionError, assert_less_equal, 1, 0)
def test_assert_greater_equal():
assert_greater_equal(1, 0)
assert_greater_equal(1, 1)
assert_raises(AssertionError, assert_greater_equal, 0, 1)
def test_set_random_state():
lda = LinearDiscriminantAnalysis()
tree = DecisionTreeClassifier()
# Linear Discriminant Analysis doesn't have random state: smoke test
set_random_state(lda, 3)
set_random_state(tree, 3)
assert_equal(tree.random_state, 3)
def test_assert_raise_message():
def _raise_ValueError(message):
raise ValueError(message)
def _no_raise():
pass
assert_raise_message(ValueError, "test",
_raise_ValueError, "test")
assert_raises(AssertionError,
assert_raise_message, ValueError, "something else",
_raise_ValueError, "test")
assert_raises(ValueError,
assert_raise_message, TypeError, "something else",
_raise_ValueError, "test")
assert_raises(AssertionError,
assert_raise_message, ValueError, "test",
_no_raise)
# multiple exceptions in a tuple
assert_raises(AssertionError,
assert_raise_message, (ValueError, AttributeError),
"test", _no_raise)
# This class is inspired from numpy 1.7 with an alteration to check
# the reset warning filters after calls to assert_warns.
# This assert_warns behavior is specific to scikit-learn because
#`clean_warning_registry()` is called internally by assert_warns
# and clears all previous filters.
class TestWarns(unittest.TestCase):
def test_warn(self):
def f():
warnings.warn("yo")
return 3
# Test that assert_warns is not impacted by externally set
# filters and is reset internally.
# This is because `clean_warning_registry()` is called internally by
# assert_warns and clears all previous filters.
warnings.simplefilter("ignore", UserWarning)
assert_equal(assert_warns(UserWarning, f), 3)
# Test that the warning registry is empty after assert_warns
assert_equal(sys.modules['warnings'].filters, [])
assert_raises(AssertionError, assert_no_warnings, f)
assert_equal(assert_no_warnings(lambda x: x, 1), 1)
def test_warn_wrong_warning(self):
def f():
warnings.warn("yo", DeprecationWarning)
failed = False
filters = sys.modules['warnings'].filters[:]
try:
try:
# Should raise an AssertionError
assert_warns(UserWarning, f)
failed = True
except AssertionError:
pass
finally:
sys.modules['warnings'].filters = filters
if failed:
raise AssertionError("wrong warning caught by assert_warn")