Repository URL to install this package:
|
Version:
6.4.1 ▾
|
# (C) Copyright 2005-2022 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!
import sys
import unittest.mock
import warnings
import weakref
from traits.api import HasTraits
from traits.constants import (
ComparisonMode, DefaultValue, TraitKind, MAXIMUM_DEFAULT_VALUE_TYPE
)
from traits.ctrait import CTrait
from traits.trait_errors import TraitError
from traits.trait_types import Any, Int, List
def getter():
""" Trivial getter. """
return True
def setter(value):
""" Trivial setter. """
pass
def validator(value):
""" Trivial validator. """
return value
class TestCTrait(unittest.TestCase):
""" Tests for the CTrait class. """
def test_initial_default_value(self):
trait = CTrait(TraitKind.trait)
self.assertEqual(
trait.default_value(), (DefaultValue.constant, None),
)
def test_set_and_get_default_value(self):
trait = CTrait(TraitKind.trait)
trait.set_default_value(DefaultValue.constant, 2.3)
self.assertEqual(trait.default_value(), (DefaultValue.constant, 2.3))
trait.set_default_value(DefaultValue.list_copy, [1, 2, 3])
self.assertEqual(
trait.default_value(), (DefaultValue.list_copy, [1, 2, 3])
)
def test_validate_default_value_for_callable_and_args(self):
bad_values = [
None,
123,
(int, (2,)),
(int, 2, 3, 4),
]
trait = CTrait(TraitKind.trait)
for value in bad_values:
with self.subTest(value=value):
with self.assertRaises(ValueError):
trait.set_default_value(
DefaultValue.callable_and_args, value)
def test_default_value_for_set_is_deprecated(self):
trait = CTrait(TraitKind.trait)
with warnings.catch_warnings(record=True) as warn_msgs:
warnings.simplefilter("always", DeprecationWarning)
trait.default_value(DefaultValue.constant, 3.7)
self.assertEqual(len(warn_msgs), 1)
warn_msg = warn_msgs[0]
self.assertIn(
"default_value method with arguments is deprecated",
str(warn_msg.message),
)
_, _, this_module = __name__.rpartition(".")
self.assertIn(this_module, warn_msg.filename)
def test_bad_default_value_type(self):
trait = CTrait(TraitKind.trait)
with self.assertRaises(ValueError):
trait.set_default_value(-1, None)
with self.assertRaises(ValueError):
trait.set_default_value(MAXIMUM_DEFAULT_VALUE_TYPE + 1, None)
def test_is_property(self):
trait = CTrait(TraitKind.trait)
self.assertFalse(trait.is_property)
trait.property_fields = (getter, setter, validator)
self.assertTrue(trait.is_property)
with self.assertRaises(AttributeError):
trait.is_property = False
def test_get_set_property(self):
trait = CTrait(TraitKind.trait)
# Get the property, ensure None
self.assertIsNone(trait.property_fields)
def value_get(self):
return self.__dict__.get("_value", 0)
def value_set(self, value):
old_value = self.__dict__.get("_value", 0)
if value != old_value:
self._value = value
self.trait_property_changed("value", old_value, value)
# Set the callables
trait.property_fields = (value_get, value_set, None)
fget, fset, validate = trait.property_fields
self.assertIs(fget, value_get)
self.assertIs(fset, value_set)
self.assertIsNone(validate)
# Ensure that _get_property does not accept arguments.
with self.assertRaises(TypeError):
trait._get_property(fget)
def test_modify_delegate(self):
trait = CTrait(TraitKind.trait)
self.assertFalse(trait.modify_delegate)
trait.modify_delegate = True
self.assertTrue(trait.modify_delegate)
def test_setattr_original_value(self):
trait = CTrait(TraitKind.trait)
self.assertFalse(trait.setattr_original_value)
trait.setattr_original_value = True
self.assertTrue(trait.setattr_original_value)
def test_post_setattr_original_value(self):
trait = CTrait(TraitKind.trait)
self.assertFalse(trait.post_setattr_original_value)
trait.post_setattr_original_value = True
self.assertTrue(trait.post_setattr_original_value)
def test_is_mapped(self):
trait = CTrait(TraitKind.trait)
self.assertFalse(trait.is_mapped)
trait.is_mapped = True
self.assertTrue(trait.is_mapped)
def test_default_comparison_mode(self):
trait = CTrait(TraitKind.trait)
self.assertIsInstance(trait.comparison_mode, ComparisonMode)
self.assertEqual(trait.comparison_mode, ComparisonMode.equality)
def test_invalid_comparison_mode(self):
trait = CTrait(TraitKind.trait)
# comparison modes other than {0,1,2}
# are invalid
with self.assertRaises(ValueError):
trait.comparison_mode = -1
with self.assertRaises(ValueError):
trait.comparison_mode = 3
def test_comparison_mode_unchanged_if_invalid(self):
trait = CTrait(TraitKind.trait)
default_comparison_mode = trait.comparison_mode
self.assertNotEqual(default_comparison_mode, ComparisonMode.none)
trait.comparison_mode = ComparisonMode.none
with self.assertRaises(ValueError):
trait.comparison_mode = -1
self.assertEqual(trait.comparison_mode, ComparisonMode.none)
def test_comparison_mode_int(self):
trait = CTrait(TraitKind.trait)
trait.comparison_mode = 0
self.assertIsInstance(trait.comparison_mode, ComparisonMode)
self.assertEqual(trait.comparison_mode, ComparisonMode.none)
trait.comparison_mode = 1
self.assertIsInstance(trait.comparison_mode, ComparisonMode)
self.assertEqual(trait.comparison_mode, ComparisonMode.identity)
trait.comparison_mode = 2
self.assertIsInstance(trait.comparison_mode, ComparisonMode)
self.assertEqual(trait.comparison_mode, ComparisonMode.equality)
def test_comparison_mode_enum(self):
trait = CTrait(TraitKind.trait)
trait.comparison_mode = ComparisonMode.none
self.assertIsInstance(trait.comparison_mode, ComparisonMode)
self.assertEqual(trait.comparison_mode, ComparisonMode.none)
trait.comparison_mode = ComparisonMode.identity
self.assertIsInstance(trait.comparison_mode, ComparisonMode)
self.assertEqual(trait.comparison_mode, ComparisonMode.identity)
trait.comparison_mode = ComparisonMode.equality
self.assertIsInstance(trait.comparison_mode, ComparisonMode)
self.assertEqual(trait.comparison_mode, ComparisonMode.equality)
def test_assign_post_setattr_none(self):
old_value = "old_value"
new_value = "new_value"
def post_setattr_func(obj, name, value):
obj.output_variable = value
trait = CTrait(TraitKind.trait)
class TestClass(HasTraits):
atr = trait
trait.post_setattr = post_setattr_func
self.assertIsNotNone(trait.post_setattr)
obj = TestClass()
obj.atr = old_value
self.assertEqual(old_value, obj.output_variable)
trait.post_setattr = None
self.assertIsNone(trait.post_setattr)
obj.atr = new_value
self.assertEqual(old_value, obj.output_variable)
trait.post_setattr = post_setattr_func
obj.atr = old_value
self.assertEqual(old_value, obj.output_variable)
with self.assertRaises(ValueError):
trait.post_setattr = "Invalid"
def test_unsafe_set_value(self):
# Regression test for enthought/traits#832. The test below causes
# a segfault (on at least some systems) before the fix.
def get_handler_refcount():
sys.getrefcount(tr.handler)
# Anything that we can create a weakref to works here.
weakrefable_object = {1, 2, 3}
tr = CTrait(0)
tr.handler = Any(weakrefable_object)
finalizer = weakref.finalize(weakrefable_object, get_handler_refcount)
del weakrefable_object
# Reassigning the handler should trigger the finaliser.
self.assertTrue(finalizer.alive)
tr.handler = None
self.assertFalse(finalizer.alive)
self.assertIsNone(tr.handler)
def test_invalid_initialization(self):
with self.assertRaises(TraitError):
CTrait(max(TraitKind) + 1)
def test_initialization_with_keywords_fails(self):
with self.assertRaises(TraitError):
CTrait(kind=0)
def test_default_initialization(self):
ctrait = CTrait()
validate = unittest.mock.MagicMock(return_value="baz")
ctrait.set_validate(validate)
class Foo(HasTraits):
bar = ctrait
bar_changed = List
def _bar_changed(self, new):
self.bar_changed.append(new)
foo = Foo()
self.assertEqual(len(foo.bar_changed), 0)
foo.bar = 1
validate.assert_called_once_with(foo, "bar", 1)
self.assertEqual(foo.bar, "baz")
self.assertEqual(len(foo.bar_changed), 1)
self.assertEqual(foo.bar_changed[0], "baz")
def test_failed_attribute_access(self):
# Double-underscore names are special-cased.
non_dunder_names = [
"non_existent",
"__non_existent",
"non_existent__",
"_non_existent_",
"__a__b_",
"_",
]
dunder_names = [
"__package__",
"__a__",
"____",
"___",
"__",
]
ctrait = CTrait(0)
for name in non_dunder_names:
with self.subTest(name=name):
self.assertIsNone(getattr(ctrait, name))
for name in dunder_names:
with self.subTest(name=name):
with self.assertRaises(AttributeError):
getattr(ctrait, name)
def test_exception_from_attribute_access(self):
# Regression test for enthought/traits#946.
# Danger: we're (temporarily) mutating global state here! Check that
# we're not touching an attribute that actually exists.
self.assertFalse(hasattr(CTrait, "badattr_test"))
CTrait.badattr_test = property(lambda self: 1 / 0)
try:
ctrait = CTrait(0)
with self.assertRaises(ZeroDivisionError):
ctrait.badattr_test
finally:
del CTrait.badattr_test
class TestCTraitNotifiers(unittest.TestCase):
""" Test calling trait notifiers and object notifiers. """
def test_notifiers_empty(self):
class Foo(HasTraits):
x = Int()
foo = Foo(x=1)
x_ctrait = foo.trait("x")
self.assertEqual(x_ctrait._notifiers(True), [])
def test_notifiers_on_trait(self):
class Foo(HasTraits):
x = Int()
def _x_changed(self):
pass
foo = Foo(x=1)
x_ctrait = foo.trait("x")
tnotifiers = x_ctrait._notifiers(True)
self.assertEqual(len(tnotifiers), 1)
notifier, = tnotifiers
self.assertEqual(notifier.handler, Foo._x_changed)