Repository URL to install this package:
|
Version:
6.4.1 ▾
|
# (C) Copyright 2005-2022 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!
import unittest
from unittest import mock
from traits.has_traits import HasTraits
from traits.observation._set_item_observer import SetItemObserver
from traits.observation._testing import (
call_add_or_remove_notifiers,
create_graph,
DummyObservable,
DummyObserver,
DummyNotifier,
)
from traits.trait_set_object import TraitSet
from traits.trait_types import Set
def create_observer(**kwargs):
""" Convenience function for creating SetItemObserver with default values.
"""
values = dict(
notify=True,
optional=False,
)
values.update(kwargs)
return SetItemObserver(**values)
class TestSetItemObserverEqualHash(unittest.TestCase):
""" Test SetItemObserver __eq__, __hash__ and immutability. """
def test_not_equal_notify(self):
observer1 = SetItemObserver(notify=False, optional=False)
observer2 = SetItemObserver(notify=True, optional=False)
self.assertNotEqual(observer1, observer2)
def test_not_equal_optional(self):
observer1 = SetItemObserver(notify=True, optional=True)
observer2 = SetItemObserver(notify=True, optional=False)
self.assertNotEqual(observer1, observer2)
def test_not_equal_different_type(self):
observer1 = SetItemObserver(notify=False, optional=False)
imposter = mock.Mock()
imposter.notify = False
imposter.optional = False
self.assertNotEqual(observer1, imposter)
def test_equal_observers(self):
observer1 = SetItemObserver(notify=False, optional=False)
observer2 = SetItemObserver(notify=False, optional=False)
self.assertEqual(observer1, observer2)
self.assertEqual(hash(observer1), hash(observer2))
def test_slots(self):
observer = SetItemObserver(notify=True, optional=False)
with self.assertRaises(AttributeError):
observer.__dict__
with self.assertRaises(AttributeError):
observer.__weakref__
def test_eval_repr_roundtrip(self):
observer = SetItemObserver(notify=True, optional=False)
self.assertEqual(eval(repr(observer)), observer)
class CustomSet(set):
# This is a set, but not an observable
pass
class CustomTraitSet(TraitSet):
# This can be used with SetItemObserver
pass
class ClassWithSet(HasTraits):
values = Set()
class TestSetItemObserverIterObservable(unittest.TestCase):
""" Test SetItemObserver.iter_observables """
def test_trait_set_iter_observables(self):
instance = ClassWithSet()
observer = create_observer(optional=False)
actual_item, = list(observer.iter_observables(instance.values))
self.assertIs(actual_item, instance.values)
def test_set_but_not_a_trait_set(self):
observer = create_observer(optional=False)
with self.assertRaises(ValueError) as exception_context:
list(observer.iter_observables(CustomSet()))
self.assertIn(
"Expected a TraitSet to be observed, got",
str(exception_context.exception)
)
def test_iter_observables_custom_trait_set(self):
# A subcalss of TraitSet can also be used.
custom_trait_set = CustomTraitSet()
observer = create_observer()
actual_item, = list(observer.iter_observables(custom_trait_set))
self.assertIs(actual_item, custom_trait_set)
def test_not_a_set(self):
observer = create_observer(optional=False)
with self.assertRaises(ValueError) as exception_context:
list(observer.iter_observables(None))
self.assertIn(
"Expected a TraitSet to be observed, got",
str(exception_context.exception)
)
def test_optional_flag_not_a_set(self):
observer = create_observer(optional=True)
actual = list(observer.iter_observables(None))
self.assertEqual(actual, [])
def test_optional_flag_not_an_observable(self):
observer = create_observer(optional=True)
actual = list(observer.iter_observables(CustomSet()))
self.assertEqual(actual, [])
class TestSetItemObserverIterObjects(unittest.TestCase):
""" Test SetItemObserver.iter_objects """
def test_iter_objects_from_set(self):
instance = ClassWithSet()
instance.values = set([1, 2, 3])
observer = create_observer()
actual = list(observer.iter_objects(instance.values))
self.assertCountEqual(actual, [1, 2, 3])
def test_iter_observables_custom_trait_set(self):
# A subcalss of TraitSet can also be used.
custom_trait_set = CustomTraitSet([1, 2, 3])
observer = create_observer()
actual = list(observer.iter_objects(custom_trait_set))
self.assertCountEqual(actual, [1, 2, 3])
def test_iter_objects_sanity_check(self):
# sanity check if the given object is a set
observer = create_observer(optional=False)
with self.assertRaises(ValueError) as exception_context:
list(observer.iter_objects(None))
self.assertIn(
"Expected a TraitSet to be observed, got",
str(exception_context.exception),
)
def test_iter_objects_optional(self):
observer = create_observer(optional=True)
actual = list(observer.iter_objects(None))
self.assertEqual(actual, [])
class TestSetItemObserverNotifications(unittest.TestCase):
""" Integration tests with notifiers and HasTraits. """
def test_notify_set_change(self):
instance = ClassWithSet(values=set())
graph = create_graph(
create_observer(notify=True),
)
handler = mock.Mock()
call_add_or_remove_notifiers(
object=instance.values,
graph=graph,
handler=handler,
)
# when
instance.values.add(1)
# then
((event, ), _), = handler.call_args_list
self.assertEqual(event.added, set([1]))
self.assertEqual(event.removed, set())
def test_maintain_notifier(self):
# Test maintaining downstream notifier
class ChildObserver(DummyObserver):
def iter_observables(self, object):
yield object
instance = ClassWithSet()
instance.values = set()
notifier = DummyNotifier()
child_observer = ChildObserver(notifier=notifier)
graph = create_graph(
create_observer(notify=False, optional=False),
child_observer,
)
handler = mock.Mock()
call_add_or_remove_notifiers(
object=instance.values,
graph=graph,
handler=handler,
)
# when
observable = DummyObservable()
instance.values.add(observable)
# then
self.assertEqual(observable.notifiers, [notifier])
# when
instance.values.remove(observable)
# then
self.assertEqual(observable.notifiers, [])