Repository URL to install this package:
|
Version:
6.4.1 ▾
|
# (C) Copyright 2005-2022 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!
import unittest
from unittest import mock
from traits.has_traits import HasTraits
from traits.observation._testing import (
call_add_or_remove_notifiers,
create_graph,
DummyNotifier,
DummyObserver,
)
from traits.observation._trait_added_observer import (
_RestrictedNamedTraitObserver,
TraitAddedObserver,
)
from traits.trait_types import Str
def create_observer(**kwargs):
values = dict(
match_func=mock.Mock(),
optional=False,
)
values.update(kwargs)
return TraitAddedObserver(**values)
class DummyMatchFunc:
""" A callable to be used as TraitAddedObserver.match_func
"""
def __init__(self, return_value):
self.return_value = return_value
def __call__(self, name, trait):
return self.return_value
def __eq__(self, other):
return self.return_value == other.return_value
def __hash__(self):
return hash(self.return_value)
class TestTraitAddedObserverEqualHash(unittest.TestCase):
""" Tests for TraitAddedObserver __eq__ and __hash__ methods.
"""
def test_not_equal_match_func(self):
observer1 = TraitAddedObserver(match_func=mock.Mock(), optional=True)
observer2 = TraitAddedObserver(match_func=mock.Mock(), optional=True)
self.assertNotEqual(observer1, observer2)
def test_not_equal_optional(self):
match_func = mock.Mock()
observer1 = TraitAddedObserver(match_func=match_func, optional=False)
observer2 = TraitAddedObserver(match_func=match_func, optional=True)
self.assertNotEqual(observer1, observer2)
def test_equal_match_func_optional(self):
# If two match_func compare equally and optional is the same
# then they are the same.
observer1 = TraitAddedObserver(
match_func=DummyMatchFunc(return_value=True),
optional=False,
)
observer2 = TraitAddedObserver(
match_func=DummyMatchFunc(return_value=True),
optional=False,
)
self.assertEqual(observer1, observer2)
self.assertEqual(hash(observer1), hash(observer2))
def test_not_equal_type(self):
match_func = mock.Mock()
observer1 = TraitAddedObserver(
match_func=match_func,
optional=False,
)
imposter = mock.Mock()
imposter.match_func = match_func
imposter.optional = False
self.assertNotEqual(observer1, imposter)
def test_notify_is_false(self):
observer = create_observer()
self.assertFalse(
observer.notify,
"TraitAddedObserver.notify should be always false.",
)
class TestRestrictedNamedTraitObserverEqualityHash(unittest.TestCase):
""" Test _RestrictedNamedTraitObserver.__eq__ and __hash__ """
def test_equality_name_and_observer(self):
wrapped_observer = DummyObserver()
observer1 = _RestrictedNamedTraitObserver(
name="name", wrapped_observer=wrapped_observer)
observer2 = _RestrictedNamedTraitObserver(
name="name", wrapped_observer=wrapped_observer)
self.assertEqual(observer1, observer2)
self.assertEqual(hash(observer1), hash(observer2))
def test_not_equal_name(self):
wrapped_observer = DummyObserver()
observer1 = _RestrictedNamedTraitObserver(
name="name", wrapped_observer=wrapped_observer)
observer2 = _RestrictedNamedTraitObserver(
name="other", wrapped_observer=wrapped_observer)
self.assertNotEqual(observer1, observer2)
def test_not_equal_observer(self):
observer1 = _RestrictedNamedTraitObserver(
name="name", wrapped_observer=DummyObserver())
observer2 = _RestrictedNamedTraitObserver(
name="name", wrapped_observer=DummyObserver())
self.assertNotEqual(observer1, observer2)
class TestRestrictedNamedTraitObserverWithWrappedObserver(unittest.TestCase):
""" Test the quantities inherited from the wrapped observer."""
def test_notify_inherited(self):
wrapped_observer = DummyObserver(notify=False)
observer = _RestrictedNamedTraitObserver(
name="name", wrapped_observer=wrapped_observer)
self.assertEqual(observer.notify, wrapped_observer.notify)
def test_notifier_inherited(self):
notifier = DummyNotifier()
wrapped_observer = DummyObserver(notifier=notifier)
observer = _RestrictedNamedTraitObserver(
name="name", wrapped_observer=wrapped_observer)
self.assertEqual(
observer.get_notifier(None, None, None),
notifier,
)
def test_maintainer_inherited(self):
maintainer = DummyNotifier()
wrapped_observer = DummyObserver(maintainer=maintainer)
observer = _RestrictedNamedTraitObserver(
name="name", wrapped_observer=wrapped_observer)
self.assertEqual(
observer.get_maintainer(None, None, None, None),
maintainer,
)
# -----------------------------------
# Integration tests with HasTraits
# -----------------------------------
class DummyHasTraitsClass(HasTraits):
def dummy_method(self):
pass
class TestTraitAddedObserverIterObservables(unittest.TestCase):
""" Test sanity checks in iter_observables. """
def test_iter_observables_get_trait_added_ctrait(self):
observer = create_observer()
instance = DummyHasTraitsClass()
actual, = list(observer.iter_observables(instance))
self.assertEqual(actual, instance._trait("trait_added", 2))
def test_iter_observables_ignore_incompatible_object_if_optional(self):
observer = create_observer(optional=True)
actual = list(observer.iter_observables(None))
self.assertEqual(actual, [])
def test_iter_observables_error_incompatible_object_if_required(self):
observer = create_observer(optional=False)
with self.assertRaises(ValueError) as exception_cm:
list(observer.iter_observables(None))
self.assertIn(
"Unable to observe 'trait_added'", str(exception_cm.exception))
class TestTraitAddedObserverIterObjects(unittest.TestCase):
""" Test iter_objects yields nothing. """
def test_iter_objects_yields_nothing(self):
observer = create_observer()
actual = list(observer.iter_objects(None))
self.assertEqual(actual, [])
class TestTraitAddedObserverNotifications(unittest.TestCase):
""" Test the core logic for maintaining downstream observers
when a trait is added.
"""
def setUp(self):
def match_func(name, trait):
return name.startswith("good_")
self.observer = TraitAddedObserver(
match_func=match_func,
optional=False,
)
def test_maintainer_trait_added(self):
# Test the maintainer is added for the trait_added event.
instance = DummyHasTraitsClass()
notifier = DummyNotifier()
maintainer = DummyNotifier()
graph = create_graph(
self.observer,
DummyObserver(
notify=True,
notifier=notifier,
maintainer=maintainer,
),
DummyObserver(), # to get maintainer in
)
call_add_or_remove_notifiers(
object=instance,
handler=instance.dummy_method,
target=instance,
graph=graph,
remove=False,
)
# when
instance.add_trait("good_name", Str())
# then
# the maintainer will have added a notifier because notify flag
# is set to true on the single observer being maintained.
notifiers = instance._trait("good_name", 2)._notifiers(True)
self.assertIn(notifier, notifiers)
self.assertIn(maintainer, notifiers)
# when
instance.add_trait("bad_name", Str())
# then
notifiers = instance._trait("bad_name", 2)._notifiers(True)
self.assertNotIn(notifier, notifiers)
self.assertNotIn(maintainer, notifiers)
def test_maintainer_keep_notify_flag(self):
# Test the maintainer will maintain the notify flag for the root
# observer in the subgraph.
instance = DummyHasTraitsClass()
notifier = DummyNotifier()
graph = create_graph(
self.observer,
DummyObserver(
notify=False,
notifier=notifier,
),
)
handler = mock.Mock()
call_add_or_remove_notifiers(
object=instance,
handler=handler,
target=instance,
graph=graph,
remove=False,
)
# when
instance.add_trait("good_name", Str())
# then
# notify flag is set to false, so there are no notifiers added.
notifiers = instance._trait("good_name", 2)._notifiers(True)
self.assertNotIn(notifier, notifiers)