1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6from ast import literal_eval
7from textwrap import dedent
8from typing import List, Set
9from unittest.mock import Mock
10
11import libcst as cst
12import libcst.matchers as m
13from libcst.matchers import (
14    MatcherDecoratableTransformer,
15    MatcherDecoratableVisitor,
16    call_if_inside,
17    call_if_not_inside,
18    leave,
19    visit,
20)
21from libcst.testing.utils import UnitTest
22
23
24def fixture(code: str) -> cst.Module:
25    return cst.parse_module(dedent(code))
26
27
28class MatchersGatingDecoratorsTest(UnitTest):
29    def test_call_if_inside_transform_simple(self) -> None:
30        # Set up a simple visitor with a call_if_inside decorator.
31        class TestVisitor(MatcherDecoratableTransformer):
32            def __init__(self) -> None:
33                super().__init__()
34                self.visits: List[str] = []
35                self.leaves: List[str] = []
36
37            @call_if_inside(m.FunctionDef(m.Name("foo")))
38            def visit_SimpleString(self, node: cst.SimpleString) -> None:
39                self.visits.append(node.value)
40
41            @call_if_inside(m.FunctionDef())
42            def leave_SimpleString(
43                self, original_node: cst.SimpleString, updated_node: cst.SimpleString
44            ) -> cst.SimpleString:
45                self.leaves.append(updated_node.value)
46                return updated_node
47
48        # Parse a module and verify we visited correctly.
49        module = fixture(
50            """
51            a = "foo"
52            b = "bar"
53
54            def foo() -> None:
55                return "baz"
56
57            def bar() -> None:
58                return "foobar"
59        """
60        )
61        visitor = TestVisitor()
62        module.visit(visitor)
63
64        # We should have only visited a select number of nodes.
65        self.assertEqual(visitor.visits, ['"baz"'])
66        self.assertEqual(visitor.leaves, ['"baz"', '"foobar"'])
67
68    def test_call_if_inside_verify_original_transform(self) -> None:
69        # Set up a simple visitor with a call_if_inside decorator.
70        class TestVisitor(MatcherDecoratableTransformer):
71            def __init__(self) -> None:
72                super().__init__()
73                self.func_visits: List[str] = []
74                self.str_visits: List[str] = []
75
76            @call_if_inside(m.FunctionDef(m.Name("foo")))
77            def visit_SimpleString(self, node: cst.SimpleString) -> None:
78                self.str_visits.append(node.value)
79
80            def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
81                self.func_visits.append(node.name.value)
82
83        # Parse a module and verify we visited correctly.
84        module = fixture(
85            """
86            a = "foo"
87            b = "bar"
88
89            def foo() -> None:
90                return "baz"
91
92            def bar() -> None:
93                return "foobar"
94        """
95        )
96        visitor = TestVisitor()
97        module.visit(visitor)
98
99        # We should have only visited a select number of nodes.
100        self.assertEqual(visitor.func_visits, ["foo", "bar"])
101        self.assertEqual(visitor.str_visits, ['"baz"'])
102
103    def test_call_if_inside_collect_simple(self) -> None:
104        # Set up a simple visitor with a call_if_inside decorator.
105        class TestVisitor(MatcherDecoratableVisitor):
106            def __init__(self) -> None:
107                super().__init__()
108                self.visits: List[str] = []
109                self.leaves: List[str] = []
110
111            @call_if_inside(m.FunctionDef(m.Name("foo")))
112            def visit_SimpleString(self, node: cst.SimpleString) -> None:
113                self.visits.append(node.value)
114
115            @call_if_inside(m.FunctionDef())
116            def leave_SimpleString(self, original_node: cst.SimpleString) -> None:
117                self.leaves.append(original_node.value)
118
119        # Parse a module and verify we visited correctly.
120        module = fixture(
121            """
122            a = "foo"
123            b = "bar"
124
125            def foo() -> None:
126                return "baz"
127
128            def bar() -> None:
129                return "foobar"
130        """
131        )
132        visitor = TestVisitor()
133        module.visit(visitor)
134
135        # We should have only visited a select number of nodes.
136        self.assertEqual(visitor.visits, ['"baz"'])
137        self.assertEqual(visitor.leaves, ['"baz"', '"foobar"'])
138
139    def test_call_if_inside_verify_original_collect(self) -> None:
140        # Set up a simple visitor with a call_if_inside decorator.
141        class TestVisitor(MatcherDecoratableVisitor):
142            def __init__(self) -> None:
143                super().__init__()
144                self.func_visits: List[str] = []
145                self.str_visits: List[str] = []
146
147            @call_if_inside(m.FunctionDef(m.Name("foo")))
148            def visit_SimpleString(self, node: cst.SimpleString) -> None:
149                self.str_visits.append(node.value)
150
151            def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
152                self.func_visits.append(node.name.value)
153
154        # Parse a module and verify we visited correctly.
155        module = fixture(
156            """
157            a = "foo"
158            b = "bar"
159
160            def foo() -> None:
161                return "baz"
162
163            def bar() -> None:
164                return "foobar"
165        """
166        )
167        visitor = TestVisitor()
168        module.visit(visitor)
169
170        # We should have only visited a select number of nodes.
171        self.assertEqual(visitor.func_visits, ["foo", "bar"])
172        self.assertEqual(visitor.str_visits, ['"baz"'])
173
174    def test_multiple_visitors_collect(self) -> None:
175        # Set up a simple visitor with multiple visit decorators.
176        class TestVisitor(MatcherDecoratableVisitor):
177            def __init__(self) -> None:
178                super().__init__()
179                self.visits: List[str] = []
180
181            @call_if_inside(m.ClassDef(m.Name("A")))
182            @call_if_inside(m.FunctionDef(m.Name("foo")))
183            def visit_SimpleString(self, node: cst.SimpleString) -> None:
184                self.visits.append(node.value)
185
186        # Parse a module and verify we visited correctly.
187        module = fixture(
188            """
189            def foo() -> None:
190                return "foo"
191
192            class A:
193                def foo(self) -> None:
194                    return "baz"
195            """
196        )
197        visitor = TestVisitor()
198        module.visit(visitor)
199
200        # We should have only visited a select number of nodes.
201        self.assertEqual(visitor.visits, ['"baz"'])
202
203    def test_multiple_visitors_transform(self) -> None:
204        # Set up a simple visitor with multiple visit decorators.
205        class TestVisitor(MatcherDecoratableTransformer):
206            def __init__(self) -> None:
207                super().__init__()
208                self.visits: List[str] = []
209
210            @call_if_inside(m.ClassDef(m.Name("A")))
211            @call_if_inside(m.FunctionDef(m.Name("foo")))
212            def visit_SimpleString(self, node: cst.SimpleString) -> None:
213                self.visits.append(node.value)
214
215        # Parse a module and verify we visited correctly.
216        module = fixture(
217            """
218            def foo() -> None:
219                return "foo"
220
221            class A:
222                def foo(self) -> None:
223                    return "baz"
224            """
225        )
226        visitor = TestVisitor()
227        module.visit(visitor)
228
229        # We should have only visited a select number of nodes.
230        self.assertEqual(visitor.visits, ['"baz"'])
231
232    def test_call_if_not_inside_transform_simple(self) -> None:
233        # Set up a simple visitor with a call_if_inside decorator.
234        class TestVisitor(MatcherDecoratableTransformer):
235            def __init__(self) -> None:
236                super().__init__()
237                self.visits: List[str] = []
238                self.leaves: List[str] = []
239
240            @call_if_not_inside(m.FunctionDef(m.Name("foo")))
241            def visit_SimpleString(self, node: cst.SimpleString) -> None:
242                self.visits.append(node.value)
243
244            @call_if_not_inside(m.FunctionDef())
245            def leave_SimpleString(
246                self, original_node: cst.SimpleString, updated_node: cst.SimpleString
247            ) -> cst.SimpleString:
248                self.leaves.append(updated_node.value)
249                return updated_node
250
251        # Parse a module and verify we visited correctly.
252        module = fixture(
253            """
254            a = "foo"
255            b = "bar"
256
257            def foo() -> None:
258                return "baz"
259
260            def bar() -> None:
261                return "foobar"
262        """
263        )
264        visitor = TestVisitor()
265        module.visit(visitor)
266
267        # We should have only visited a select number of nodes.
268        self.assertEqual(visitor.visits, ['"foo"', '"bar"', '"foobar"'])
269        self.assertEqual(visitor.leaves, ['"foo"', '"bar"'])
270
271    def test_visit_if_inot_inside_verify_original_transform(self) -> None:
272        # Set up a simple visitor with a call_if_inside decorator.
273        class TestVisitor(MatcherDecoratableTransformer):
274            def __init__(self) -> None:
275                super().__init__()
276                self.func_visits: List[str] = []
277                self.str_visits: List[str] = []
278
279            @call_if_not_inside(m.FunctionDef(m.Name("foo")))
280            def visit_SimpleString(self, node: cst.SimpleString) -> None:
281                self.str_visits.append(node.value)
282
283            def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
284                self.func_visits.append(node.name.value)
285
286        # Parse a module and verify we visited correctly.
287        module = fixture(
288            """
289            a = "foo"
290            b = "bar"
291
292            def foo() -> None:
293                return "baz"
294
295            def bar() -> None:
296                return "foobar"
297        """
298        )
299        visitor = TestVisitor()
300        module.visit(visitor)
301
302        # We should have only visited a select number of nodes.
303        self.assertEqual(visitor.func_visits, ["foo", "bar"])
304        self.assertEqual(visitor.str_visits, ['"foo"', '"bar"', '"foobar"'])
305
306    def test_call_if_not_inside_collect_simple(self) -> None:
307        # Set up a simple visitor with a call_if_inside decorator.
308        class TestVisitor(MatcherDecoratableVisitor):
309            def __init__(self) -> None:
310                super().__init__()
311                self.visits: List[str] = []
312                self.leaves: List[str] = []
313
314            @call_if_not_inside(m.FunctionDef(m.Name("foo")))
315            def visit_SimpleString(self, node: cst.SimpleString) -> None:
316                self.visits.append(node.value)
317
318            @call_if_not_inside(m.FunctionDef())
319            def leave_SimpleString(self, original_node: cst.SimpleString) -> None:
320                self.leaves.append(original_node.value)
321
322        # Parse a module and verify we visited correctly.
323        module = fixture(
324            """
325            a = "foo"
326            b = "bar"
327
328            def foo() -> None:
329                return "baz"
330
331            def bar() -> None:
332                return "foobar"
333        """
334        )
335        visitor = TestVisitor()
336        module.visit(visitor)
337
338        # We should have only visited a select number of nodes.
339        self.assertEqual(visitor.visits, ['"foo"', '"bar"', '"foobar"'])
340        self.assertEqual(visitor.leaves, ['"foo"', '"bar"'])
341
342    def test_visit_if_inot_inside_verify_original_collect(self) -> None:
343        # Set up a simple visitor with a call_if_inside decorator.
344        class TestVisitor(MatcherDecoratableVisitor):
345            def __init__(self) -> None:
346                super().__init__()
347                self.func_visits: List[str] = []
348                self.str_visits: List[str] = []
349
350            @call_if_not_inside(m.FunctionDef(m.Name("foo")))
351            def visit_SimpleString(self, node: cst.SimpleString) -> None:
352                self.str_visits.append(node.value)
353
354            def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
355                self.func_visits.append(node.name.value)
356
357        # Parse a module and verify we visited correctly.
358        module = fixture(
359            """
360            a = "foo"
361            b = "bar"
362
363            def foo() -> None:
364                return "baz"
365
366            def bar() -> None:
367                return "foobar"
368        """
369        )
370        visitor = TestVisitor()
371        module.visit(visitor)
372
373        # We should have only visited a select number of nodes.
374        self.assertEqual(visitor.func_visits, ["foo", "bar"])
375        self.assertEqual(visitor.str_visits, ['"foo"', '"bar"', '"foobar"'])
376
377
378class MatchersVisitLeaveDecoratorsTest(UnitTest):
379    def test_visit_transform(self) -> None:
380        # Set up a simple visitor with a visit and leave decorator.
381        class TestVisitor(MatcherDecoratableTransformer):
382            def __init__(self) -> None:
383                super().__init__()
384                self.visits: List[str] = []
385                self.leaves: List[str] = []
386
387            @visit(m.FunctionDef(m.Name("foo") | m.Name("bar")))
388            def visit_function(self, node: cst.FunctionDef) -> None:
389                self.visits.append(node.name.value)
390
391            @leave(m.FunctionDef(m.Name("bar") | m.Name("baz")))
392            def leave_function(
393                self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
394            ) -> cst.FunctionDef:
395                self.leaves.append(updated_node.name.value)
396                return updated_node
397
398        # Parse a module and verify we visited correctly.
399        module = fixture(
400            """
401            a = "foo"
402            b = "bar"
403
404            def foo() -> None:
405                return "baz"
406
407            def bar() -> None:
408                return "foobar"
409
410            def baz() -> None:
411                return "foobar"
412        """
413        )
414        visitor = TestVisitor()
415        module.visit(visitor)
416
417        # We should have only visited a select number of nodes.
418        self.assertEqual(visitor.visits, ["foo", "bar"])
419        self.assertEqual(visitor.leaves, ["bar", "baz"])
420
421    def test_visit_collector(self) -> None:
422        # Set up a simple visitor with a visit and leave decorator.
423        class TestVisitor(MatcherDecoratableVisitor):
424            def __init__(self) -> None:
425                super().__init__()
426                self.visits: List[str] = []
427                self.leaves: List[str] = []
428
429            @visit(m.FunctionDef(m.Name("foo") | m.Name("bar")))
430            def visit_function(self, node: cst.FunctionDef) -> None:
431                self.visits.append(node.name.value)
432
433            @leave(m.FunctionDef(m.Name("bar") | m.Name("baz")))
434            def leave_function(self, original_node: cst.FunctionDef) -> None:
435                self.leaves.append(original_node.name.value)
436
437        # Parse a module and verify we visited correctly.
438        module = fixture(
439            """
440            a = "foo"
441            b = "bar"
442
443            def foo() -> None:
444                return "baz"
445
446            def bar() -> None:
447                return "foobar"
448
449            def baz() -> None:
450                return "foobar"
451        """
452        )
453        visitor = TestVisitor()
454        module.visit(visitor)
455
456        # We should have only visited a select number of nodes.
457        self.assertEqual(visitor.visits, ["foo", "bar"])
458        self.assertEqual(visitor.leaves, ["bar", "baz"])
459
460    def test_stacked_visit_transform(self) -> None:
461        # Set up a simple visitor with a visit and leave decorator.
462        class TestVisitor(MatcherDecoratableTransformer):
463            def __init__(self) -> None:
464                super().__init__()
465                self.visits: List[str] = []
466                self.leaves: List[str] = []
467
468            @visit(m.FunctionDef(m.Name("foo")))
469            @visit(m.FunctionDef(m.Name("bar")))
470            def visit_function(self, node: cst.FunctionDef) -> None:
471                self.visits.append(node.name.value)
472
473            @leave(m.FunctionDef(m.Name("bar")))
474            @leave(m.FunctionDef(m.Name("baz")))
475            def leave_function(
476                self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
477            ) -> cst.FunctionDef:
478                self.leaves.append(updated_node.name.value)
479                return updated_node
480
481        # Parse a module and verify we visited correctly.
482        module = fixture(
483            """
484            a = "foo"
485            b = "bar"
486
487            def foo() -> None:
488                return "baz"
489
490            def bar() -> None:
491                return "foobar"
492
493            def baz() -> None:
494                return "foobar"
495        """
496        )
497        visitor = TestVisitor()
498        module.visit(visitor)
499
500        # We should have only visited a select number of nodes.
501        self.assertEqual(visitor.visits, ["foo", "bar"])
502        self.assertEqual(visitor.leaves, ["bar", "baz"])
503
504    def test_stacked_visit_collector(self) -> None:
505        # Set up a simple visitor with a visit and leave decorator.
506        class TestVisitor(MatcherDecoratableVisitor):
507            def __init__(self) -> None:
508                super().__init__()
509                self.visits: List[str] = []
510                self.leaves: List[str] = []
511
512            @visit(m.FunctionDef(m.Name("foo")))
513            @visit(m.FunctionDef(m.Name("bar")))
514            def visit_function(self, node: cst.FunctionDef) -> None:
515                self.visits.append(node.name.value)
516
517            @leave(m.FunctionDef(m.Name("bar")))
518            @leave(m.FunctionDef(m.Name("baz")))
519            def leave_function(self, original_node: cst.FunctionDef) -> None:
520                self.leaves.append(original_node.name.value)
521
522        # Parse a module and verify we visited correctly.
523        module = fixture(
524            """
525            a = "foo"
526            b = "bar"
527
528            def foo() -> None:
529                return "baz"
530
531            def bar() -> None:
532                return "foobar"
533
534            def baz() -> None:
535                return "foobar"
536        """
537        )
538        visitor = TestVisitor()
539        module.visit(visitor)
540
541        # We should have only visited a select number of nodes.
542        self.assertEqual(visitor.visits, ["foo", "bar"])
543        self.assertEqual(visitor.leaves, ["bar", "baz"])
544        self.assertEqual(visitor.leaves, ["bar", "baz"])
545
546    def test_duplicate_visit_transform(self) -> None:
547        # Set up a simple visitor with a visit and leave decorator.
548        class TestVisitor(MatcherDecoratableTransformer):
549            def __init__(self) -> None:
550                super().__init__()
551                self.visits: Set[str] = set()
552                self.leaves: Set[str] = set()
553
554            @visit(m.FunctionDef(m.Name("foo")))
555            def visit_function1(self, node: cst.FunctionDef) -> None:
556                self.visits.add(node.name.value + "1")
557
558            @visit(m.FunctionDef(m.Name("foo")))
559            def visit_function2(self, node: cst.FunctionDef) -> None:
560                self.visits.add(node.name.value + "2")
561
562            @leave(m.FunctionDef(m.Name("bar")))
563            def leave_function1(
564                self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
565            ) -> cst.FunctionDef:
566                self.leaves.add(updated_node.name.value + "1")
567                return updated_node
568
569            @leave(m.FunctionDef(m.Name("bar")))
570            def leave_function2(
571                self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
572            ) -> cst.FunctionDef:
573                self.leaves.add(updated_node.name.value + "2")
574                return updated_node
575
576        # Parse a module and verify we visited correctly.
577        module = fixture(
578            """
579            a = "foo"
580            b = "bar"
581
582            def foo() -> None:
583                return "baz"
584
585            def bar() -> None:
586                return "foobar"
587
588            def baz() -> None:
589                return "foobar"
590        """
591        )
592        visitor = TestVisitor()
593        module.visit(visitor)
594
595        # We should have only visited a select number of nodes.
596        self.assertEqual(visitor.visits, {"foo1", "foo2"})
597        self.assertEqual(visitor.leaves, {"bar1", "bar2"})
598
599    def test_duplicate_visit_collector(self) -> None:
600        # Set up a simple visitor with a visit and leave decorator.
601        class TestVisitor(MatcherDecoratableVisitor):
602            def __init__(self) -> None:
603                super().__init__()
604                self.visits: Set[str] = set()
605                self.leaves: Set[str] = set()
606
607            @visit(m.FunctionDef(m.Name("foo")))
608            def visit_function1(self, node: cst.FunctionDef) -> None:
609                self.visits.add(node.name.value + "1")
610
611            @visit(m.FunctionDef(m.Name("foo")))
612            def visit_function2(self, node: cst.FunctionDef) -> None:
613                self.visits.add(node.name.value + "2")
614
615            @leave(m.FunctionDef(m.Name("bar")))
616            def leave_function1(self, original_node: cst.FunctionDef) -> None:
617                self.leaves.add(original_node.name.value + "1")
618
619            @leave(m.FunctionDef(m.Name("bar")))
620            def leave_function2(self, original_node: cst.FunctionDef) -> None:
621                self.leaves.add(original_node.name.value + "2")
622
623        # Parse a module and verify we visited correctly.
624        module = fixture(
625            """
626            a = "foo"
627            b = "bar"
628
629            def foo() -> None:
630                return "baz"
631
632            def bar() -> None:
633                return "foobar"
634
635            def baz() -> None:
636                return "foobar"
637        """
638        )
639        visitor = TestVisitor()
640        module.visit(visitor)
641
642        # We should have only visited a select number of nodes.
643        self.assertEqual(visitor.visits, {"foo1", "foo2"})
644        self.assertEqual(visitor.leaves, {"bar1", "bar2"})
645
646    def test_gated_visit_transform(self) -> None:
647        # Set up a simple visitor with a visit and leave decorator.
648        class TestVisitor(MatcherDecoratableTransformer):
649            def __init__(self) -> None:
650                super().__init__()
651                self.visits: Set[str] = set()
652                self.leaves: Set[str] = set()
653
654            @call_if_inside(m.FunctionDef(m.Name("foo")))
655            @visit(m.SimpleString())
656            def visit_string1(self, node: cst.SimpleString) -> None:
657                self.visits.add(literal_eval(node.value) + "1")
658
659            @call_if_not_inside(m.FunctionDef(m.Name("bar")))
660            @visit(m.SimpleString())
661            def visit_string2(self, node: cst.SimpleString) -> None:
662                self.visits.add(literal_eval(node.value) + "2")
663
664            @call_if_inside(m.FunctionDef(m.Name("baz")))
665            @leave(m.SimpleString())
666            def leave_string1(
667                self, original_node: cst.SimpleString, updated_node: cst.SimpleString
668            ) -> cst.SimpleString:
669                self.leaves.add(literal_eval(updated_node.value) + "1")
670                return updated_node
671
672            @call_if_not_inside(m.FunctionDef(m.Name("foo")))
673            @leave(m.SimpleString())
674            def leave_string2(
675                self, original_node: cst.SimpleString, updated_node: cst.SimpleString
676            ) -> cst.SimpleString:
677                self.leaves.add(literal_eval(updated_node.value) + "2")
678                return updated_node
679
680        # Parse a module and verify we visited correctly.
681        module = fixture(
682            """
683            a = "foo"
684            b = "bar"
685
686            def foo() -> None:
687                return "baz"
688
689            def bar() -> None:
690                return "foobar"
691
692            def baz() -> None:
693                return "foobarbaz"
694        """
695        )
696        visitor = TestVisitor()
697        module.visit(visitor)
698
699        # We should have only visited a select number of nodes.
700        self.assertEqual(visitor.visits, {"baz1", "foo2", "bar2", "baz2", "foobarbaz2"})
701        self.assertEqual(
702            visitor.leaves, {"foobarbaz1", "foo2", "bar2", "foobar2", "foobarbaz2"}
703        )
704
705    def test_gated_visit_collect(self) -> None:
706        # Set up a simple visitor with a visit and leave decorator.
707        class TestVisitor(MatcherDecoratableVisitor):
708            def __init__(self) -> None:
709                super().__init__()
710                self.visits: Set[str] = set()
711                self.leaves: Set[str] = set()
712
713            @call_if_inside(m.FunctionDef(m.Name("foo")))
714            @visit(m.SimpleString())
715            def visit_string1(self, node: cst.SimpleString) -> None:
716                self.visits.add(literal_eval(node.value) + "1")
717
718            @call_if_not_inside(m.FunctionDef(m.Name("bar")))
719            @visit(m.SimpleString())
720            def visit_string2(self, node: cst.SimpleString) -> None:
721                self.visits.add(literal_eval(node.value) + "2")
722
723            @call_if_inside(m.FunctionDef(m.Name("baz")))
724            @leave(m.SimpleString())
725            def leave_string1(self, original_node: cst.SimpleString) -> None:
726                self.leaves.add(literal_eval(original_node.value) + "1")
727
728            @call_if_not_inside(m.FunctionDef(m.Name("foo")))
729            @leave(m.SimpleString())
730            def leave_string2(self, original_node: cst.SimpleString) -> None:
731                self.leaves.add(literal_eval(original_node.value) + "2")
732
733        # Parse a module and verify we visited correctly.
734        module = fixture(
735            """
736            a = "foo"
737            b = "bar"
738
739            def foo() -> None:
740                return "baz"
741
742            def bar() -> None:
743                return "foobar"
744
745            def baz() -> None:
746                return "foobarbaz"
747        """
748        )
749        visitor = TestVisitor()
750        module.visit(visitor)
751
752        # We should have only visited a select number of nodes.
753        self.assertEqual(visitor.visits, {"baz1", "foo2", "bar2", "baz2", "foobarbaz2"})
754        self.assertEqual(
755            visitor.leaves, {"foobarbaz1", "foo2", "bar2", "foobar2", "foobarbaz2"}
756        )
757
758    def test_transform_order(self) -> None:
759        # Set up a simple visitor with a visit and leave decorator.
760        class TestVisitor(MatcherDecoratableTransformer):
761            @call_if_inside(m.FunctionDef(m.Name("bar")))
762            @leave(m.SimpleString())
763            def leave_string1(
764                self, original_node: cst.SimpleString, updated_node: cst.SimpleString
765            ) -> cst.SimpleString:
766                return updated_node.with_changes(
767                    value=f'"prefix{literal_eval(updated_node.value)}"'
768                )
769
770            @call_if_inside(m.FunctionDef(m.Name("bar")))
771            @leave(m.SimpleString())
772            def leave_string2(
773                self, original_node: cst.SimpleString, updated_node: cst.SimpleString
774            ) -> cst.SimpleString:
775                return updated_node.with_changes(
776                    value=f'"{literal_eval(updated_node.value)}suffix"'
777                )
778
779            @call_if_inside(m.FunctionDef(m.Name("bar")))
780            def leave_SimpleString(
781                self, original_node: cst.SimpleString, updated_node: cst.SimpleString
782            ) -> cst.SimpleString:
783                return updated_node.with_changes(
784                    value=f'"{"".join(reversed(literal_eval(updated_node.value)))}"'
785                )
786
787        # Parse a module and verify we visited correctly.
788        module = fixture(
789            """
790            a = "foo"
791            b = "bar"
792
793            def foo() -> None:
794                return "baz"
795
796            def bar() -> None:
797                return "foobar"
798
799            def baz() -> None:
800                return "foobarbaz"
801            """
802        )
803        visitor = TestVisitor()
804        actual = module.visit(visitor)
805        expected = fixture(
806            """
807            a = "foo"
808            b = "bar"
809
810            def foo() -> None:
811                return "baz"
812
813            def bar() -> None:
814                return "prefixraboofsuffix"
815
816            def baz() -> None:
817                return "foobarbaz"
818            """
819        )
820        self.assertTrue(expected.deep_equals(actual))
821
822    def test_call_if_inside_visitor_attribute(self) -> None:
823        # Set up a simple visitor with a call_if_inside decorator.
824        class TestVisitor(MatcherDecoratableVisitor):
825            def __init__(self) -> None:
826                super().__init__()
827                self.visits: List[str] = []
828                self.leaves: List[str] = []
829
830            @call_if_inside(m.FunctionDef(m.Name("foo")))
831            def visit_SimpleString_lpar(self, node: cst.SimpleString) -> None:
832                self.visits.append(node.value)
833
834            @call_if_inside(m.FunctionDef())
835            def leave_SimpleString_lpar(self, node: cst.SimpleString) -> None:
836                self.leaves.append(node.value)
837
838        # Parse a module and verify we visited correctly.
839        module = fixture(
840            """
841            a = "foo"
842            b = "bar"
843
844            def foo() -> None:
845                return "baz"
846
847            def bar() -> None:
848                return "foobar"
849        """
850        )
851        visitor = TestVisitor()
852        module.visit(visitor)
853
854        # We should have only visited a select number of nodes.
855        self.assertEqual(visitor.visits, ['"baz"'])
856        self.assertEqual(visitor.leaves, ['"baz"', '"foobar"'])
857
858    def test_call_if_inside_transform_attribute(self) -> None:
859        # Set up a simple visitor with a call_if_inside decorator.
860        class TestVisitor(MatcherDecoratableTransformer):
861            def __init__(self) -> None:
862                super().__init__()
863                self.visits: List[str] = []
864                self.leaves: List[str] = []
865
866            @call_if_inside(m.FunctionDef(m.Name("foo")))
867            def visit_SimpleString_lpar(self, node: cst.SimpleString) -> None:
868                self.visits.append(node.value)
869
870            @call_if_inside(m.FunctionDef())
871            def leave_SimpleString_lpar(self, node: cst.SimpleString) -> None:
872                self.leaves.append(node.value)
873
874        # Parse a module and verify we visited correctly.
875        module = fixture(
876            """
877            a = "foo"
878            b = "bar"
879
880            def foo() -> None:
881                return "baz"
882
883            def bar() -> None:
884                return "foobar"
885        """
886        )
887        visitor = TestVisitor()
888        module.visit(visitor)
889
890        # We should have only visited a select number of nodes.
891        self.assertEqual(visitor.visits, ['"baz"'])
892        self.assertEqual(visitor.leaves, ['"baz"', '"foobar"'])
893
894    def test_call_if_not_inside_visitor_attribute(self) -> None:
895        # Set up a simple visitor with a call_if_inside decorator.
896        class TestVisitor(MatcherDecoratableVisitor):
897            def __init__(self) -> None:
898                super().__init__()
899                self.visits: List[str] = []
900                self.leaves: List[str] = []
901
902            @call_if_not_inside(m.FunctionDef(m.Name("foo")))
903            def visit_SimpleString_lpar(self, node: cst.SimpleString) -> None:
904                self.visits.append(node.value)
905
906            @call_if_not_inside(m.FunctionDef())
907            def leave_SimpleString_lpar(self, node: cst.SimpleString) -> None:
908                self.leaves.append(node.value)
909
910        # Parse a module and verify we visited correctly.
911        module = fixture(
912            """
913            a = "foo"
914            b = "bar"
915
916            def foo() -> None:
917                return "baz"
918
919            def bar() -> None:
920                return "foobar"
921        """
922        )
923        visitor = TestVisitor()
924        module.visit(visitor)
925
926        # We should have only visited a select number of nodes.
927        self.assertEqual(visitor.visits, ['"foo"', '"bar"', '"foobar"'])
928        self.assertEqual(visitor.leaves, ['"foo"', '"bar"'])
929
930    def test_call_if_not_inside_transform_attribute(self) -> None:
931        # Set up a simple visitor with a call_if_inside decorator.
932        class TestVisitor(MatcherDecoratableTransformer):
933            def __init__(self) -> None:
934                super().__init__()
935                self.visits: List[str] = []
936                self.leaves: List[str] = []
937
938            @call_if_not_inside(m.FunctionDef(m.Name("foo")))
939            def visit_SimpleString_lpar(self, node: cst.SimpleString) -> None:
940                self.visits.append(node.value)
941
942            @call_if_not_inside(m.FunctionDef())
943            def leave_SimpleString_lpar(self, node: cst.SimpleString) -> None:
944                self.leaves.append(node.value)
945
946        # Parse a module and verify we visited correctly.
947        module = fixture(
948            """
949            a = "foo"
950            b = "bar"
951
952            def foo() -> None:
953                return "baz"
954
955            def bar() -> None:
956                return "foobar"
957        """
958        )
959        visitor = TestVisitor()
960        module.visit(visitor)
961
962        # We should have only visited a select number of nodes.
963        self.assertEqual(visitor.visits, ['"foo"', '"bar"', '"foobar"'])
964        self.assertEqual(visitor.leaves, ['"foo"', '"bar"'])
965
966    def test_init_with_unhashable_types(self) -> None:
967        # Set up a simple visitor with a call_if_inside decorator.
968        class TestVisitor(MatcherDecoratableTransformer):
969            def __init__(self) -> None:
970                super().__init__()
971                self.visits: List[str] = []
972
973            @call_if_inside(
974                m.FunctionDef(m.Name("foo"), params=m.Parameters([m.ZeroOrMore()]))
975            )
976            def visit_SimpleString(self, node: cst.SimpleString) -> None:
977                self.visits.append(node.value)
978
979        # Parse a module and verify we visited correctly.
980        module = fixture(
981            """
982            a = "foo"
983            b = "bar"
984
985            def foo() -> None:
986                return "baz"
987
988            def bar() -> None:
989                return "foobar"
990        """
991        )
992        visitor = TestVisitor()
993        module.visit(visitor)
994
995        # We should have only visited a select number of nodes.
996        self.assertEqual(visitor.visits, ['"baz"'])
997
998
999# This is meant to simulate `cst.ImportFrom | cst.RemovalSentinel` in py3.10
1000FakeUnionClass: Mock = Mock()
1001setattr(FakeUnionClass, "__name__", "Union")
1002setattr(FakeUnionClass, "__module__", "types")
1003FakeUnion: Mock = Mock()
1004FakeUnion.__class__ = FakeUnionClass
1005FakeUnion.__args__ = [cst.ImportFrom, cst.RemovalSentinel]
1006
1007
1008class MatchersUnionDecoratorsTest(UnitTest):
1009    def test_init_with_new_union_annotation(self) -> None:
1010        class TransformerWithUnionReturnAnnotation(m.MatcherDecoratableTransformer):
1011            @m.leave(m.ImportFrom(module=m.Name(value="typing")))
1012            def test(
1013                self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
1014            ) -> FakeUnion:
1015                pass
1016
1017        # assert that init (specifically _check_types on return annotation) passes
1018        TransformerWithUnionReturnAnnotation()
1019