1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6from typing import List
7
8import libcst as cst
9from libcst import CSTTransformer, CSTVisitor, parse_module
10from libcst.testing.utils import UnitTest
11
12
13class VisitorTest(UnitTest):
14    def test_visitor(self) -> None:
15        class SomeVisitor(CSTVisitor):
16            def __init__(self) -> None:
17                self.visit_order: List[str] = []
18
19            def visit_If(self, node: cst.If) -> None:
20                self.visit_order.append("visit_If")
21
22            def leave_If(self, original_node: cst.If) -> None:
23                self.visit_order.append("leave_If")
24
25            def visit_If_test(self, node: cst.If) -> None:
26                self.visit_order.append("visit_If_test")
27
28            def leave_If_test(self, node: cst.If) -> None:
29                self.visit_order.append("leave_If_test")
30
31            def visit_Name(self, node: cst.Name) -> None:
32                self.visit_order.append("visit_Name")
33
34            def leave_Name(self, original_node: cst.Name) -> None:
35                self.visit_order.append("leave_Name")
36
37        # Create and visit a simple module.
38        module = parse_module("if True:\n    pass")
39        visitor = SomeVisitor()
40        module.visit(visitor)
41
42        # Check that visits worked.
43        self.assertEqual(
44            visitor.visit_order,
45            [
46                "visit_If",
47                "visit_If_test",
48                "visit_Name",
49                "leave_Name",
50                "leave_If_test",
51                "leave_If",
52            ],
53        )
54
55    def test_transformer(self) -> None:
56        class SomeTransformer(CSTTransformer):
57            def __init__(self) -> None:
58                self.visit_order: List[str] = []
59
60            def visit_If(self, node: cst.If) -> None:
61                self.visit_order.append("visit_If")
62
63            def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
64                self.visit_order.append("leave_If")
65                return updated_node
66
67            def visit_If_test(self, node: cst.If) -> None:
68                self.visit_order.append("visit_If_test")
69
70            def leave_If_test(self, node: cst.If) -> None:
71                self.visit_order.append("leave_If_test")
72
73            def visit_Name(self, node: cst.Name) -> None:
74                self.visit_order.append("visit_Name")
75
76            def leave_Name(
77                self, original_node: cst.Name, updated_node: cst.Name
78            ) -> cst.Name:
79                self.visit_order.append("leave_Name")
80                return updated_node
81
82        # Create and visit a simple module.
83        module = parse_module("if True:\n    pass")
84        transformer = SomeTransformer()
85        module.visit(transformer)
86
87        # Check that visits worked.
88        self.assertEqual(
89            transformer.visit_order,
90            [
91                "visit_If",
92                "visit_If_test",
93                "visit_Name",
94                "leave_Name",
95                "leave_If_test",
96                "leave_If",
97            ],
98        )
99