1# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX
2# All rights reserved.
3#
4# This software is provided without warranty under the terms of the BSD
5# license included in LICENSE.txt and may be redistributed only under
6# the conditions described in the aforementioned license. The license
7# is also available online at http://www.enthought.com/licenses/BSD.txt
8#
9# Thanks for using Enthought open source!
10
11""" Unit test case for testing HasTraits 'on_trait_change' support.
12"""
13
14import unittest
15
16from traits.api import (
17    Any,
18    Dict,
19    HasTraits,
20    Instance,
21    Int,
22    List,
23    Property,
24    TraitDictEvent,
25    TraitDictObject,
26    TraitError,
27    TraitListEvent,
28    TraitListObject,
29    Undefined,
30    cached_property,
31    on_trait_change,
32    pop_exception_handler,
33    push_exception_handler,
34)
35
36
37class ArgCheckBase(HasTraits):
38
39    value = Int(0)
40    int1 = Int(0, test=True)
41    int2 = Int(0)
42    int3 = Int(0, test=True)
43    tint1 = Int(0)
44    tint2 = Int(0, test=True)
45    tint3 = Int(0)
46
47    calls = Int(0)
48    tc = Any
49
50
51class ArgCheckList(ArgCheckBase):
52
53    value = List(Int, [0, 1, 2])
54
55
56class ArgCheckDict(ArgCheckBase):
57
58    value = Dict(Int, Int, {0: 0, 1: 1, 2: 2})
59
60
61class ArgCheckSimple(ArgCheckBase):
62    def arg_check0(self):
63        self.calls += 1
64
65    def arg_check1(self, new):
66        self.calls += 1
67        self.tc.assertEqual(new, self.value)
68
69    def arg_check2(self, name, new):
70        self.calls += 1
71        self.tc.assertEqual(name, "value")
72        self.tc.assertEqual(new, self.value)
73
74    def arg_check3(self, object, name, new):
75        self.calls += 1
76        self.tc.assertIs(object, self)
77        self.tc.assertEqual(name, "value")
78        self.tc.assertEqual(new, self.value)
79
80    def arg_check4(self, object, name, old, new):
81        self.calls += 1
82        self.tc.assertIs(object, self)
83        self.tc.assertEqual(name, "value")
84        self.tc.assertEqual(old, (self.value - 1))
85        self.tc.assertEqual(new, self.value)
86
87
88class ArgCheckDecorator(ArgCheckBase):
89    @on_trait_change("value")
90    def arg_check0(self):
91        self.calls += 1
92
93    @on_trait_change("value")
94    def arg_check1(self, new):
95        self.calls += 1
96        self.tc.assertEqual(new, self.value)
97
98    @on_trait_change("value")
99    def arg_check2(self, name, new):
100        self.calls += 1
101        self.tc.assertEqual(name, "value")
102        self.tc.assertEqual(new, self.value)
103
104    @on_trait_change("value")
105    def arg_check3(self, object, name, new):
106        self.calls += 1
107        self.tc.assertIs(object, self)
108        self.tc.assertEqual(name, "value")
109        self.tc.assertEqual(new, self.value)
110
111    @on_trait_change("value")
112    def arg_check4(self, object, name, old, new):
113        self.calls += 1
114        self.tc.assertIs(object, self)
115        self.tc.assertEqual(name, "value")
116        self.tc.assertEqual(old, (self.value - 1))
117        self.tc.assertEqual(new, self.value)
118
119
120class ArgCheckDecoratorTrailingComma(ArgCheckDecorator):
121    @on_trait_change("int1, int2,")
122    def arg_check(self, object, name, old, new):
123        pass
124
125
126class BaseInstance(HasTraits):
127
128    #: An instance with a value trait we want to listen to.
129    ref = Instance(HasTraits)
130
131    calls = Dict({x: 0 for x in range(5)})
132    exp_object = Any
133    exp_name = Any
134    dst_name = Any
135    exp_old = Any
136    exp_new = Any
137    dst_new = Any
138    tc = Any
139
140
141class InstanceValueListener(BaseInstance):
142
143    @on_trait_change("ref.value")
144    def arg_check0(self):
145        self.calls[0] += 1
146
147    @on_trait_change("ref.value")
148    def arg_check1(self, new):
149        self.calls[1] += 1
150        self.tc.assertEqual(new, self.dst_new)
151
152    @on_trait_change("ref.value")
153    def arg_check2(self, name, new):
154        self.calls[2] += 1
155        self.tc.assertEqual(name, self.dst_name)
156        self.tc.assertEqual(new, self.dst_new)
157
158    @on_trait_change("ref.value")
159    def arg_check3(self, object, name, new):
160        self.calls[3] += 1
161        self.tc.assertIs(object, self.exp_object)
162        self.tc.assertEqual(name, self.exp_name)
163        self.tc.assertEqual(new, self.exp_new)
164
165    @on_trait_change("ref.value")
166    def arg_check4(self, object, name, old, new):
167        self.calls[4] += 1
168        self.tc.assertIs(object, self.exp_object)
169        self.tc.assertEqual(name, self.exp_name)
170        self.tc.assertEqual(old, self.exp_old)
171        self.tc.assertEqual(new, self.exp_new)
172
173
174class InstanceSimpleValue(InstanceValueListener):
175
176    #: An instance with a simple value trait we want to listen to.
177    ref = Instance(ArgCheckBase, ())
178
179
180class InstanceListValue(InstanceValueListener):
181
182    #: An instance with a list value trait we want to listen to.
183    ref = Instance(ArgCheckList, ())
184
185
186class InstanceDictValue(InstanceValueListener):
187
188    ref = Instance(ArgCheckDict, ())
189
190
191class InstanceValueListListener(BaseInstance):
192
193    #: An instance with a list value trait we want to listen to.
194    ref = Instance(ArgCheckList, ())
195
196    @on_trait_change("ref.value[]")
197    def arg_check0(self):
198        self.calls[0] += 1
199
200    @on_trait_change("ref.value[]")
201    def arg_check1(self, new):
202        self.calls[1] += 1
203        self.tc.assertEqual(new, self.dst_new)
204
205    @on_trait_change("ref.value[]")
206    def arg_check2(self, name, new):
207        self.calls[2] += 1
208        self.tc.assertEqual(name, self.dst_name)
209        self.tc.assertEqual(new, self.dst_new)
210
211    @on_trait_change("ref.value[]")
212    def arg_check3(self, object, name, new):
213        self.calls[3] += 1
214        self.tc.assertIs(object, self.exp_object)
215        self.tc.assertEqual(name, self.exp_name)
216        self.tc.assertEqual(new, self.exp_new)
217
218    @on_trait_change("ref.value[]")
219    def arg_check4(self, object, name, old, new):
220        self.calls[4] += 1
221        self.tc.assertIs(object, self.exp_object)
222        self.tc.assertEqual(name, self.exp_name)
223        self.tc.assertEqual(old, self.exp_old)
224        self.tc.assertEqual(new, self.exp_new)
225
226
227class List1(HasTraits):
228
229    refs = List(ArgCheckBase)
230    calls = Dict({0: 0, 3: 0, 4: 0})
231
232    exp_object = Any
233    exp_name = Any
234    type_old = Any
235    exp_old = Any
236    type_new = Any
237    exp_new = Any
238    tc = Any
239
240    @on_trait_change("refs.value")
241    def arg_check0(self):
242        self.calls[0] += 1
243
244    @on_trait_change("refs.value")
245    def arg_check3(self, object, name, new):
246        self.calls[3] += 1
247        self.tc.assertIs(object, self.exp_object)
248        self.tc.assertEqual(name, self.exp_name)
249        if self.type_new is None:
250            self.tc.assertEqual(new, self.exp_new)
251        else:
252            self.tc.assertIsInstance(new, self.type_new)
253
254    @on_trait_change("refs.value")
255    def arg_check4(self, object, name, old, new):
256        self.calls[4] += 1
257        self.tc.assertIs(object, self.exp_object)
258        self.tc.assertEqual(name, self.exp_name)
259        if self.type_old is None:
260            self.tc.assertEqual(old, self.exp_old)
261        else:
262            self.tc.assertIsInstance(old, self.type_old)
263        if self.type_new is None:
264            self.tc.assertEqual(new, self.exp_new)
265        else:
266            self.tc.assertIsInstance(new, self.type_new)
267
268
269class List2(HasTraits):
270
271    refs = List(ArgCheckBase)
272
273    calls = Int(0)
274    exp_new = Any
275    tc = Any
276
277    @on_trait_change("refs.value")
278    def arg_check1(self, new):
279        self.calls += 1
280        self.tc.assertEqual(new, self.exp_new)
281
282
283class List3(HasTraits):
284
285    refs = List(ArgCheckBase)
286
287    calls = Int(0)
288    exp_name = Any
289    exp_new = Any
290    tc = Any
291
292    @on_trait_change("refs.value")
293    def arg_check2(self, name, new):
294        self.calls += 1
295        self.tc.assertEqual(name, self.exp_name)
296        self.tc.assertEqual(new, self.exp_new)
297
298
299class Dict1(List1):
300    refs = Dict(Int, ArgCheckBase)
301
302
303class Dict2(HasTraits):
304
305    refs = Dict(Int, ArgCheckBase)
306
307    calls = Int(0)
308    exp_new = Any
309    tc = Any
310
311    @on_trait_change("refs.value")
312    def arg_check1(self, new):
313        self.calls += 1
314        self.tc.assertEqual(new, self.exp_new)
315
316
317class Dict3(HasTraits):
318
319    refs = Dict(Int, ArgCheckBase)
320
321    calls = Int(0)
322    exp_name = Any
323    exp_new = Any
324    tc = Any
325
326    @on_trait_change("refs.value")
327    def arg_check2(self, name, new):
328        self.calls += 1
329        self.tc.assertEqual(name, self.exp_name)
330        self.tc.assertEqual(new, self.exp_new)
331
332
333class Complex(HasTraits):
334
335    int1 = Int(0, test=True)
336    int2 = Int(0)
337    int3 = Int(0, test=True)
338    tint1 = Int(0)
339    tint2 = Int(0, test=True)
340    tint3 = Int(0)
341    ref = Instance(ArgCheckBase, ())
342
343    calls = Int(0)
344    exp_object = Any
345    exp_name = Any
346    dst_name = Any
347    exp_old = Any
348    exp_new = Any
349    dst_new = Any
350    tc = Any
351
352    def arg_check0(self):
353        self.calls += 1
354
355    def arg_check1(self, new):
356        self.calls += 1
357        self.tc.assertEqual(new, self.exp_new)
358
359    def arg_check2(self, name, new):
360        self.calls += 1
361        self.tc.assertEqual(name, self.exp_name)
362        self.tc.assertEqual(new, self.exp_new)
363
364    def arg_check3(self, object, name, new):
365        self.calls += 1
366        self.tc.assertIs(object, self.exp_object)
367        self.tc.assertEqual(name, self.exp_name)
368        self.tc.assertEqual(new, self.exp_new)
369
370    def arg_check4(self, object, name, old, new):
371        self.calls += 1
372        self.tc.assertIs(object, self.exp_object)
373        self.tc.assertEqual(name, self.exp_name)
374        self.tc.assertEqual(old, self.exp_old)
375        self.tc.assertEqual(new, self.exp_new)
376
377
378class Link(HasTraits):
379
380    next = Any
381    prev = Any
382    value = Int(0)
383
384
385class LinkTest(HasTraits):
386
387    head = Instance(Link)
388
389    calls = Int(0)
390    exp_object = Any
391    exp_name = Any
392    dst_name = Any
393    exp_old = Any
394    exp_new = Any
395    dst_new = Any
396    tc = Any
397
398    def arg_check0(self):
399        self.calls += 1
400
401    def arg_check1(self, new):
402        self.calls += 1
403        self.tc.assertEqual(new, self.exp_new)
404
405    def arg_check2(self, name, new):
406        self.calls += 1
407        self.tc.assertEqual(name, self.exp_name)
408        self.tc.assertEqual(new, self.exp_new)
409
410    def arg_check3(self, object, name, new):
411        self.calls += 1
412        self.tc.assertIs(object, self.exp_object)
413        self.tc.assertEqual(name, self.exp_name)
414        self.tc.assertEqual(new, self.exp_new)
415
416    def arg_check4(self, object, name, old, new):
417        self.calls += 1
418        self.tc.assertIs(object, self.exp_object)
419        self.tc.assertEqual(name, self.exp_name)
420        self.tc.assertEqual(old, self.exp_old)
421        self.tc.assertEqual(new, self.exp_new)
422
423
424class PropertyDependsOn(HasTraits):
425
426    sum = Property(depends_on="ref.[int1,int2,int3]")
427    ref = Instance(ArgCheckBase, ())
428
429    pcalls = Int(0)
430    calls = Int(0)
431    exp_old = Any
432    exp_new = Any
433    tc = Any
434
435    @cached_property
436    def _get_sum(self):
437        self.pcalls += 1
438        r = self.ref
439        return r.int1 + r.int2 + r.int3
440
441    def _sum_changed(self, old, new):
442        self.calls += 1
443        self.tc.assertEqual(old, self.exp_old)
444        self.tc.assertEqual(new, self.exp_new)
445
446
447class OnTraitChangeTest(unittest.TestCase):
448    def setUp(self):
449        def ignore(*args):
450            pass
451
452        push_exception_handler(handler=ignore, reraise_exceptions=True)
453
454    def tearDown(self):
455        pop_exception_handler()
456
457    def test_arg_check_simple(self):
458        ac = ArgCheckSimple(tc=self)
459        ac.on_trait_change(ac.arg_check0, "value")
460        ac.on_trait_change(ac.arg_check1, "value")
461        ac.on_trait_change(ac.arg_check2, "value")
462        ac.on_trait_change(ac.arg_check3, "value")
463        ac.on_trait_change(ac.arg_check4, "value")
464        for i in range(3):
465            ac.value += 1
466        self.assertEqual(ac.calls, (3 * 5))
467        ac.on_trait_change(ac.arg_check0, "value", remove=True)
468        ac.on_trait_change(ac.arg_check1, "value", remove=True)
469        ac.on_trait_change(ac.arg_check2, "value", remove=True)
470        ac.on_trait_change(ac.arg_check3, "value", remove=True)
471        ac.on_trait_change(ac.arg_check4, "value", remove=True)
472        for i in range(3):
473            ac.value += 1
474        self.assertEqual(ac.calls, (3 * 5))
475        self.assertEqual(ac.value, (2 * 3))
476
477    def test_arg_check_trailing_comma(self):
478        ac = ArgCheckSimple(tc=self)
479
480        with self.assertRaises(TraitError):
481            ac.on_trait_change(ac.arg_check0, "int1, int2,")
482
483    def test_arg_check_decorator(self):
484        ac = ArgCheckDecorator(tc=self)
485        for i in range(3):
486            ac.value += 1
487        self.assertEqual(ac.calls, (3 * 5))
488        self.assertEqual(ac.value, 3)
489
490    def test_arg_check_decorator_trailing_comma(self):
491        with self.assertRaises(TraitError):
492            ArgCheckDecoratorTrailingComma(tc=self)
493
494    def test_instance_simple_value(self):
495        inst = InstanceSimpleValue(tc=self)
496        for i in range(3):
497            inst.trait_set(
498                exp_object=inst.ref,
499                exp_name="value",
500                dst_name="value",
501                exp_old=i,
502                exp_new=(i + 1),
503                dst_new=(i + 1),
504            )
505            inst.ref.value = i + 1
506        self.assertEqual(inst.calls, {x: 3 for x in range(5)})
507        self.assertEqual(inst.ref.value, 3)
508
509        inst.reset_traits(['calls'])
510        ref = ArgCheckBase()
511        inst.trait_set(
512            exp_object=inst,
513            exp_name="ref",
514            dst_name="value",
515            exp_old=inst.ref,
516            exp_new=ref,
517            dst_new=0,
518        )
519        inst.ref = ref
520        self.assertEqual(inst.calls, {x: 1 for x in range(5)})
521        self.assertEqual(inst.ref.value, 0)
522
523        inst.reset_traits(['calls'])
524        for i in range(3):
525            inst.trait_set(
526                exp_object=inst.ref,
527                exp_name="value",
528                dst_name="value",
529                exp_old=i,
530                exp_new=(i + 1),
531                dst_new=(i + 1),
532            )
533            inst.ref.value = i + 1
534        self.assertEqual(inst.calls, {x: 3 for x in range(5)})
535        self.assertEqual(inst.ref.value, 3)
536
537    def test_instance_list_value(self):
538        inst = InstanceListValue(tc=self)
539
540        inst.trait_set(
541            exp_object=inst.ref,
542            exp_name="value",
543            dst_name="value",
544            exp_old=[0, 1, 2],
545            exp_new=[0, 1, 2, 3],
546            dst_new=[0, 1, 2, 3],
547        )
548        inst.ref.value = [0, 1, 2, 3]
549        self.assertEqual(inst.calls, {x: 1 for x in range(5)})
550        self.assertEqual(inst.ref.value, [0, 1, 2, 3])
551
552        inst.reset_traits(['calls'])
553        ref = ArgCheckList()
554        inst.trait_set(
555            exp_object=inst,
556            exp_name="ref",
557            dst_name="value",
558            exp_old=inst.ref,
559            exp_new=ref,
560            dst_new=[0, 1, 2],
561        )
562        inst.ref = ref
563        self.assertEqual(inst.calls, {x: 1 for x in range(5)})
564        self.assertEqual(inst.ref.value, [0, 1, 2])
565
566        inst.reset_traits(['calls'])
567        inst.trait_set(
568            exp_object=inst.ref,
569            exp_name="value",
570            dst_name="value",
571            exp_old=[0, 1, 2],
572            exp_new=[0, 1, 2, 3],
573            dst_new=[0, 1, 2, 3],
574        )
575        with self.assertRaises(
576                AssertionError,
577                msg="Behavior of a bug (#537) is not reproduced."):
578            # Expected failure, see enthought/traits#537
579            # InstanceValueListener.arg_check1 receives a TraitListEvent
580            # as `new` instead of the expected `[0, 1, 2, 3]`
581            inst.ref.value.append(3)
582
583        # Expected failure
584        # See enthought/traits#537
585        with self.assertRaises(
586                AssertionError,
587                msg="Behavior of a bug (#537) is not reproduced."):
588            # Handlers with arguments are unexpectedly called, but one of the
589            # handlers fails, leading to the rest of the handlers
590            # not to be called. Actual behavior depends on dictionary ordering
591            # (Python <3.6) or the order of handlers defined in
592            # InstanceValueListener (Python >= 3.6)
593            self.assertEqual(inst.calls, {0: 1, 1: 0, 2: 0, 3: 0, 4: 0})
594
595        self.assertEqual(inst.ref.value, [0, 1, 2, 3])
596
597    def test_instance_dict_value(self):
598        inst = InstanceDictValue(tc=self)
599
600        inst.trait_set(
601            exp_object=inst.ref,
602            exp_name="value",
603            dst_name="value",
604            exp_old={0: 0, 1: 1, 2: 2},
605            exp_new={0: 0, 1: 1, 2: 2, 3: 3},
606            dst_new={0: 0, 1: 1, 2: 2, 3: 3},
607        )
608        inst.ref.value = {0: 0, 1: 1, 2: 2, 3: 3}
609        self.assertEqual(inst.calls, {x: 1 for x in range(5)})
610        self.assertEqual(inst.ref.value, {0: 0, 1: 1, 2: 2, 3: 3})
611
612        inst.reset_traits(['calls'])
613        ref = ArgCheckDict()
614        inst.trait_set(
615            exp_object=inst,
616            exp_name="ref",
617            dst_name="value",
618            exp_old=inst.ref,
619            exp_new=ref,
620            dst_new={0: 0, 1: 1, 2: 2},
621        )
622        inst.ref = ref
623        self.assertEqual(inst.calls, {x: 1 for x in range(5)})
624        self.assertEqual(inst.ref.value, {0: 0, 1: 1, 2: 2})
625
626        inst.reset_traits(['calls'])
627        inst.trait_set(
628            exp_object=inst.ref,
629            exp_name="value",
630            dst_name="value",
631            exp_old={0: 0, 1: 1, 2: 2},
632            exp_new={0: 0, 1: 1, 2: 2, 3: 3},
633            dst_new={0: 0, 1: 1, 2: 2, 3: 3},
634        )
635        with self.assertRaises(
636                AssertionError,
637                msg="Behavior of a bug (#537) is not reproduced."):
638            # Expected failure, see enthought/traits#537
639            # InstanceValueListener.arg_check1 receives a TraitDictEvent
640            # as `new` instead of the expected `{0: 0, 1: 1, 2: 2, 3: 3}`
641            inst.ref.value[3] = 3
642
643        # Expected failure
644        # See enthought/traits#537
645        with self.assertRaises(
646                AssertionError,
647                msg="Behavior of a bug (#537) is not reproduced."):
648            # Handlers with arguments are unexpectedly called, but one of the
649            # handlers fails, leading to the rest of the handlers
650            # not to be called. Actual behavior depends on dictionary ordering
651            # (Python <3.6) or the order of handlers defined in
652            # InstanceValueListener (Python >= 3.6)
653            self.assertEqual(inst.calls, {0: 1, 1: 0, 2: 0, 3: 0, 4: 0})
654
655        self.assertEqual(inst.ref.value, {0: 0, 1: 1, 2: 2, 3: 3})
656
657    def test_instance_value_list_listener(self):
658        inst = InstanceValueListListener(tc=self)
659
660        inst.trait_set(
661            exp_object=inst.ref,
662            exp_name="value",
663            dst_name="value",
664            exp_old=[0, 1, 2],
665            exp_new=[0, 1, 2, 3],
666            dst_new=[0, 1, 2, 3],
667        )
668        inst.ref.value = [0, 1, 2, 3]
669        self.assertEqual(inst.calls, {x: 1 for x in range(5)})
670        self.assertEqual(inst.ref.value, [0, 1, 2, 3])
671
672        inst.reset_traits(['calls'])
673        ref = ArgCheckList()
674        inst.trait_set(
675            exp_object=inst,
676            exp_name="ref",
677            dst_name="value",
678            exp_old=inst.ref,
679            exp_new=ref,
680            dst_new=[0, 1, 2],
681        )
682        inst.ref = ref
683        self.assertEqual(inst.calls, {x: 1 for x in range(5)})
684        self.assertEqual(inst.ref.value, [0, 1, 2])
685
686        inst.reset_traits(['calls'])
687        inst.trait_set(
688            exp_object=inst.ref,
689            exp_name="value_items",
690            dst_name="value_items",
691            exp_old=[],
692            exp_new=[3],
693            dst_new=[3],
694        )
695        inst.ref.value.append(3)
696        self.assertEqual(
697            inst.calls, {x: 1 for x in range(5)}
698        )
699        self.assertEqual(inst.ref.value, [0, 1, 2, 3])
700
701        inst.reset_traits(['calls'])
702        inst.trait_set(
703            exp_object=inst.ref,
704            exp_name="value_items",
705            dst_name="value_items",
706            exp_old=[2],
707            exp_new=[],
708            dst_new=[],
709        )
710        inst.ref.value.pop(2)
711        self.assertEqual(
712            inst.calls, {x: 1 for x in range(5)}
713        )
714        self.assertEqual(inst.ref.value, [0, 1, 3])
715
716        inst.reset_traits(['calls'])
717        inst.trait_set(
718            exp_object=inst.ref,
719            exp_name="value_items",
720            dst_name="value_items",
721            exp_old=[1],
722            exp_new=[1, 2],
723            dst_new=[1, 2],
724        )
725        inst.ref.value[1:2] = [1, 2]
726        self.assertEqual(
727            inst.calls, {x: 1 for x in range(5)}
728        )
729        self.assertEqual(inst.ref.value, [0, 1, 2, 3])
730
731    def test_list1(self):
732        l1 = List1(tc=self)
733        for i in range(3):
734            ac = ArgCheckBase()
735            l1.trait_set(
736                exp_object=l1,
737                exp_name="refs_items",
738                type_old=None,
739                exp_old=Undefined,
740                type_new=TraitListEvent,
741            )
742            l1.refs.append(ac)
743
744        # Behavior of an existing bug.
745        # The expected value should be {0: 3, 3: 3, 4: 3}
746        # See enthought/traits#538
747        self.assertEqual(
748            l1.calls, {0: 3, 3: 0, 4: 0},
749            "Behavior of a bug (#538) is not reproduced."
750        )
751
752        for i in range(3):
753            self.assertEqual(l1.refs[i].value, 0)
754
755        l1.reset_traits(['calls'])
756        refs = [ArgCheckBase(), ArgCheckBase(), ArgCheckBase()]
757        l1.trait_set(
758            exp_object=l1,
759            exp_name="refs",
760            type_old=None,
761            exp_old=l1.refs,
762            type_new=TraitListObject,
763        )
764        l1.refs = refs
765        self.assertEqual(l1.calls, {0: 1, 3: 1, 4: 1})
766        for i in range(3):
767            self.assertEqual(l1.refs[i].value, 0)
768
769        l1.reset_traits(['calls'])
770        for i in range(3):
771            for j in range(3):
772                l1.trait_set(
773                    exp_object=l1.refs[j],
774                    exp_name="value",
775                    type_old=None,
776                    exp_old=i,
777                    type_new=None,
778                    exp_new=(i + 1),
779                )
780                l1.refs[j].value = i + 1
781
782        self.assertEqual(l1.calls, {0: 9, 3: 9, 4: 9})
783        for i in range(3):
784            self.assertEqual(l1.refs[i].value, 3)
785
786    def test_list2(self):
787        self.check_list(List2(tc=self))
788
789    def test_list3(self):
790        self.check_list(List3(tc=self))
791
792    def test_dict1(self):
793        d1 = Dict1(tc=self)
794        for i in range(3):
795            ac = ArgCheckBase()
796            d1.trait_set(
797                exp_object=d1,
798                exp_name="refs_items",
799                type_old=None,
800                exp_old=Undefined,
801                type_new=TraitDictEvent,
802            )
803            d1.refs[i] = ac
804
805        # Behavior of an existing bug.
806        # The expected value should be {0: 3, 3: 3, 4: 3}
807        # See enthought/traits#538
808        self.assertEqual(
809            d1.calls, {0: 3, 3: 0, 4: 0},
810            "Behavior of a bug (#538) is not reproduced."
811        )
812
813        for i in range(3):
814            self.assertEqual(d1.refs[i].value, 0)
815
816        d1.reset_traits(['calls'])
817        refs = {0: ArgCheckBase(), 1: ArgCheckBase(), 2: ArgCheckBase()}
818        d1.trait_set(
819            exp_object=d1,
820            exp_name="refs",
821            type_old=None,
822            exp_old=d1.refs,
823            type_new=TraitDictObject,
824        )
825        d1.refs = refs
826        self.assertEqual(d1.calls, {0: 1, 3: 1, 4: 1})
827        for i in range(3):
828            self.assertEqual(d1.refs[i].value, 0)
829
830        d1.reset_traits(['calls'])
831        for i in range(3):
832            for j in range(3):
833                d1.trait_set(
834                    exp_object=d1.refs[j],
835                    exp_name="value",
836                    type_old=None,
837                    exp_old=i,
838                    type_new=None,
839                    exp_new=(i + 1),
840                )
841                d1.refs[j].value = i + 1
842        self.assertEqual(d1.calls, {0: 9, 3: 9, 4: 9})
843        for i in range(3):
844            self.assertEqual(d1.refs[i].value, 3)
845
846    def test_dict2(self):
847        self.check_dict(Dict2(tc=self))
848
849    def test_dict3(self):
850        self.check_dict(Dict3(tc=self))
851
852    def test_pattern_list1(self):
853        c = Complex(tc=self)
854        self.check_complex(
855            c,
856            c,
857            "int1, int2, int3",
858            ["int1", "int2", "int3"],
859            ["tint1", "tint2", "tint3"],
860        )
861
862    def test_pattern_list2(self):
863        c = Complex(tc=self)
864        self.check_complex(
865            c,
866            c,
867            ["int1", "int2", "int3"],
868            ["int1", "int2", "int3"],
869            ["tint1", "tint2", "tint3"],
870        )
871
872    def test_pattern_list3(self):
873        c = Complex(tc=self)
874        self.check_complex(
875            c,
876            c.ref,
877            "ref.[int1, int2, int3]",
878            ["int1", "int2", "int3"],
879            ["tint1", "tint2", "tint3"],
880        )
881
882    def test_pattern_list4(self):
883        c = Complex(tc=self)
884        handlers = [c.arg_check0, c.arg_check3, c.arg_check4]
885        n = len(handlers)
886        pattern = "ref.[int1,int2,int3]"
887        self.multi_register(c, handlers, pattern)
888        r0 = c.ref
889        r1 = ArgCheckBase()
890        c.trait_set(exp_object=c, exp_name="ref", exp_old=r0, exp_new=r1)
891        c.ref = r1
892        c.trait_set(exp_old=r1, exp_new=r0)
893        c.ref = r0
894        self.assertEqual(c.calls, 2 * n)
895        self.multi_register(c, handlers, pattern, remove=True)
896        c.ref = r1
897        c.ref = r0
898        self.assertEqual(c.calls, 2 * n)
899
900    def test_pattern_list5(self):
901        c = Complex(tc=self)
902        c.on_trait_change(c.arg_check1, "ref.[int1,int2,int3]")
903        self.assertRaises(TraitError, c.trait_set, ref=ArgCheckBase())
904
905    def test_pattern_list6(self):
906        c = Complex(tc=self)
907        c.on_trait_change(c.arg_check2, "ref.[int1,int2,int3]")
908        self.assertRaises(TraitError, c.trait_set, ref=ArgCheckBase())
909
910    def test_pattern_list7(self):
911        c = Complex(tc=self)
912        self.check_complex(
913            c,
914            c,
915            "+test",
916            ["int1", "int3", "tint2"],
917            ["int2", "tint1", "tint3"],
918        )
919
920    def test_pattern_list8(self):
921        c = Complex(tc=self)
922        self.check_complex(
923            c,
924            c,
925            "int+test",
926            ["int1", "int3"],
927            ["int2", "tint1", "tint2", "tint3"],
928        )
929
930    def test_pattern_list9(self):
931        c = Complex(tc=self)
932        self.check_complex(
933            c,
934            c,
935            "int-test",
936            ["int2"],
937            ["int1", "int3", "tint4", "tint5", "tint6"],
938        )
939
940    def test_pattern_list10(self):
941        c = Complex(tc=self)
942        self.check_complex(
943            c, c, "int+", ["int1", "int2", "int3"], ["tint1", "tint2", "tint3"]
944        )
945
946    def test_pattern_list11(self):
947        c = Complex(tc=self)
948        self.check_complex(
949            c, c, "int-", ["int1", "int2", "int3"], ["tint1", "tint2", "tint3"]
950        )
951
952    def test_pattern_list12(self):
953        c = Complex(tc=self)
954        self.check_complex(
955            c,
956            c,
957            "int+test,tint-test",
958            ["int1", "int3", "tint1", "tint3"],
959            ["int2", "tint2"],
960        )
961
962    def test_pattern_list13(self):
963        c = Complex(tc=self)
964        self.check_complex(
965            c,
966            c.ref,
967            "ref.[int+test,tint-test]",
968            ["int1", "int3", "tint1", "tint3"],
969            ["int2", "tint2"],
970        )
971
972    def test_cycle1(self):
973        lt = LinkTest(tc=self, head=self.build_list())
974        handlers = [
975            lt.arg_check0,
976            lt.arg_check1,
977            lt.arg_check2,
978            lt.arg_check3,
979            lt.arg_check4,
980        ]
981        nh = len(handlers)
982        self.multi_register(lt, handlers, "head.next*.value")
983        cur = lt.head
984        for i in range(4):
985            lt.trait_set(
986                exp_object=cur,
987                exp_name="value",
988                exp_old=10 * i,
989                exp_new=(10 * i) + 1,
990            )
991            cur.value = (10 * i) + 1
992            cur = cur.next
993        self.assertEqual(lt.calls, 4 * nh)
994        self.multi_register(lt, handlers, "head.next*.value", remove=True)
995        cur = lt.head
996        for i in range(4):
997            cur.value = (10 * i) + 2
998            cur = cur.next
999        self.assertEqual(lt.calls, 4 * nh)
1000
1001    def test_cycle2(self):
1002        lt = LinkTest(tc=self, head=self.build_list())
1003        handlers = [
1004            lt.arg_check0,
1005            lt.arg_check1,
1006            lt.arg_check2,
1007            lt.arg_check3,
1008            lt.arg_check4,
1009        ]
1010        nh = len(handlers)
1011        self.multi_register(lt, handlers, "head.[next,prev]*.value")
1012        cur = lt.head
1013        for i in range(4):
1014            lt.trait_set(
1015                exp_object=cur,
1016                exp_name="value",
1017                exp_old=10 * i,
1018                exp_new=(10 * i) + 1,
1019            )
1020            cur.value = (10 * i) + 1
1021            cur = cur.next
1022        self.assertEqual(lt.calls, 4 * nh)
1023        self.multi_register(
1024            lt, handlers, "head.[next,prev]*.value", remove=True
1025        )
1026        cur = lt.head
1027        for i in range(4):
1028            cur.value = (10 * i) + 2
1029            cur = cur.next
1030        self.assertEqual(lt.calls, 4 * nh)
1031
1032    def test_cycle3(self):
1033        lt = LinkTest(tc=self, head=self.build_list())
1034        handlers = [lt.arg_check0, lt.arg_check3, lt.arg_check4]
1035        nh = len(handlers)
1036        self.multi_register(lt, handlers, "head.next*.value")
1037        link = self.new_link(lt, lt.head, 1)
1038        self.assertEqual(lt.calls, nh)
1039        link = self.new_link(lt, link, 2)
1040        self.assertEqual(lt.calls, 2 * nh)
1041        self.multi_register(lt, handlers, "head.next*.value", remove=True)
1042        link = self.new_link(lt, link, 3)
1043        self.assertEqual(lt.calls, 2 * nh)
1044
1045    def test_property(self):
1046        pdo = PropertyDependsOn(tc=self)
1047        sum = pdo.sum
1048        self.assertEqual(sum, 0)
1049        for n in ["int1", "int2", "int3"]:
1050            for i in range(3):
1051                pdo.trait_set(exp_old=sum, exp_new=sum + 1)
1052                setattr(pdo.ref, n, i + 1)
1053                sum += 1
1054        self.assertEqual(pdo.pcalls, (3 * 3) + 1)
1055        self.assertEqual(pdo.calls, 3 * 3)
1056        for i in range(10):
1057            pdo.sum
1058        self.assertEqual(pdo.pcalls, (3 * 3) + 1)
1059        pdo.trait_set(exp_old=sum, exp_new=60)
1060        old_ref = pdo.ref
1061        pdo.ref = ArgCheckBase(int1=10, int2=20, int3=30)
1062        self.assertEqual(pdo.pcalls, (3 * 3) + 2)
1063        self.assertEqual(pdo.calls, (3 * 3) + 1)
1064        sum = 60
1065        for n in ["int1", "int2", "int3"]:
1066            for i in range(3):
1067                pdo.trait_set(exp_old=sum, exp_new=sum + 1)
1068                setattr(pdo.ref, n, getattr(pdo.ref, n) + 1)
1069                sum += 1
1070        self.assertEqual(pdo.pcalls, (2 * 3 * 3) + 2)
1071        self.assertEqual(pdo.calls, (2 * 3 * 3) + 1)
1072        for n in ["int1", "int2", "int3"]:
1073            for i in range(3):
1074                setattr(old_ref, n, getattr(old_ref, n) + 1)
1075        self.assertEqual(pdo.pcalls, (2 * 3 * 3) + 2)
1076        self.assertEqual(pdo.calls, (2 * 3 * 3) + 1)
1077        self.assertEqual(pdo.sum, sum)
1078        self.assertEqual(pdo.pcalls, (2 * 3 * 3) + 2)
1079
1080    def check_list(self, l):
1081        for i in range(3):
1082            ac = ArgCheckBase()
1083            self.assertRaises(TraitError, l.refs.append, ac)
1084        self.assertEqual(l.calls, 0)
1085        for i in range(3):
1086            self.assertEqual(l.refs[i].value, 0)
1087        refs = [ArgCheckBase(), ArgCheckBase(), ArgCheckBase()]
1088        self.assertRaises(TraitError, l.trait_set, refs=refs)
1089        self.assertEqual(l.calls, 0)
1090        for i in range(3):
1091            self.assertEqual(l.refs[i].value, 0)
1092        for i in range(3):
1093            for j in range(3):
1094                l.exp_new = i + 1
1095                l.refs[j].value = i + 1
1096        self.assertEqual(l.calls, 0)
1097        for i in range(3):
1098            self.assertEqual(l.refs[i].value, 3)
1099
1100    def check_dict(self, d):
1101        for i in range(3):
1102            ac = ArgCheckBase()
1103            self.assertRaises(TraitError, d.refs.setdefault, i, ac)
1104        self.assertEqual(d.calls, 0)
1105        for i in range(3):
1106            self.assertEqual(d.refs[i].value, 0)
1107        refs = {0: ArgCheckBase(), 1: ArgCheckBase(), 2: ArgCheckBase()}
1108        self.assertRaises(TraitError, d.trait_set, refs=refs)
1109        self.assertEqual(d.calls, 0)
1110        for i in range(3):
1111            self.assertEqual(d.refs[i].value, 0)
1112        for i in range(3):
1113            for j in range(3):
1114                d.exp_new = i + 1
1115                d.refs[j].value = i + 1
1116        self.assertEqual(d.calls, 0)
1117        for i in range(3):
1118            self.assertEqual(d.refs[i].value, 3)
1119
1120    def check_complex(self, c, r, pattern, names, other=[]):
1121        handlers = [
1122            c.arg_check0,
1123            c.arg_check1,
1124            c.arg_check2,
1125            c.arg_check3,
1126            c.arg_check4,
1127        ]
1128        nh = len(handlers)
1129        nn = len(names)
1130        self.multi_register(c, handlers, pattern)
1131        for i in range(3):
1132            for n in names:
1133                c.trait_set(
1134                    exp_object=r, exp_name=n, exp_old=i, exp_new=(i + 1)
1135                )
1136                setattr(r, n, i + 1)
1137            for n in other:
1138                c.trait_set(
1139                    exp_object=r, exp_name=n, exp_old=i, exp_new=(i + 1)
1140                )
1141                setattr(r, n, i + 1)
1142        self.assertEqual(c.calls, 3 * nn * nh)
1143        self.multi_register(c, handlers, pattern, remove=True)
1144        for i in range(3):
1145            for n in names:
1146                setattr(r, n, i + 1)
1147            for n in other:
1148                setattr(r, n, i + 1)
1149        self.assertEqual(c.calls, 3 * nn * nh)
1150
1151    def multi_register(self, object, handlers, pattern, remove=False):
1152        for handler in handlers:
1153            object.on_trait_change(handler, pattern, remove=remove)
1154
1155    def build_list(self):
1156        l1 = Link(value=00)
1157        l2 = Link(value=10)
1158        l3 = Link(value=20)
1159        l4 = Link(value=30)
1160        l1.trait_set(next=l2, prev=l4)
1161        l2.trait_set(next=l3, prev=l1)
1162        l3.trait_set(next=l4, prev=l2)
1163        l4.trait_set(next=l1, prev=l3)
1164        return l1
1165
1166    def new_link(self, lt, cur, value):
1167        link = Link(value=value, next=cur.next, prev=cur)
1168        cur.next.prev = link
1169        lt.trait_set(
1170            exp_object=cur, exp_name="next", exp_old=cur.next, exp_new=link
1171        )
1172        cur.next = link
1173        return link
1174