1# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX
2# All rights reserved.
3#
4# This software is provided without warranty under the terms of the BSD
5# license included in LICENSE.txt and may be redistributed only under
6# the conditions described in the aforementioned license. The license
7# is also available online at http://www.enthought.com/licenses/BSD.txt
8#
9# Thanks for using Enthought open source!
10
11#  Imports
12
13import unittest
14import warnings
15
16from traits.api import (
17    Any,
18    Bytes,
19    CBytes,
20    CFloat,
21    CInt,
22    ComparisonMode,
23    Color,
24    Delegate,
25    Float,
26    Font,
27    HasTraits,
28    Instance,
29    Int,
30    List,
31    Range,
32    RGBColor,
33    Str,
34    This,
35    Trait,
36    TraitError,
37    TraitList,
38    TraitPrefixList,
39    TraitPrefixMap,
40    Tuple,
41    pop_exception_handler,
42    push_exception_handler,
43)
44from traits.testing.optional_dependencies import requires_traitsui
45
46#  Base unit test classes:
47
48
49class BaseTest(object):
50    def assign(self, value):
51        self.obj.value = value
52
53    def coerce(self, value):
54        return value
55
56    def test_assignment(self):
57        obj = self.obj
58
59        # Validate default value
60        value = self._default_value
61        self.assertEqual(obj.value, value)
62
63        # Validate all legal values
64        for i, value in enumerate(self._good_values):
65            obj.value = value
66            self.assertEqual(obj.value, self.coerce(value))
67
68            # If there's a defined
69            if i < len(self._mapped_values):
70                self.assertEqual(obj.value_, self._mapped_values[i])
71
72        # Validate correct behavior for illegal values
73        for value in self._bad_values:
74            self.assertRaises(TraitError, self.assign, value)
75
76
77class test_base2(unittest.TestCase):
78    def indexed_assign(self, list, index, value):
79        list[index] = value
80
81    def indexed_range_assign(self, list, index1, index2, value):
82        list[index1:index2] = value
83
84    def extended_slice_assign(self, list, index1, index2, step, value):
85        list[index1:index2:step] = value
86
87    # This avoids using a method name that contains 'test' so that this is not
88    # called by the tester directly.
89    def check_values(
90        self,
91        name,
92        default_value,
93        good_values,
94        bad_values,
95        actual_values=None,
96        mapped_values=None,
97    ):
98        obj = self.obj
99
100        # Make sure the default value is correct:
101        value = default_value
102        self.assertEqual(getattr(obj, name), value)
103
104        # Iterate over all legal values being tested:
105        if actual_values is None:
106            actual_values = good_values
107        i = 0
108        for value in good_values:
109            setattr(obj, name, value)
110            self.assertEqual(getattr(obj, name), actual_values[i])
111            if mapped_values is not None:
112                self.assertEqual(
113                    getattr(obj, name + "_"), mapped_values[i]
114                )
115            i += 1
116
117        # Iterate over all illegal values being tested:
118        for value in bad_values:
119            self.assertRaises(TraitError, setattr, obj, name, value)
120
121
122class AnyTrait(HasTraits):
123    value = Any
124
125
126class AnyTraitTest(BaseTest, unittest.TestCase):
127
128    def setUp(self):
129        self.obj = AnyTrait()
130
131    _default_value = None
132    _good_values = [10.0, b"ten", "ten", [10], {"ten": 10}, (10,), None, 1j]
133    _mapped_values = []
134    _bad_values = []
135
136
137class CoercibleIntTrait(HasTraits):
138    value = CInt(99)
139
140
141class IntTrait(HasTraits):
142    value = Int(99)
143
144
145class CoercibleIntTest(AnyTraitTest):
146
147    def setUp(self):
148        self.obj = CoercibleIntTrait()
149
150    _default_value = 99
151    _good_values = [
152        10,
153        -10,
154        10.1,
155        -10.1,
156        "10",
157        "-10",
158        b"10",
159        b"-10",
160    ]
161    _bad_values = [
162        "10L",
163        "-10L",
164        "10.1",
165        "-10.1",
166        b"10L",
167        b"-10L",
168        b"10.1",
169        b"-10.1",
170        "ten",
171        b"ten",
172        [10],
173        {"ten": 10},
174        (10,),
175        None,
176        1j,
177    ]
178
179    def coerce(self, value):
180        try:
181            return int(value)
182        except:
183            return int(float(value))
184
185
186class IntTest(AnyTraitTest):
187
188    def setUp(self):
189        self.obj = IntTrait()
190
191    _default_value = 99
192    _good_values = [10, -10]
193    _bad_values = [
194        "ten",
195        b"ten",
196        [10],
197        {"ten": 10},
198        (10,),
199        None,
200        1j,
201        10.1,
202        -10.1,
203        "10L",
204        "-10L",
205        "10.1",
206        "-10.1",
207        b"10L",
208        b"-10L",
209        b"10.1",
210        b"-10.1",
211        "10",
212        "-10",
213        b"10",
214        b"-10",
215    ]
216
217    try:
218        import numpy as np
219    except ImportError:
220        pass
221    else:
222        _good_values.extend(
223            [
224                np.int64(10),
225                np.int64(-10),
226                np.int32(10),
227                np.int32(-10),
228                np.int_(10),
229                np.int_(-10),
230            ]
231        )
232
233    def coerce(self, value):
234        try:
235            return int(value)
236        except:
237            return int(float(value))
238
239
240class CoercibleFloatTrait(HasTraits):
241    value = CFloat(99.0)
242
243
244class FloatTrait(HasTraits):
245    value = Float(99.0)
246
247
248class CoercibleFloatTest(AnyTraitTest):
249    def setUp(self):
250        self.obj = CoercibleFloatTrait()
251
252    _default_value = 99.0
253    _good_values = [
254        10,
255        -10,
256        10.1,
257        -10.1,
258        "10",
259        "-10",
260        "10.1",
261        "-10.1",
262        b"10",
263        b"-10",
264        b"10.1",
265        b"-10.1",
266    ]
267    _bad_values = [
268        "10L",
269        "-10L",
270        b"10L",
271        b"-10L",
272        "ten",
273        b"ten",
274        [10],
275        {"ten": 10},
276        (10,),
277        None,
278        1j,
279    ]
280
281    def coerce(self, value):
282        return float(value)
283
284
285class FloatTest(AnyTraitTest):
286    def setUp(self):
287        self.obj = FloatTrait()
288
289    _default_value = 99.0
290    _good_values = [10, -10, 10.1, -10.1]
291    _bad_values = [
292        "ten",
293        b"ten",
294        [10],
295        {"ten": 10},
296        (10,),
297        None,
298        1j,
299        "10",
300        "-10",
301        "10L",
302        "-10L",
303        "10.1",
304        "-10.1",
305        b"10",
306        b"-10",
307        b"10L",
308        b"-10L",
309        b"10.1",
310        b"-10.1",
311    ]
312
313    def coerce(self, value):
314        return float(value)
315
316
317#  Trait that can only have 'complex'(i.e. imaginary) values:
318
319
320class ImaginaryValueTrait(HasTraits):
321    value = Trait(99.0 - 99.0j)
322
323
324class ImaginaryValueTest(AnyTraitTest):
325    def setUp(self):
326        self.obj = ImaginaryValueTrait()
327
328    _default_value = 99.0 - 99.0j
329    _good_values = [
330        10,
331        -10,
332        10.1,
333        -10.1,
334        "10",
335        "-10",
336        "10.1",
337        "-10.1",
338        10j,
339        10 + 10j,
340        10 - 10j,
341        10.1j,
342        10.1 + 10.1j,
343        10.1 - 10.1j,
344        "10j",
345        "10+10j",
346        "10-10j",
347    ]
348    _bad_values = [b"10L", "-10L", "ten", [10], {"ten": 10}, (10,), None]
349
350    def coerce(self, value):
351        return complex(value)
352
353
354class StringTrait(HasTraits):
355    value = Trait("string")
356
357
358class StringTest(AnyTraitTest):
359
360    def setUp(self):
361        self.obj = StringTrait()
362
363    _default_value = "string"
364    _good_values = [
365        10,
366        -10,
367        10.1,
368        -10.1,
369        "10",
370        "-10",
371        "10L",
372        "-10L",
373        "10.1",
374        "-10.1",
375        "string",
376        1j,
377        [10],
378        ["ten"],
379        {"ten": 10},
380        (10,),
381        None,
382    ]
383    _bad_values = []
384
385    def coerce(self, value):
386        return str(value)
387
388
389class BytesTrait(HasTraits):
390    value = Bytes(b"bytes")
391
392
393class BytesTest(StringTest):
394
395    def setUp(self):
396        self.obj = BytesTrait()
397
398    _default_value = b"bytes"
399    _good_values = [b"", b"10", b"-10"]
400    _bad_values = [
401        10,
402        -10,
403        10.1,
404        [b""],
405        [b"bytes"],
406        [0],
407        {b"ten": b"10"},
408        (b"",),
409        None,
410        True,
411        "",
412        "string",
413    ]
414
415    def coerce(self, value):
416        return bytes(value)
417
418
419class CoercibleBytesTrait(HasTraits):
420    value = CBytes(b"bytes")
421
422
423class CoercibleBytesTest(StringTest):
424
425    def setUp(self):
426        self.obj = CoercibleBytesTrait()
427
428    _default_value = b"bytes"
429    _good_values = [
430        b"",
431        b"10",
432        b"-10",
433        10,
434        [10],
435        (10,),
436        set([10]),
437        {10: "foo"},
438        True,
439    ]
440    _bad_values = [
441        "",
442        "string",
443        -10,
444        10.1,
445        [b""],
446        [b"bytes"],
447        [-10],
448        (-10,),
449        {-10: "foo"},
450        set([-10]),
451        [256],
452        (256,),
453        {256: "foo"},
454        set([256]),
455        {b"ten": b"10"},
456        (b"",),
457        None,
458    ]
459
460    def coerce(self, value):
461        return bytes(value)
462
463
464class EnumTrait(HasTraits):
465    value = Trait([1, "one", 2, "two", 3, "three", 4.4, "four.four"])
466
467
468class EnumTest(AnyTraitTest):
469
470    def setUp(self):
471        self.obj = EnumTrait()
472
473    _default_value = 1
474    _good_values = [1, "one", 2, "two", 3, "three", 4.4, "four.four"]
475    _bad_values = [0, "zero", 4, None]
476
477
478class MappedTrait(HasTraits):
479    value = Trait("one", {"one": 1, "two": 2, "three": 3})
480
481
482class MappedTest(AnyTraitTest):
483    def setUp(self):
484        self.obj = MappedTrait()
485
486    _default_value = "one"
487    _good_values = ["one", "two", "three"]
488    _mapped_values = [1, 2, 3]
489    _bad_values = ["four", 1, 2, 3, [1], (1,), {1: 1}, None]
490
491
492# Suppress DeprecationWarning from TraitPrefixList instantiation.
493with warnings.catch_warnings():
494    warnings.filterwarnings(action="ignore", category=DeprecationWarning)
495
496    class PrefixListTrait(HasTraits):
497        value = Trait("one", TraitPrefixList("one", "two", "three"))
498
499
500class PrefixListTest(AnyTraitTest):
501    def setUp(self):
502        self.obj = PrefixListTrait()
503
504    _default_value = "one"
505    _good_values = [
506        "o",
507        "on",
508        "one",
509        "tw",
510        "two",
511        "th",
512        "thr",
513        "thre",
514        "three",
515    ]
516    _bad_values = ["t", "one ", " two", 1, None]
517
518    def coerce(self, value):
519        return {"o": "one", "on": "one", "tw": "two", "th": "three"}[value[:2]]
520
521
522# Suppress DeprecationWarning from TraitPrefixMap instantiation.
523with warnings.catch_warnings():
524    warnings.filterwarnings(action="ignore", category=DeprecationWarning)
525
526    class PrefixMapTrait(HasTraits):
527        value = Trait("one", TraitPrefixMap({"one": 1, "two": 2, "three": 3}))
528
529
530class PrefixMapTest(AnyTraitTest):
531    def setUp(self):
532        self.obj = PrefixMapTrait()
533
534    _default_value = "one"
535    _good_values = [
536        "o",
537        "on",
538        "one",
539        "tw",
540        "two",
541        "th",
542        "thr",
543        "thre",
544        "three",
545    ]
546    _mapped_values = [1, 1, 1, 2, 2, 3, 3, 3]
547    _bad_values = ["t", "one ", " two", 1, None]
548
549    def coerce(self, value):
550        return {"o": "one", "on": "one", "tw": "two", "th": "three"}[value[:2]]
551
552
553# This test a combination of Trait, a default, a mapping and a function
554
555def str_cast_to_int(object, name, value):
556    """ A function that validates the value is a str and then converts
557    it to an int using its length.
558    """
559    if not isinstance(value, str):
560        raise TraitError("Not a string!")
561    return len(value)
562
563
564class TraitWithMappingAndCallable(HasTraits):
565
566    value = Trait(
567        "white",
568        {"white": 0, "red": 1, (0, 0, 0): 999},
569        str_cast_to_int,
570    )
571
572
573class TestTraitWithMappingAndCallable(unittest.TestCase):
574    """ Test that demonstrates a usage of Trait where TraitMap is used but it
575    cannot be replaced with Map. The callable causes the key value to be
576    changed to match the mapped value.
577
578    e.g. this would not work:
579
580        value = Union(
581            Map({"white": 0, "red": 1, (0,0,0): 999}),
582            NewTraitType(),
583            default_value="white",
584        )
585
586        where NewTraitType is a subclass of TraitType with ``validate`` simply
587        calls str_cast_to_int
588    """
589
590    def test_trait_default(self):
591        obj = TraitWithMappingAndCallable()
592
593        # the value is not 'white' any more.
594        self.assertEqual(obj.value, 5)
595        self.assertEqual(obj.value_, 5)
596
597    def test_trait_set_value_use_callable(self):
598        obj = TraitWithMappingAndCallable(value="red")
599
600        # The value is not 'red' any more.
601        # the callable is used, not the mapping.
602        self.assertEqual(obj.value, 3)
603        self.assertEqual(obj.value_, 3)
604
605    def test_trait_set_value_use_mapping(self):
606        obj = TraitWithMappingAndCallable(value=(0, 0, 0))
607
608        # Now this uses the mapping, and the value is the original one.
609        self.assertEqual(obj.value, (0, 0, 0))
610        self.assertEqual(obj.value_, 999)
611
612
613# Old style class version:
614
615
616class OTraitTest1:
617    pass
618
619
620class OTraitTest2(OTraitTest1):
621    pass
622
623
624class OTraitTest3(OTraitTest2):
625    pass
626
627
628class OBadTraitTest:
629    pass
630
631
632otrait_test1 = OTraitTest1()
633
634
635class OldInstanceTrait(HasTraits):
636    value = Trait(otrait_test1)
637
638
639class OldInstanceTest(AnyTraitTest):
640    def setUp(self):
641        self.obj = OldInstanceTrait()
642
643    _default_value = otrait_test1
644    _good_values = [
645        otrait_test1,
646        OTraitTest1(),
647        OTraitTest2(),
648        OTraitTest3(),
649        None,
650    ]
651    _bad_values = [
652        0,
653        0.0,
654        0j,
655        OTraitTest1,
656        OTraitTest2,
657        OBadTraitTest(),
658        b"bytes",
659        "string",
660        [otrait_test1],
661        (otrait_test1,),
662        {"data": otrait_test1},
663    ]
664
665
666# New style class version:
667class NTraitTest1(object):
668    pass
669
670
671class NTraitTest2(NTraitTest1):
672    pass
673
674
675class NTraitTest3(NTraitTest2):
676    pass
677
678
679class NBadTraitTest:
680    pass
681
682
683ntrait_test1 = NTraitTest1()
684
685
686class NewInstanceTrait(HasTraits):
687    value = Trait(ntrait_test1)
688
689
690class NewInstanceTest(AnyTraitTest):
691    def setUp(self):
692        self.obj = NewInstanceTrait()
693
694    _default_value = ntrait_test1
695    _good_values = [
696        ntrait_test1,
697        NTraitTest1(),
698        NTraitTest2(),
699        NTraitTest3(),
700        None,
701    ]
702    _bad_values = [
703        0,
704        0.0,
705        0j,
706        NTraitTest1,
707        NTraitTest2,
708        NBadTraitTest(),
709        b"bytes",
710        "string",
711        [ntrait_test1],
712        (ntrait_test1,),
713        {"data": ntrait_test1},
714    ]
715
716
717class FactoryClass(HasTraits):
718    pass
719
720
721class ConsumerClass(HasTraits):
722    x = Instance(FactoryClass, ())
723
724
725class ConsumerSubclass(ConsumerClass):
726    x = FactoryClass()
727
728
729embedded_instance_trait = Trait(
730    "", Str, Instance("traits.has_traits.HasTraits")
731)
732
733
734class Dummy(HasTraits):
735    x = embedded_instance_trait
736    xl = List(embedded_instance_trait)
737
738
739class RegressionTest(unittest.TestCase):
740    """ Check that fixed bugs stay fixed.
741    """
742
743    def test_factory_subclass_no_segfault(self):
744        """ Test that we can provide an instance as a default in the definition
745        of a subclass.
746        """
747        # There used to be a bug where this would segfault.
748        obj = ConsumerSubclass()
749        obj.x
750
751    def test_trait_compound_instance(self):
752        """ Test that a deferred Instance() embedded in a TraitCompound handler
753        and then a list will not replace the validate method for the outermost
754        trait.
755        """
756        # Pass through an instance in order to make the instance trait resolve
757        # the class.
758        d = Dummy()
759        d.xl = [HasTraits()]
760        d.x = "OK"
761
762
763#  Trait(using a function) that must be an odd integer:
764
765
766def odd_integer(object, name, value):
767    try:
768        float(value)
769        if (value % 2) == 1:
770            return int(value)
771    except:
772        pass
773    raise TraitError
774
775
776class OddIntegerTrait(HasTraits):
777    value = Trait(99, odd_integer)
778
779
780class OddIntegerTest(AnyTraitTest):
781    def setUp(self):
782        self.obj = OddIntegerTrait()
783
784    _default_value = 99
785    _good_values = [
786        1,
787        3,
788        5,
789        7,
790        9,
791        999999999,
792        1.0,
793        3.0,
794        5.0,
795        7.0,
796        9.0,
797        999999999.0,
798        -1,
799        -3,
800        -5,
801        -7,
802        -9,
803        -999999999,
804        -1.0,
805        -3.0,
806        -5.0,
807        -7.0,
808        -9.0,
809        -999999999.0,
810    ]
811    _bad_values = [0, 2, -2, 1j, None, "1", [1], (1,), {1: 1}]
812
813
814class NotifierTraits(HasTraits):
815    value1 = Int
816    value2 = Int
817    value1_count = Int
818    value2_count = Int
819
820    def _anytrait_changed(self, trait_name, old, new):
821        if trait_name == "value1":
822            self.value1_count += 1
823        elif trait_name == "value2":
824            self.value2_count += 1
825
826    def _value1_changed(self, old, new):
827        self.value1_count += 1
828
829    def _value2_changed(self, old, new):
830        self.value2_count += 1
831
832
833class NotifierTests(unittest.TestCase):
834    def setUp(self):
835        obj = self.obj = NotifierTraits()
836        obj.value1 = 0
837        obj.value2 = 0
838        obj.value1_count = 0
839        obj.value2_count = 0
840
841    def tearDown(self):
842        obj = self.obj
843        obj.on_trait_change(self.on_value1_changed, "value1", remove=True)
844        obj.on_trait_change(self.on_value2_changed, "value2", remove=True)
845        obj.on_trait_change(self.on_anytrait_changed, remove=True)
846
847    def on_anytrait_changed(self, object, trait_name, old, new):
848        if trait_name == "value1":
849            self.obj.value1_count += 1
850        elif trait_name == "value2":
851            self.obj.value2_count += 1
852
853    def on_value1_changed(self):
854        self.obj.value1_count += 1
855
856    def on_value2_changed(self):
857        self.obj.value2_count += 1
858
859    def test_simple(self):
860        obj = self.obj
861
862        obj.value1 = 1
863        self.assertEqual(obj.value1_count, 2)
864        self.assertEqual(obj.value2_count, 0)
865
866        obj.value2 = 1
867        self.assertEqual(obj.value1_count, 2)
868        self.assertEqual(obj.value2_count, 2)
869
870    def test_complex(self):
871        obj = self.obj
872
873        obj.on_trait_change(self.on_value1_changed, "value1")
874        obj.value1 = 1
875        self.assertEqual(obj.value1_count, 3)
876        self.assertEqual(obj.value2_count, 0)
877
878        obj.on_trait_change(self.on_value2_changed, "value2")
879        obj.value2 = 1
880        self.assertEqual(obj.value1_count, 3)
881        self.assertEqual(obj.value2_count, 3)
882
883        obj.on_trait_change(self.on_anytrait_changed)
884
885        obj.value1 = 2
886        self.assertEqual(obj.value1_count, 7)
887        self.assertEqual(obj.value2_count, 3)
888
889        obj.value1 = 2
890        self.assertEqual(obj.value1_count, 7)
891        self.assertEqual(obj.value2_count, 3)
892
893        obj.value2 = 2
894        self.assertEqual(obj.value1_count, 7)
895        self.assertEqual(obj.value2_count, 7)
896
897        obj.on_trait_change(self.on_value1_changed, "value1", remove=True)
898        obj.value1 = 3
899        self.assertEqual(obj.value1_count, 10)
900        self.assertEqual(obj.value2_count, 7)
901
902        obj.on_trait_change(self.on_value2_changed, "value2", remove=True)
903        obj.value2 = 3
904        self.assertEqual(obj.value1_count, 10)
905        self.assertEqual(obj.value2_count, 10)
906
907        obj.on_trait_change(self.on_anytrait_changed, remove=True)
908
909        obj.value1 = 4
910        self.assertEqual(obj.value1_count, 12)
911        self.assertEqual(obj.value2_count, 10)
912
913        obj.value2 = 4
914        self.assertEqual(obj.value1_count, 12)
915        self.assertEqual(obj.value2_count, 12)
916
917
918class RaisesArgumentlessRuntimeError(HasTraits):
919    x = Int(0)
920
921    def _x_changed(self):
922        raise RuntimeError
923
924
925class TestRuntimeError(unittest.TestCase):
926    def setUp(self):
927        push_exception_handler(lambda *args: None, reraise_exceptions=True)
928
929    def tearDown(self):
930        pop_exception_handler()
931
932    def test_runtime_error(self):
933        f = RaisesArgumentlessRuntimeError()
934        self.assertRaises(RuntimeError, setattr, f, "x", 5)
935
936
937class DelegatedFloatTrait(HasTraits):
938    value = Trait(99.0)
939
940
941class DelegateTrait(HasTraits):
942    value = Delegate("delegate")
943    delegate = Trait(DelegatedFloatTrait())
944
945
946class DelegateTrait2(DelegateTrait):
947    delegate = Trait(DelegateTrait())
948
949
950class DelegateTrait3(DelegateTrait):
951    delegate = Trait(DelegateTrait2())
952
953
954class DelegateTests(unittest.TestCase):
955    def test_delegation(self):
956        obj = DelegateTrait3()
957
958        self.assertEqual(obj.value, 99.0)
959        parent1 = obj.delegate
960        parent2 = parent1.delegate
961        parent3 = parent2.delegate
962        parent3.value = 3.0
963        self.assertEqual(obj.value, 3.0)
964        parent2.value = 2.0
965        self.assertEqual(obj.value, 2.0)
966        self.assertEqual(parent3.value, 3.0)
967        parent1.value = 1.0
968        self.assertEqual(obj.value, 1.0)
969        self.assertEqual(parent2.value, 2.0)
970        self.assertEqual(parent3.value, 3.0)
971        obj.value = 0.0
972        self.assertEqual(obj.value, 0.0)
973        self.assertEqual(parent1.value, 1.0)
974        self.assertEqual(parent2.value, 2.0)
975        self.assertEqual(parent3.value, 3.0)
976        del obj.value
977        self.assertEqual(obj.value, 1.0)
978        del parent1.value
979        self.assertEqual(obj.value, 2.0)
980        self.assertEqual(parent1.value, 2.0)
981        del parent2.value
982        self.assertEqual(obj.value, 3.0)
983        self.assertEqual(parent1.value, 3.0)
984        self.assertEqual(parent2.value, 3.0)
985        del parent3.value
986        # Uncommenting the following line allows
987        # the last assertions to pass. However, this
988        # may not be intended behavior, so keeping
989        # the line commented.
990        # del parent2.value
991        self.assertEqual(obj.value, 99.0)
992        self.assertEqual(parent1.value, 99.0)
993        self.assertEqual(parent2.value, 99.0)
994        self.assertEqual(parent3.value, 99.0)
995
996
997#  Complex(i.e. 'composite') Traits tests:
998
999# Make a TraitCompound handler that does not have a fast_validate so we can
1000# check for a particular regression.
1001slow = Trait(1, Range(1, 3), Range(-3, -1))
1002try:
1003    del slow.handler.fast_validate
1004except AttributeError:
1005    pass
1006
1007
1008# Suppress DeprecationWarnings from TraitPrefixList and TraitPrefixMap
1009with warnings.catch_warnings():
1010    warnings.filterwarnings(action="ignore", category=DeprecationWarning)
1011
1012    class complex_value(HasTraits):
1013        num1 = Trait(1, Range(1, 5), Range(-5, -1))
1014        num2 = Trait(
1015            1,
1016            Range(1, 5),
1017            TraitPrefixList("one", "two", "three", "four", "five"),
1018        )
1019        num3 = Trait(
1020            1,
1021            Range(1, 5),
1022            TraitPrefixMap(
1023                {"one": 1, "two": 2, "three": 3, "four": 4, "five": 5}
1024            ),
1025        )
1026        num4 = Trait(1, Trait(1, Tuple, slow), 10)
1027        num5 = Trait(1, 10, Trait(1, Tuple, slow))
1028
1029
1030class test_complex_value(test_base2):
1031    def setUp(self):
1032        self.obj = complex_value()
1033
1034    def test_num1(self):
1035        self.check_values(
1036            "num1",
1037            1,
1038            [1, 2, 3, 4, 5, -1, -2, -3, -4, -5],
1039            [
1040                0,
1041                6,
1042                -6,
1043                "0",
1044                "6",
1045                "-6",
1046                0.0,
1047                6.0,
1048                -6.0,
1049                [1],
1050                (1,),
1051                {1: 1},
1052                None,
1053            ],
1054            [1, 2, 3, 4, 5, -1, -2, -3, -4, -5],
1055        )
1056
1057    def test_enum_exceptions(self):
1058        """ Check that enumerated values can be combined with nested
1059        TraitCompound handlers.
1060        """
1061        self.check_values(
1062            "num4", 1, [1, 2, 3, -3, -2, -1, 10, ()], [0, 4, 5, -5, -4, 11]
1063        )
1064        self.check_values(
1065            "num5", 1, [1, 2, 3, -3, -2, -1, 10, ()], [0, 4, 5, -5, -4, 11]
1066        )
1067
1068
1069class test_list_value(test_base2):
1070    def setUp(self):
1071        with self.assertWarns(DeprecationWarning):
1072
1073            class list_value(HasTraits):
1074                # Trait definitions:
1075                list1 = Trait([2], TraitList(Trait([1, 2, 3, 4]), maxlen=4))
1076                list2 = Trait(
1077                    [2], TraitList(Trait([1, 2, 3, 4]), minlen=1, maxlen=4)
1078                )
1079                alist = List()
1080
1081        self.obj = list_value()
1082        self.last_event = None
1083
1084    def tearDown(self):
1085        del self.last_event
1086
1087    def del_range(self, list, index1, index2):
1088        del list[index1:index2]
1089
1090    def del_extended_slice(self, list, index1, index2, step):
1091        del list[index1:index2:step]
1092
1093    def check_list(self, list):
1094        self.assertEqual(list, [2])
1095        self.assertEqual(len(list), 1)
1096        list.append(3)
1097        self.assertEqual(len(list), 2)
1098        list[1] = 2
1099        self.assertEqual(list[1], 2)
1100        self.assertEqual(len(list), 2)
1101        list[0] = 1
1102        self.assertEqual(list[0], 1)
1103        self.assertEqual(len(list), 2)
1104        self.assertRaises(TraitError, self.indexed_assign, list, 0, 5)
1105        self.assertRaises(TraitError, list.append, 5)
1106        self.assertRaises(TraitError, list.extend, [1, 2, 3])
1107        list.extend([3, 4])
1108        self.assertEqual(list, [1, 2, 3, 4])
1109        self.assertRaises(TraitError, list.append, 1)
1110        self.assertRaises(
1111            ValueError, self.extended_slice_assign, list, 0, 4, 2, [4, 5, 6]
1112        )
1113        del list[1]
1114        self.assertEqual(list, [1, 3, 4])
1115        del list[0]
1116        self.assertEqual(list, [3, 4])
1117        list[:0] = [1, 2]
1118        self.assertEqual(list, [1, 2, 3, 4])
1119        self.assertRaises(
1120            TraitError, self.indexed_range_assign, list, 0, 0, [1]
1121        )
1122        del list[0:3]
1123        self.assertEqual(list, [4])
1124        self.assertRaises(
1125            TraitError, self.indexed_range_assign, list, 0, 0, [4, 5]
1126        )
1127
1128    def test_list1(self):
1129        self.check_list(self.obj.list1)
1130
1131    def test_list2(self):
1132        self.check_list(self.obj.list2)
1133        self.assertRaises(TraitError, self.del_range, self.obj.list2, 0, 1)
1134        self.assertRaises(
1135            TraitError, self.del_extended_slice, self.obj.list2, 4, -5, -1
1136        )
1137
1138    def assertLastTraitListEventEqual(self, index, removed, added):
1139        self.assertEqual(self.last_event.index, index)
1140        self.assertEqual(self.last_event.removed, removed)
1141        self.assertEqual(self.last_event.added, added)
1142
1143    def test_trait_list_event(self):
1144        """ Record TraitListEvent behavior.
1145        """
1146        self.obj.alist = [1, 2, 3, 4]
1147        self.obj.on_trait_change(self._record_trait_list_event, "alist_items")
1148        del self.obj.alist[0]
1149        self.assertLastTraitListEventEqual(0, [1], [])
1150        self.obj.alist.append(5)
1151        self.assertLastTraitListEventEqual(3, [], [5])
1152        self.obj.alist[0:2] = [6, 7]
1153        self.assertLastTraitListEventEqual(0, [2, 3], [6, 7])
1154        self.obj.alist[:2] = [4, 5]
1155        self.assertLastTraitListEventEqual(0, [6, 7], [4, 5])
1156        self.obj.alist[0:2:1] = [8, 9]
1157        self.assertLastTraitListEventEqual(0, [4, 5], [8, 9])
1158        self.obj.alist[0:2:1] = [8, 9]
1159        # If list values stay the same, a new TraitListEvent will be generated.
1160        self.assertLastTraitListEventEqual(0, [8, 9], [8, 9])
1161        old_event = self.last_event
1162        self.obj.alist[4:] = []
1163        # If no structural change, NO new TraitListEvent will be generated.
1164        self.assertIs(self.last_event, old_event)
1165        self.obj.alist[0:4:2] = [10, 11]
1166        self.assertLastTraitListEventEqual(
1167            slice(0, 3, 2), [8, 4], [10, 11]
1168        )
1169        del self.obj.alist[1:4:2]
1170        self.assertLastTraitListEventEqual(slice(1, 4, 2), [9, 5], [])
1171        self.obj.alist = [1, 2, 3, 4]
1172        del self.obj.alist[2:4]
1173        self.assertLastTraitListEventEqual(2, [3, 4], [])
1174        self.obj.alist[:0] = [5, 6, 7, 8]
1175        self.assertLastTraitListEventEqual(0, [], [5, 6, 7, 8])
1176        del self.obj.alist[:2]
1177        self.assertLastTraitListEventEqual(0, [5, 6], [])
1178        del self.obj.alist[0:2]
1179        self.assertLastTraitListEventEqual(0, [7, 8], [])
1180        del self.obj.alist[:]
1181        self.assertLastTraitListEventEqual(0, [1, 2], [])
1182
1183    def _record_trait_list_event(self, object, name, old, new):
1184        self.last_event = new
1185
1186
1187class ThisDummy(HasTraits):
1188    allows_none = This()
1189    disallows_none = This(allow_none=False)
1190
1191
1192class TestThis(unittest.TestCase):
1193    def test_this_none(self):
1194        d = ThisDummy()
1195        self.assertIsNone(d.allows_none)
1196        d.allows_none = None
1197        d.allows_none = ThisDummy()
1198        self.assertIsNotNone(d.allows_none)
1199        d.allows_none = None
1200        self.assertIsNone(d.allows_none)
1201
1202        # Still starts out as None, unavoidably.
1203        self.assertIsNone(d.disallows_none)
1204        d.disallows_none = ThisDummy()
1205        self.assertIsNotNone(d.disallows_none)
1206        with self.assertRaises(TraitError):
1207            d.disallows_none = None
1208        self.assertIsNotNone(d.disallows_none)
1209
1210    def test_this_other_class(self):
1211        d = ThisDummy()
1212        with self.assertRaises(TraitError):
1213            d.allows_none = object()
1214        self.assertIsNone(d.allows_none)
1215
1216
1217class ComparisonModeTests(unittest.TestCase):
1218    def test_comparison_mode_none(self):
1219        class HasComparisonMode(HasTraits):
1220            bar = Trait(comparison_mode=ComparisonMode.none)
1221
1222        old_compare = HasComparisonMode()
1223        events = []
1224        old_compare.on_trait_change(lambda: events.append(None), "bar")
1225
1226        some_list = [1, 2, 3]
1227
1228        self.assertEqual(len(events), 0)
1229        old_compare.bar = some_list
1230        self.assertEqual(len(events), 1)
1231        old_compare.bar = some_list
1232        self.assertEqual(len(events), 2)
1233        old_compare.bar = [1, 2, 3]
1234        self.assertEqual(len(events), 3)
1235        old_compare.bar = [4, 5, 6]
1236        self.assertEqual(len(events), 4)
1237
1238    def test_comparison_mode_identity(self):
1239        class HasComparisonMode(HasTraits):
1240            bar = Trait(comparison_mode=ComparisonMode.identity)
1241
1242        old_compare = HasComparisonMode()
1243        events = []
1244        old_compare.on_trait_change(lambda: events.append(None), "bar")
1245
1246        some_list = [1, 2, 3]
1247
1248        self.assertEqual(len(events), 0)
1249        old_compare.bar = some_list
1250        self.assertEqual(len(events), 1)
1251        old_compare.bar = some_list
1252        self.assertEqual(len(events), 1)
1253        old_compare.bar = [1, 2, 3]
1254        self.assertEqual(len(events), 2)
1255        old_compare.bar = [4, 5, 6]
1256        self.assertEqual(len(events), 3)
1257
1258    def test_comparison_mode_equality(self):
1259        class HasComparisonMode(HasTraits):
1260            bar = Trait(comparison_mode=ComparisonMode.equality)
1261
1262        old_compare = HasComparisonMode()
1263        events = []
1264        old_compare.on_trait_change(lambda: events.append(None), "bar")
1265
1266        some_list = [1, 2, 3]
1267
1268        self.assertEqual(len(events), 0)
1269        old_compare.bar = some_list
1270        self.assertEqual(len(events), 1)
1271        old_compare.bar = some_list
1272        self.assertEqual(len(events), 1)
1273        old_compare.bar = [1, 2, 3]
1274        self.assertEqual(len(events), 1)
1275        old_compare.bar = [4, 5, 6]
1276        self.assertEqual(len(events), 2)
1277
1278    def test_rich_compare_false(self):
1279        with warnings.catch_warnings(record=True) as warn_msgs:
1280            warnings.simplefilter("always", DeprecationWarning)
1281
1282            class OldRichCompare(HasTraits):
1283                bar = Trait(rich_compare=False)
1284
1285        # Check for a DeprecationWarning.
1286        self.assertEqual(len(warn_msgs), 1)
1287        warn_msg = warn_msgs[0]
1288        self.assertIs(warn_msg.category, DeprecationWarning)
1289        self.assertIn(
1290            "'rich_compare' metadata has been deprecated",
1291            str(warn_msg.message)
1292        )
1293        _, _, this_module = __name__.rpartition(".")
1294        self.assertIn(this_module, warn_msg.filename)
1295
1296        # Behaviour matches comparison_mode=ComparisonMode.identity.
1297        old_compare = OldRichCompare()
1298        events = []
1299        old_compare.on_trait_change(lambda: events.append(None), "bar")
1300
1301        some_list = [1, 2, 3]
1302
1303        self.assertEqual(len(events), 0)
1304        old_compare.bar = some_list
1305        self.assertEqual(len(events), 1)
1306        old_compare.bar = some_list
1307        self.assertEqual(len(events), 1)
1308        old_compare.bar = [1, 2, 3]
1309        self.assertEqual(len(events), 2)
1310        old_compare.bar = [4, 5, 6]
1311        self.assertEqual(len(events), 3)
1312
1313    def test_rich_compare_true(self):
1314        with warnings.catch_warnings(record=True) as warn_msgs:
1315            warnings.simplefilter("always", DeprecationWarning)
1316
1317            class OldRichCompare(HasTraits):
1318                bar = Trait(rich_compare=True)
1319
1320        # Check for a DeprecationWarning.
1321        self.assertEqual(len(warn_msgs), 1)
1322        warn_msg = warn_msgs[0]
1323        self.assertIs(warn_msg.category, DeprecationWarning)
1324        self.assertIn(
1325            "'rich_compare' metadata has been deprecated",
1326            str(warn_msg.message)
1327        )
1328        _, _, this_module = __name__.rpartition(".")
1329        self.assertIn(this_module, warn_msg.filename)
1330
1331        # Behaviour matches comparison_mode=ComparisonMode.identity.
1332        old_compare = OldRichCompare()
1333        events = []
1334        old_compare.on_trait_change(lambda: events.append(None), "bar")
1335
1336        some_list = [1, 2, 3]
1337
1338        self.assertEqual(len(events), 0)
1339        old_compare.bar = some_list
1340        self.assertEqual(len(events), 1)
1341        old_compare.bar = some_list
1342        self.assertEqual(len(events), 1)
1343        old_compare.bar = [1, 2, 3]
1344        self.assertEqual(len(events), 1)
1345        old_compare.bar = [4, 5, 6]
1346        self.assertEqual(len(events), 2)
1347
1348
1349@requires_traitsui
1350class TestDeprecatedTraits(unittest.TestCase):
1351
1352    def test_color_deprecated(self):
1353        with self.assertWarnsRegex(DeprecationWarning, "'Color' in 'traits'"):
1354            Color()
1355
1356    def test_rgb_color_deprecated(self):
1357        with self.assertWarnsRegex(DeprecationWarning,
1358                                   "'RGBColor' in 'traits'"):
1359            RGBColor()
1360
1361    def test_font_deprecated(self):
1362        with self.assertWarnsRegex(DeprecationWarning, "'Font' in 'traits'"):
1363            Font()
1364