Repository URL to install this package:
|
Version:
0.15.2 ▾
|
import warnings
import unittest
import sys
from nose.tools import assert_raises
from sklearn.utils.testing import (
_assert_less,
_assert_greater,
assert_warns,
assert_no_warnings,
assert_equal,
set_random_state,
assert_raise_message)
from sklearn.tree import DecisionTreeClassifier
from sklearn.lda import LDA
try:
from nose.tools import assert_less
def test_assert_less():
# Check that the nose implementation of assert_less gives the
# same thing as the scikit's
assert_less(0, 1)
_assert_less(0, 1)
assert_raises(AssertionError, assert_less, 1, 0)
assert_raises(AssertionError, _assert_less, 1, 0)
except ImportError:
pass
try:
from nose.tools import assert_greater
def test_assert_greater():
# Check that the nose implementation of assert_less gives the
# same thing as the scikit's
assert_greater(1, 0)
_assert_greater(1, 0)
assert_raises(AssertionError, assert_greater, 0, 1)
assert_raises(AssertionError, _assert_greater, 0, 1)
except ImportError:
pass
def test_set_random_state():
lda = LDA()
tree = DecisionTreeClassifier()
# LDA doesn't have random state: smoke test
set_random_state(lda, 3)
set_random_state(tree, 3)
assert_equal(tree.random_state, 3)
def test_assert_raise_message():
def _raise_ValueError(message):
raise ValueError(message)
assert_raise_message(ValueError, "test",
_raise_ValueError, "test")
assert_raises(AssertionError,
assert_raise_message, ValueError, "something else",
_raise_ValueError, "test")
assert_raises(ValueError,
assert_raise_message, TypeError, "something else",
_raise_ValueError, "test")
# This class is taken from numpy 1.7
class TestWarns(unittest.TestCase):
def test_warn(self):
def f():
warnings.warn("yo")
return 3
before_filters = sys.modules['warnings'].filters[:]
assert_equal(assert_warns(UserWarning, f), 3)
after_filters = sys.modules['warnings'].filters
assert_raises(AssertionError, assert_no_warnings, f)
assert_equal(assert_no_warnings(lambda x: x, 1), 1)
# Check that the warnings state is unchanged
assert_equal(before_filters, after_filters,
"assert_warns does not preserver warnings state")
def test_warn_wrong_warning(self):
def f():
warnings.warn("yo", DeprecationWarning)
failed = False
filters = sys.modules['warnings'].filters[:]
try:
try:
# Should raise an AssertionError
assert_warns(UserWarning, f)
failed = True
except AssertionError:
pass
finally:
sys.modules['warnings'].filters = filters
if failed:
raise AssertionError("wrong warning caught by assert_warn")