1# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX 2# All rights reserved. 3# 4# This software is provided without warranty under the terms of the BSD 5# license included in LICENSE.txt and may be redistributed only under 6# the conditions described in the aforementioned license. The license 7# is also available online at http://www.enthought.com/licenses/BSD.txt 8# 9# Thanks for using Enthought open source! 10 11# Imports 12 13import unittest 14import warnings 15 16from traits.api import ( 17 Any, 18 Bytes, 19 CBytes, 20 CFloat, 21 CInt, 22 ComparisonMode, 23 Color, 24 Delegate, 25 Float, 26 Font, 27 HasTraits, 28 Instance, 29 Int, 30 List, 31 Range, 32 RGBColor, 33 Str, 34 This, 35 Trait, 36 TraitError, 37 TraitList, 38 TraitPrefixList, 39 TraitPrefixMap, 40 Tuple, 41 pop_exception_handler, 42 push_exception_handler, 43) 44from traits.testing.optional_dependencies import requires_traitsui 45 46# Base unit test classes: 47 48 49class BaseTest(object): 50 def assign(self, value): 51 self.obj.value = value 52 53 def coerce(self, value): 54 return value 55 56 def test_assignment(self): 57 obj = self.obj 58 59 # Validate default value 60 value = self._default_value 61 self.assertEqual(obj.value, value) 62 63 # Validate all legal values 64 for i, value in enumerate(self._good_values): 65 obj.value = value 66 self.assertEqual(obj.value, self.coerce(value)) 67 68 # If there's a defined 69 if i < len(self._mapped_values): 70 self.assertEqual(obj.value_, self._mapped_values[i]) 71 72 # Validate correct behavior for illegal values 73 for value in self._bad_values: 74 self.assertRaises(TraitError, self.assign, value) 75 76 77class test_base2(unittest.TestCase): 78 def indexed_assign(self, list, index, value): 79 list[index] = value 80 81 def indexed_range_assign(self, list, index1, index2, value): 82 list[index1:index2] = value 83 84 def extended_slice_assign(self, list, index1, index2, step, value): 85 list[index1:index2:step] = value 86 87 # This avoids using a method name that contains 'test' so that this is not 88 # called by the tester directly. 89 def check_values( 90 self, 91 name, 92 default_value, 93 good_values, 94 bad_values, 95 actual_values=None, 96 mapped_values=None, 97 ): 98 obj = self.obj 99 100 # Make sure the default value is correct: 101 value = default_value 102 self.assertEqual(getattr(obj, name), value) 103 104 # Iterate over all legal values being tested: 105 if actual_values is None: 106 actual_values = good_values 107 i = 0 108 for value in good_values: 109 setattr(obj, name, value) 110 self.assertEqual(getattr(obj, name), actual_values[i]) 111 if mapped_values is not None: 112 self.assertEqual( 113 getattr(obj, name + "_"), mapped_values[i] 114 ) 115 i += 1 116 117 # Iterate over all illegal values being tested: 118 for value in bad_values: 119 self.assertRaises(TraitError, setattr, obj, name, value) 120 121 122class AnyTrait(HasTraits): 123 value = Any 124 125 126class AnyTraitTest(BaseTest, unittest.TestCase): 127 128 def setUp(self): 129 self.obj = AnyTrait() 130 131 _default_value = None 132 _good_values = [10.0, b"ten", "ten", [10], {"ten": 10}, (10,), None, 1j] 133 _mapped_values = [] 134 _bad_values = [] 135 136 137class CoercibleIntTrait(HasTraits): 138 value = CInt(99) 139 140 141class IntTrait(HasTraits): 142 value = Int(99) 143 144 145class CoercibleIntTest(AnyTraitTest): 146 147 def setUp(self): 148 self.obj = CoercibleIntTrait() 149 150 _default_value = 99 151 _good_values = [ 152 10, 153 -10, 154 10.1, 155 -10.1, 156 "10", 157 "-10", 158 b"10", 159 b"-10", 160 ] 161 _bad_values = [ 162 "10L", 163 "-10L", 164 "10.1", 165 "-10.1", 166 b"10L", 167 b"-10L", 168 b"10.1", 169 b"-10.1", 170 "ten", 171 b"ten", 172 [10], 173 {"ten": 10}, 174 (10,), 175 None, 176 1j, 177 ] 178 179 def coerce(self, value): 180 try: 181 return int(value) 182 except: 183 return int(float(value)) 184 185 186class IntTest(AnyTraitTest): 187 188 def setUp(self): 189 self.obj = IntTrait() 190 191 _default_value = 99 192 _good_values = [10, -10] 193 _bad_values = [ 194 "ten", 195 b"ten", 196 [10], 197 {"ten": 10}, 198 (10,), 199 None, 200 1j, 201 10.1, 202 -10.1, 203 "10L", 204 "-10L", 205 "10.1", 206 "-10.1", 207 b"10L", 208 b"-10L", 209 b"10.1", 210 b"-10.1", 211 "10", 212 "-10", 213 b"10", 214 b"-10", 215 ] 216 217 try: 218 import numpy as np 219 except ImportError: 220 pass 221 else: 222 _good_values.extend( 223 [ 224 np.int64(10), 225 np.int64(-10), 226 np.int32(10), 227 np.int32(-10), 228 np.int_(10), 229 np.int_(-10), 230 ] 231 ) 232 233 def coerce(self, value): 234 try: 235 return int(value) 236 except: 237 return int(float(value)) 238 239 240class CoercibleFloatTrait(HasTraits): 241 value = CFloat(99.0) 242 243 244class FloatTrait(HasTraits): 245 value = Float(99.0) 246 247 248class CoercibleFloatTest(AnyTraitTest): 249 def setUp(self): 250 self.obj = CoercibleFloatTrait() 251 252 _default_value = 99.0 253 _good_values = [ 254 10, 255 -10, 256 10.1, 257 -10.1, 258 "10", 259 "-10", 260 "10.1", 261 "-10.1", 262 b"10", 263 b"-10", 264 b"10.1", 265 b"-10.1", 266 ] 267 _bad_values = [ 268 "10L", 269 "-10L", 270 b"10L", 271 b"-10L", 272 "ten", 273 b"ten", 274 [10], 275 {"ten": 10}, 276 (10,), 277 None, 278 1j, 279 ] 280 281 def coerce(self, value): 282 return float(value) 283 284 285class FloatTest(AnyTraitTest): 286 def setUp(self): 287 self.obj = FloatTrait() 288 289 _default_value = 99.0 290 _good_values = [10, -10, 10.1, -10.1] 291 _bad_values = [ 292 "ten", 293 b"ten", 294 [10], 295 {"ten": 10}, 296 (10,), 297 None, 298 1j, 299 "10", 300 "-10", 301 "10L", 302 "-10L", 303 "10.1", 304 "-10.1", 305 b"10", 306 b"-10", 307 b"10L", 308 b"-10L", 309 b"10.1", 310 b"-10.1", 311 ] 312 313 def coerce(self, value): 314 return float(value) 315 316 317# Trait that can only have 'complex'(i.e. imaginary) values: 318 319 320class ImaginaryValueTrait(HasTraits): 321 value = Trait(99.0 - 99.0j) 322 323 324class ImaginaryValueTest(AnyTraitTest): 325 def setUp(self): 326 self.obj = ImaginaryValueTrait() 327 328 _default_value = 99.0 - 99.0j 329 _good_values = [ 330 10, 331 -10, 332 10.1, 333 -10.1, 334 "10", 335 "-10", 336 "10.1", 337 "-10.1", 338 10j, 339 10 + 10j, 340 10 - 10j, 341 10.1j, 342 10.1 + 10.1j, 343 10.1 - 10.1j, 344 "10j", 345 "10+10j", 346 "10-10j", 347 ] 348 _bad_values = [b"10L", "-10L", "ten", [10], {"ten": 10}, (10,), None] 349 350 def coerce(self, value): 351 return complex(value) 352 353 354class StringTrait(HasTraits): 355 value = Trait("string") 356 357 358class StringTest(AnyTraitTest): 359 360 def setUp(self): 361 self.obj = StringTrait() 362 363 _default_value = "string" 364 _good_values = [ 365 10, 366 -10, 367 10.1, 368 -10.1, 369 "10", 370 "-10", 371 "10L", 372 "-10L", 373 "10.1", 374 "-10.1", 375 "string", 376 1j, 377 [10], 378 ["ten"], 379 {"ten": 10}, 380 (10,), 381 None, 382 ] 383 _bad_values = [] 384 385 def coerce(self, value): 386 return str(value) 387 388 389class BytesTrait(HasTraits): 390 value = Bytes(b"bytes") 391 392 393class BytesTest(StringTest): 394 395 def setUp(self): 396 self.obj = BytesTrait() 397 398 _default_value = b"bytes" 399 _good_values = [b"", b"10", b"-10"] 400 _bad_values = [ 401 10, 402 -10, 403 10.1, 404 [b""], 405 [b"bytes"], 406 [0], 407 {b"ten": b"10"}, 408 (b"",), 409 None, 410 True, 411 "", 412 "string", 413 ] 414 415 def coerce(self, value): 416 return bytes(value) 417 418 419class CoercibleBytesTrait(HasTraits): 420 value = CBytes(b"bytes") 421 422 423class CoercibleBytesTest(StringTest): 424 425 def setUp(self): 426 self.obj = CoercibleBytesTrait() 427 428 _default_value = b"bytes" 429 _good_values = [ 430 b"", 431 b"10", 432 b"-10", 433 10, 434 [10], 435 (10,), 436 set([10]), 437 {10: "foo"}, 438 True, 439 ] 440 _bad_values = [ 441 "", 442 "string", 443 -10, 444 10.1, 445 [b""], 446 [b"bytes"], 447 [-10], 448 (-10,), 449 {-10: "foo"}, 450 set([-10]), 451 [256], 452 (256,), 453 {256: "foo"}, 454 set([256]), 455 {b"ten": b"10"}, 456 (b"",), 457 None, 458 ] 459 460 def coerce(self, value): 461 return bytes(value) 462 463 464class EnumTrait(HasTraits): 465 value = Trait([1, "one", 2, "two", 3, "three", 4.4, "four.four"]) 466 467 468class EnumTest(AnyTraitTest): 469 470 def setUp(self): 471 self.obj = EnumTrait() 472 473 _default_value = 1 474 _good_values = [1, "one", 2, "two", 3, "three", 4.4, "four.four"] 475 _bad_values = [0, "zero", 4, None] 476 477 478class MappedTrait(HasTraits): 479 value = Trait("one", {"one": 1, "two": 2, "three": 3}) 480 481 482class MappedTest(AnyTraitTest): 483 def setUp(self): 484 self.obj = MappedTrait() 485 486 _default_value = "one" 487 _good_values = ["one", "two", "three"] 488 _mapped_values = [1, 2, 3] 489 _bad_values = ["four", 1, 2, 3, [1], (1,), {1: 1}, None] 490 491 492# Suppress DeprecationWarning from TraitPrefixList instantiation. 493with warnings.catch_warnings(): 494 warnings.filterwarnings(action="ignore", category=DeprecationWarning) 495 496 class PrefixListTrait(HasTraits): 497 value = Trait("one", TraitPrefixList("one", "two", "three")) 498 499 500class PrefixListTest(AnyTraitTest): 501 def setUp(self): 502 self.obj = PrefixListTrait() 503 504 _default_value = "one" 505 _good_values = [ 506 "o", 507 "on", 508 "one", 509 "tw", 510 "two", 511 "th", 512 "thr", 513 "thre", 514 "three", 515 ] 516 _bad_values = ["t", "one ", " two", 1, None] 517 518 def coerce(self, value): 519 return {"o": "one", "on": "one", "tw": "two", "th": "three"}[value[:2]] 520 521 522# Suppress DeprecationWarning from TraitPrefixMap instantiation. 523with warnings.catch_warnings(): 524 warnings.filterwarnings(action="ignore", category=DeprecationWarning) 525 526 class PrefixMapTrait(HasTraits): 527 value = Trait("one", TraitPrefixMap({"one": 1, "two": 2, "three": 3})) 528 529 530class PrefixMapTest(AnyTraitTest): 531 def setUp(self): 532 self.obj = PrefixMapTrait() 533 534 _default_value = "one" 535 _good_values = [ 536 "o", 537 "on", 538 "one", 539 "tw", 540 "two", 541 "th", 542 "thr", 543 "thre", 544 "three", 545 ] 546 _mapped_values = [1, 1, 1, 2, 2, 3, 3, 3] 547 _bad_values = ["t", "one ", " two", 1, None] 548 549 def coerce(self, value): 550 return {"o": "one", "on": "one", "tw": "two", "th": "three"}[value[:2]] 551 552 553# This test a combination of Trait, a default, a mapping and a function 554 555def str_cast_to_int(object, name, value): 556 """ A function that validates the value is a str and then converts 557 it to an int using its length. 558 """ 559 if not isinstance(value, str): 560 raise TraitError("Not a string!") 561 return len(value) 562 563 564class TraitWithMappingAndCallable(HasTraits): 565 566 value = Trait( 567 "white", 568 {"white": 0, "red": 1, (0, 0, 0): 999}, 569 str_cast_to_int, 570 ) 571 572 573class TestTraitWithMappingAndCallable(unittest.TestCase): 574 """ Test that demonstrates a usage of Trait where TraitMap is used but it 575 cannot be replaced with Map. The callable causes the key value to be 576 changed to match the mapped value. 577 578 e.g. this would not work: 579 580 value = Union( 581 Map({"white": 0, "red": 1, (0,0,0): 999}), 582 NewTraitType(), 583 default_value="white", 584 ) 585 586 where NewTraitType is a subclass of TraitType with ``validate`` simply 587 calls str_cast_to_int 588 """ 589 590 def test_trait_default(self): 591 obj = TraitWithMappingAndCallable() 592 593 # the value is not 'white' any more. 594 self.assertEqual(obj.value, 5) 595 self.assertEqual(obj.value_, 5) 596 597 def test_trait_set_value_use_callable(self): 598 obj = TraitWithMappingAndCallable(value="red") 599 600 # The value is not 'red' any more. 601 # the callable is used, not the mapping. 602 self.assertEqual(obj.value, 3) 603 self.assertEqual(obj.value_, 3) 604 605 def test_trait_set_value_use_mapping(self): 606 obj = TraitWithMappingAndCallable(value=(0, 0, 0)) 607 608 # Now this uses the mapping, and the value is the original one. 609 self.assertEqual(obj.value, (0, 0, 0)) 610 self.assertEqual(obj.value_, 999) 611 612 613# Old style class version: 614 615 616class OTraitTest1: 617 pass 618 619 620class OTraitTest2(OTraitTest1): 621 pass 622 623 624class OTraitTest3(OTraitTest2): 625 pass 626 627 628class OBadTraitTest: 629 pass 630 631 632otrait_test1 = OTraitTest1() 633 634 635class OldInstanceTrait(HasTraits): 636 value = Trait(otrait_test1) 637 638 639class OldInstanceTest(AnyTraitTest): 640 def setUp(self): 641 self.obj = OldInstanceTrait() 642 643 _default_value = otrait_test1 644 _good_values = [ 645 otrait_test1, 646 OTraitTest1(), 647 OTraitTest2(), 648 OTraitTest3(), 649 None, 650 ] 651 _bad_values = [ 652 0, 653 0.0, 654 0j, 655 OTraitTest1, 656 OTraitTest2, 657 OBadTraitTest(), 658 b"bytes", 659 "string", 660 [otrait_test1], 661 (otrait_test1,), 662 {"data": otrait_test1}, 663 ] 664 665 666# New style class version: 667class NTraitTest1(object): 668 pass 669 670 671class NTraitTest2(NTraitTest1): 672 pass 673 674 675class NTraitTest3(NTraitTest2): 676 pass 677 678 679class NBadTraitTest: 680 pass 681 682 683ntrait_test1 = NTraitTest1() 684 685 686class NewInstanceTrait(HasTraits): 687 value = Trait(ntrait_test1) 688 689 690class NewInstanceTest(AnyTraitTest): 691 def setUp(self): 692 self.obj = NewInstanceTrait() 693 694 _default_value = ntrait_test1 695 _good_values = [ 696 ntrait_test1, 697 NTraitTest1(), 698 NTraitTest2(), 699 NTraitTest3(), 700 None, 701 ] 702 _bad_values = [ 703 0, 704 0.0, 705 0j, 706 NTraitTest1, 707 NTraitTest2, 708 NBadTraitTest(), 709 b"bytes", 710 "string", 711 [ntrait_test1], 712 (ntrait_test1,), 713 {"data": ntrait_test1}, 714 ] 715 716 717class FactoryClass(HasTraits): 718 pass 719 720 721class ConsumerClass(HasTraits): 722 x = Instance(FactoryClass, ()) 723 724 725class ConsumerSubclass(ConsumerClass): 726 x = FactoryClass() 727 728 729embedded_instance_trait = Trait( 730 "", Str, Instance("traits.has_traits.HasTraits") 731) 732 733 734class Dummy(HasTraits): 735 x = embedded_instance_trait 736 xl = List(embedded_instance_trait) 737 738 739class RegressionTest(unittest.TestCase): 740 """ Check that fixed bugs stay fixed. 741 """ 742 743 def test_factory_subclass_no_segfault(self): 744 """ Test that we can provide an instance as a default in the definition 745 of a subclass. 746 """ 747 # There used to be a bug where this would segfault. 748 obj = ConsumerSubclass() 749 obj.x 750 751 def test_trait_compound_instance(self): 752 """ Test that a deferred Instance() embedded in a TraitCompound handler 753 and then a list will not replace the validate method for the outermost 754 trait. 755 """ 756 # Pass through an instance in order to make the instance trait resolve 757 # the class. 758 d = Dummy() 759 d.xl = [HasTraits()] 760 d.x = "OK" 761 762 763# Trait(using a function) that must be an odd integer: 764 765 766def odd_integer(object, name, value): 767 try: 768 float(value) 769 if (value % 2) == 1: 770 return int(value) 771 except: 772 pass 773 raise TraitError 774 775 776class OddIntegerTrait(HasTraits): 777 value = Trait(99, odd_integer) 778 779 780class OddIntegerTest(AnyTraitTest): 781 def setUp(self): 782 self.obj = OddIntegerTrait() 783 784 _default_value = 99 785 _good_values = [ 786 1, 787 3, 788 5, 789 7, 790 9, 791 999999999, 792 1.0, 793 3.0, 794 5.0, 795 7.0, 796 9.0, 797 999999999.0, 798 -1, 799 -3, 800 -5, 801 -7, 802 -9, 803 -999999999, 804 -1.0, 805 -3.0, 806 -5.0, 807 -7.0, 808 -9.0, 809 -999999999.0, 810 ] 811 _bad_values = [0, 2, -2, 1j, None, "1", [1], (1,), {1: 1}] 812 813 814class NotifierTraits(HasTraits): 815 value1 = Int 816 value2 = Int 817 value1_count = Int 818 value2_count = Int 819 820 def _anytrait_changed(self, trait_name, old, new): 821 if trait_name == "value1": 822 self.value1_count += 1 823 elif trait_name == "value2": 824 self.value2_count += 1 825 826 def _value1_changed(self, old, new): 827 self.value1_count += 1 828 829 def _value2_changed(self, old, new): 830 self.value2_count += 1 831 832 833class NotifierTests(unittest.TestCase): 834 def setUp(self): 835 obj = self.obj = NotifierTraits() 836 obj.value1 = 0 837 obj.value2 = 0 838 obj.value1_count = 0 839 obj.value2_count = 0 840 841 def tearDown(self): 842 obj = self.obj 843 obj.on_trait_change(self.on_value1_changed, "value1", remove=True) 844 obj.on_trait_change(self.on_value2_changed, "value2", remove=True) 845 obj.on_trait_change(self.on_anytrait_changed, remove=True) 846 847 def on_anytrait_changed(self, object, trait_name, old, new): 848 if trait_name == "value1": 849 self.obj.value1_count += 1 850 elif trait_name == "value2": 851 self.obj.value2_count += 1 852 853 def on_value1_changed(self): 854 self.obj.value1_count += 1 855 856 def on_value2_changed(self): 857 self.obj.value2_count += 1 858 859 def test_simple(self): 860 obj = self.obj 861 862 obj.value1 = 1 863 self.assertEqual(obj.value1_count, 2) 864 self.assertEqual(obj.value2_count, 0) 865 866 obj.value2 = 1 867 self.assertEqual(obj.value1_count, 2) 868 self.assertEqual(obj.value2_count, 2) 869 870 def test_complex(self): 871 obj = self.obj 872 873 obj.on_trait_change(self.on_value1_changed, "value1") 874 obj.value1 = 1 875 self.assertEqual(obj.value1_count, 3) 876 self.assertEqual(obj.value2_count, 0) 877 878 obj.on_trait_change(self.on_value2_changed, "value2") 879 obj.value2 = 1 880 self.assertEqual(obj.value1_count, 3) 881 self.assertEqual(obj.value2_count, 3) 882 883 obj.on_trait_change(self.on_anytrait_changed) 884 885 obj.value1 = 2 886 self.assertEqual(obj.value1_count, 7) 887 self.assertEqual(obj.value2_count, 3) 888 889 obj.value1 = 2 890 self.assertEqual(obj.value1_count, 7) 891 self.assertEqual(obj.value2_count, 3) 892 893 obj.value2 = 2 894 self.assertEqual(obj.value1_count, 7) 895 self.assertEqual(obj.value2_count, 7) 896 897 obj.on_trait_change(self.on_value1_changed, "value1", remove=True) 898 obj.value1 = 3 899 self.assertEqual(obj.value1_count, 10) 900 self.assertEqual(obj.value2_count, 7) 901 902 obj.on_trait_change(self.on_value2_changed, "value2", remove=True) 903 obj.value2 = 3 904 self.assertEqual(obj.value1_count, 10) 905 self.assertEqual(obj.value2_count, 10) 906 907 obj.on_trait_change(self.on_anytrait_changed, remove=True) 908 909 obj.value1 = 4 910 self.assertEqual(obj.value1_count, 12) 911 self.assertEqual(obj.value2_count, 10) 912 913 obj.value2 = 4 914 self.assertEqual(obj.value1_count, 12) 915 self.assertEqual(obj.value2_count, 12) 916 917 918class RaisesArgumentlessRuntimeError(HasTraits): 919 x = Int(0) 920 921 def _x_changed(self): 922 raise RuntimeError 923 924 925class TestRuntimeError(unittest.TestCase): 926 def setUp(self): 927 push_exception_handler(lambda *args: None, reraise_exceptions=True) 928 929 def tearDown(self): 930 pop_exception_handler() 931 932 def test_runtime_error(self): 933 f = RaisesArgumentlessRuntimeError() 934 self.assertRaises(RuntimeError, setattr, f, "x", 5) 935 936 937class DelegatedFloatTrait(HasTraits): 938 value = Trait(99.0) 939 940 941class DelegateTrait(HasTraits): 942 value = Delegate("delegate") 943 delegate = Trait(DelegatedFloatTrait()) 944 945 946class DelegateTrait2(DelegateTrait): 947 delegate = Trait(DelegateTrait()) 948 949 950class DelegateTrait3(DelegateTrait): 951 delegate = Trait(DelegateTrait2()) 952 953 954class DelegateTests(unittest.TestCase): 955 def test_delegation(self): 956 obj = DelegateTrait3() 957 958 self.assertEqual(obj.value, 99.0) 959 parent1 = obj.delegate 960 parent2 = parent1.delegate 961 parent3 = parent2.delegate 962 parent3.value = 3.0 963 self.assertEqual(obj.value, 3.0) 964 parent2.value = 2.0 965 self.assertEqual(obj.value, 2.0) 966 self.assertEqual(parent3.value, 3.0) 967 parent1.value = 1.0 968 self.assertEqual(obj.value, 1.0) 969 self.assertEqual(parent2.value, 2.0) 970 self.assertEqual(parent3.value, 3.0) 971 obj.value = 0.0 972 self.assertEqual(obj.value, 0.0) 973 self.assertEqual(parent1.value, 1.0) 974 self.assertEqual(parent2.value, 2.0) 975 self.assertEqual(parent3.value, 3.0) 976 del obj.value 977 self.assertEqual(obj.value, 1.0) 978 del parent1.value 979 self.assertEqual(obj.value, 2.0) 980 self.assertEqual(parent1.value, 2.0) 981 del parent2.value 982 self.assertEqual(obj.value, 3.0) 983 self.assertEqual(parent1.value, 3.0) 984 self.assertEqual(parent2.value, 3.0) 985 del parent3.value 986 # Uncommenting the following line allows 987 # the last assertions to pass. However, this 988 # may not be intended behavior, so keeping 989 # the line commented. 990 # del parent2.value 991 self.assertEqual(obj.value, 99.0) 992 self.assertEqual(parent1.value, 99.0) 993 self.assertEqual(parent2.value, 99.0) 994 self.assertEqual(parent3.value, 99.0) 995 996 997# Complex(i.e. 'composite') Traits tests: 998 999# Make a TraitCompound handler that does not have a fast_validate so we can 1000# check for a particular regression. 1001slow = Trait(1, Range(1, 3), Range(-3, -1)) 1002try: 1003 del slow.handler.fast_validate 1004except AttributeError: 1005 pass 1006 1007 1008# Suppress DeprecationWarnings from TraitPrefixList and TraitPrefixMap 1009with warnings.catch_warnings(): 1010 warnings.filterwarnings(action="ignore", category=DeprecationWarning) 1011 1012 class complex_value(HasTraits): 1013 num1 = Trait(1, Range(1, 5), Range(-5, -1)) 1014 num2 = Trait( 1015 1, 1016 Range(1, 5), 1017 TraitPrefixList("one", "two", "three", "four", "five"), 1018 ) 1019 num3 = Trait( 1020 1, 1021 Range(1, 5), 1022 TraitPrefixMap( 1023 {"one": 1, "two": 2, "three": 3, "four": 4, "five": 5} 1024 ), 1025 ) 1026 num4 = Trait(1, Trait(1, Tuple, slow), 10) 1027 num5 = Trait(1, 10, Trait(1, Tuple, slow)) 1028 1029 1030class test_complex_value(test_base2): 1031 def setUp(self): 1032 self.obj = complex_value() 1033 1034 def test_num1(self): 1035 self.check_values( 1036 "num1", 1037 1, 1038 [1, 2, 3, 4, 5, -1, -2, -3, -4, -5], 1039 [ 1040 0, 1041 6, 1042 -6, 1043 "0", 1044 "6", 1045 "-6", 1046 0.0, 1047 6.0, 1048 -6.0, 1049 [1], 1050 (1,), 1051 {1: 1}, 1052 None, 1053 ], 1054 [1, 2, 3, 4, 5, -1, -2, -3, -4, -5], 1055 ) 1056 1057 def test_enum_exceptions(self): 1058 """ Check that enumerated values can be combined with nested 1059 TraitCompound handlers. 1060 """ 1061 self.check_values( 1062 "num4", 1, [1, 2, 3, -3, -2, -1, 10, ()], [0, 4, 5, -5, -4, 11] 1063 ) 1064 self.check_values( 1065 "num5", 1, [1, 2, 3, -3, -2, -1, 10, ()], [0, 4, 5, -5, -4, 11] 1066 ) 1067 1068 1069class test_list_value(test_base2): 1070 def setUp(self): 1071 with self.assertWarns(DeprecationWarning): 1072 1073 class list_value(HasTraits): 1074 # Trait definitions: 1075 list1 = Trait([2], TraitList(Trait([1, 2, 3, 4]), maxlen=4)) 1076 list2 = Trait( 1077 [2], TraitList(Trait([1, 2, 3, 4]), minlen=1, maxlen=4) 1078 ) 1079 alist = List() 1080 1081 self.obj = list_value() 1082 self.last_event = None 1083 1084 def tearDown(self): 1085 del self.last_event 1086 1087 def del_range(self, list, index1, index2): 1088 del list[index1:index2] 1089 1090 def del_extended_slice(self, list, index1, index2, step): 1091 del list[index1:index2:step] 1092 1093 def check_list(self, list): 1094 self.assertEqual(list, [2]) 1095 self.assertEqual(len(list), 1) 1096 list.append(3) 1097 self.assertEqual(len(list), 2) 1098 list[1] = 2 1099 self.assertEqual(list[1], 2) 1100 self.assertEqual(len(list), 2) 1101 list[0] = 1 1102 self.assertEqual(list[0], 1) 1103 self.assertEqual(len(list), 2) 1104 self.assertRaises(TraitError, self.indexed_assign, list, 0, 5) 1105 self.assertRaises(TraitError, list.append, 5) 1106 self.assertRaises(TraitError, list.extend, [1, 2, 3]) 1107 list.extend([3, 4]) 1108 self.assertEqual(list, [1, 2, 3, 4]) 1109 self.assertRaises(TraitError, list.append, 1) 1110 self.assertRaises( 1111 ValueError, self.extended_slice_assign, list, 0, 4, 2, [4, 5, 6] 1112 ) 1113 del list[1] 1114 self.assertEqual(list, [1, 3, 4]) 1115 del list[0] 1116 self.assertEqual(list, [3, 4]) 1117 list[:0] = [1, 2] 1118 self.assertEqual(list, [1, 2, 3, 4]) 1119 self.assertRaises( 1120 TraitError, self.indexed_range_assign, list, 0, 0, [1] 1121 ) 1122 del list[0:3] 1123 self.assertEqual(list, [4]) 1124 self.assertRaises( 1125 TraitError, self.indexed_range_assign, list, 0, 0, [4, 5] 1126 ) 1127 1128 def test_list1(self): 1129 self.check_list(self.obj.list1) 1130 1131 def test_list2(self): 1132 self.check_list(self.obj.list2) 1133 self.assertRaises(TraitError, self.del_range, self.obj.list2, 0, 1) 1134 self.assertRaises( 1135 TraitError, self.del_extended_slice, self.obj.list2, 4, -5, -1 1136 ) 1137 1138 def assertLastTraitListEventEqual(self, index, removed, added): 1139 self.assertEqual(self.last_event.index, index) 1140 self.assertEqual(self.last_event.removed, removed) 1141 self.assertEqual(self.last_event.added, added) 1142 1143 def test_trait_list_event(self): 1144 """ Record TraitListEvent behavior. 1145 """ 1146 self.obj.alist = [1, 2, 3, 4] 1147 self.obj.on_trait_change(self._record_trait_list_event, "alist_items") 1148 del self.obj.alist[0] 1149 self.assertLastTraitListEventEqual(0, [1], []) 1150 self.obj.alist.append(5) 1151 self.assertLastTraitListEventEqual(3, [], [5]) 1152 self.obj.alist[0:2] = [6, 7] 1153 self.assertLastTraitListEventEqual(0, [2, 3], [6, 7]) 1154 self.obj.alist[:2] = [4, 5] 1155 self.assertLastTraitListEventEqual(0, [6, 7], [4, 5]) 1156 self.obj.alist[0:2:1] = [8, 9] 1157 self.assertLastTraitListEventEqual(0, [4, 5], [8, 9]) 1158 self.obj.alist[0:2:1] = [8, 9] 1159 # If list values stay the same, a new TraitListEvent will be generated. 1160 self.assertLastTraitListEventEqual(0, [8, 9], [8, 9]) 1161 old_event = self.last_event 1162 self.obj.alist[4:] = [] 1163 # If no structural change, NO new TraitListEvent will be generated. 1164 self.assertIs(self.last_event, old_event) 1165 self.obj.alist[0:4:2] = [10, 11] 1166 self.assertLastTraitListEventEqual( 1167 slice(0, 3, 2), [8, 4], [10, 11] 1168 ) 1169 del self.obj.alist[1:4:2] 1170 self.assertLastTraitListEventEqual(slice(1, 4, 2), [9, 5], []) 1171 self.obj.alist = [1, 2, 3, 4] 1172 del self.obj.alist[2:4] 1173 self.assertLastTraitListEventEqual(2, [3, 4], []) 1174 self.obj.alist[:0] = [5, 6, 7, 8] 1175 self.assertLastTraitListEventEqual(0, [], [5, 6, 7, 8]) 1176 del self.obj.alist[:2] 1177 self.assertLastTraitListEventEqual(0, [5, 6], []) 1178 del self.obj.alist[0:2] 1179 self.assertLastTraitListEventEqual(0, [7, 8], []) 1180 del self.obj.alist[:] 1181 self.assertLastTraitListEventEqual(0, [1, 2], []) 1182 1183 def _record_trait_list_event(self, object, name, old, new): 1184 self.last_event = new 1185 1186 1187class ThisDummy(HasTraits): 1188 allows_none = This() 1189 disallows_none = This(allow_none=False) 1190 1191 1192class TestThis(unittest.TestCase): 1193 def test_this_none(self): 1194 d = ThisDummy() 1195 self.assertIsNone(d.allows_none) 1196 d.allows_none = None 1197 d.allows_none = ThisDummy() 1198 self.assertIsNotNone(d.allows_none) 1199 d.allows_none = None 1200 self.assertIsNone(d.allows_none) 1201 1202 # Still starts out as None, unavoidably. 1203 self.assertIsNone(d.disallows_none) 1204 d.disallows_none = ThisDummy() 1205 self.assertIsNotNone(d.disallows_none) 1206 with self.assertRaises(TraitError): 1207 d.disallows_none = None 1208 self.assertIsNotNone(d.disallows_none) 1209 1210 def test_this_other_class(self): 1211 d = ThisDummy() 1212 with self.assertRaises(TraitError): 1213 d.allows_none = object() 1214 self.assertIsNone(d.allows_none) 1215 1216 1217class ComparisonModeTests(unittest.TestCase): 1218 def test_comparison_mode_none(self): 1219 class HasComparisonMode(HasTraits): 1220 bar = Trait(comparison_mode=ComparisonMode.none) 1221 1222 old_compare = HasComparisonMode() 1223 events = [] 1224 old_compare.on_trait_change(lambda: events.append(None), "bar") 1225 1226 some_list = [1, 2, 3] 1227 1228 self.assertEqual(len(events), 0) 1229 old_compare.bar = some_list 1230 self.assertEqual(len(events), 1) 1231 old_compare.bar = some_list 1232 self.assertEqual(len(events), 2) 1233 old_compare.bar = [1, 2, 3] 1234 self.assertEqual(len(events), 3) 1235 old_compare.bar = [4, 5, 6] 1236 self.assertEqual(len(events), 4) 1237 1238 def test_comparison_mode_identity(self): 1239 class HasComparisonMode(HasTraits): 1240 bar = Trait(comparison_mode=ComparisonMode.identity) 1241 1242 old_compare = HasComparisonMode() 1243 events = [] 1244 old_compare.on_trait_change(lambda: events.append(None), "bar") 1245 1246 some_list = [1, 2, 3] 1247 1248 self.assertEqual(len(events), 0) 1249 old_compare.bar = some_list 1250 self.assertEqual(len(events), 1) 1251 old_compare.bar = some_list 1252 self.assertEqual(len(events), 1) 1253 old_compare.bar = [1, 2, 3] 1254 self.assertEqual(len(events), 2) 1255 old_compare.bar = [4, 5, 6] 1256 self.assertEqual(len(events), 3) 1257 1258 def test_comparison_mode_equality(self): 1259 class HasComparisonMode(HasTraits): 1260 bar = Trait(comparison_mode=ComparisonMode.equality) 1261 1262 old_compare = HasComparisonMode() 1263 events = [] 1264 old_compare.on_trait_change(lambda: events.append(None), "bar") 1265 1266 some_list = [1, 2, 3] 1267 1268 self.assertEqual(len(events), 0) 1269 old_compare.bar = some_list 1270 self.assertEqual(len(events), 1) 1271 old_compare.bar = some_list 1272 self.assertEqual(len(events), 1) 1273 old_compare.bar = [1, 2, 3] 1274 self.assertEqual(len(events), 1) 1275 old_compare.bar = [4, 5, 6] 1276 self.assertEqual(len(events), 2) 1277 1278 def test_rich_compare_false(self): 1279 with warnings.catch_warnings(record=True) as warn_msgs: 1280 warnings.simplefilter("always", DeprecationWarning) 1281 1282 class OldRichCompare(HasTraits): 1283 bar = Trait(rich_compare=False) 1284 1285 # Check for a DeprecationWarning. 1286 self.assertEqual(len(warn_msgs), 1) 1287 warn_msg = warn_msgs[0] 1288 self.assertIs(warn_msg.category, DeprecationWarning) 1289 self.assertIn( 1290 "'rich_compare' metadata has been deprecated", 1291 str(warn_msg.message) 1292 ) 1293 _, _, this_module = __name__.rpartition(".") 1294 self.assertIn(this_module, warn_msg.filename) 1295 1296 # Behaviour matches comparison_mode=ComparisonMode.identity. 1297 old_compare = OldRichCompare() 1298 events = [] 1299 old_compare.on_trait_change(lambda: events.append(None), "bar") 1300 1301 some_list = [1, 2, 3] 1302 1303 self.assertEqual(len(events), 0) 1304 old_compare.bar = some_list 1305 self.assertEqual(len(events), 1) 1306 old_compare.bar = some_list 1307 self.assertEqual(len(events), 1) 1308 old_compare.bar = [1, 2, 3] 1309 self.assertEqual(len(events), 2) 1310 old_compare.bar = [4, 5, 6] 1311 self.assertEqual(len(events), 3) 1312 1313 def test_rich_compare_true(self): 1314 with warnings.catch_warnings(record=True) as warn_msgs: 1315 warnings.simplefilter("always", DeprecationWarning) 1316 1317 class OldRichCompare(HasTraits): 1318 bar = Trait(rich_compare=True) 1319 1320 # Check for a DeprecationWarning. 1321 self.assertEqual(len(warn_msgs), 1) 1322 warn_msg = warn_msgs[0] 1323 self.assertIs(warn_msg.category, DeprecationWarning) 1324 self.assertIn( 1325 "'rich_compare' metadata has been deprecated", 1326 str(warn_msg.message) 1327 ) 1328 _, _, this_module = __name__.rpartition(".") 1329 self.assertIn(this_module, warn_msg.filename) 1330 1331 # Behaviour matches comparison_mode=ComparisonMode.identity. 1332 old_compare = OldRichCompare() 1333 events = [] 1334 old_compare.on_trait_change(lambda: events.append(None), "bar") 1335 1336 some_list = [1, 2, 3] 1337 1338 self.assertEqual(len(events), 0) 1339 old_compare.bar = some_list 1340 self.assertEqual(len(events), 1) 1341 old_compare.bar = some_list 1342 self.assertEqual(len(events), 1) 1343 old_compare.bar = [1, 2, 3] 1344 self.assertEqual(len(events), 1) 1345 old_compare.bar = [4, 5, 6] 1346 self.assertEqual(len(events), 2) 1347 1348 1349@requires_traitsui 1350class TestDeprecatedTraits(unittest.TestCase): 1351 1352 def test_color_deprecated(self): 1353 with self.assertWarnsRegex(DeprecationWarning, "'Color' in 'traits'"): 1354 Color() 1355 1356 def test_rgb_color_deprecated(self): 1357 with self.assertWarnsRegex(DeprecationWarning, 1358 "'RGBColor' in 'traits'"): 1359 RGBColor() 1360 1361 def test_font_deprecated(self): 1362 with self.assertWarnsRegex(DeprecationWarning, "'Font' in 'traits'"): 1363 Font() 1364