1# Copyright (c) Facebook, Inc. and its affiliates. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6from typing import List 7 8import libcst as cst 9from libcst import CSTTransformer, CSTVisitor, parse_module 10from libcst.testing.utils import UnitTest 11 12 13class VisitorTest(UnitTest): 14 def test_visitor(self) -> None: 15 class SomeVisitor(CSTVisitor): 16 def __init__(self) -> None: 17 self.visit_order: List[str] = [] 18 19 def visit_If(self, node: cst.If) -> None: 20 self.visit_order.append("visit_If") 21 22 def leave_If(self, original_node: cst.If) -> None: 23 self.visit_order.append("leave_If") 24 25 def visit_If_test(self, node: cst.If) -> None: 26 self.visit_order.append("visit_If_test") 27 28 def leave_If_test(self, node: cst.If) -> None: 29 self.visit_order.append("leave_If_test") 30 31 def visit_Name(self, node: cst.Name) -> None: 32 self.visit_order.append("visit_Name") 33 34 def leave_Name(self, original_node: cst.Name) -> None: 35 self.visit_order.append("leave_Name") 36 37 # Create and visit a simple module. 38 module = parse_module("if True:\n pass") 39 visitor = SomeVisitor() 40 module.visit(visitor) 41 42 # Check that visits worked. 43 self.assertEqual( 44 visitor.visit_order, 45 [ 46 "visit_If", 47 "visit_If_test", 48 "visit_Name", 49 "leave_Name", 50 "leave_If_test", 51 "leave_If", 52 ], 53 ) 54 55 def test_transformer(self) -> None: 56 class SomeTransformer(CSTTransformer): 57 def __init__(self) -> None: 58 self.visit_order: List[str] = [] 59 60 def visit_If(self, node: cst.If) -> None: 61 self.visit_order.append("visit_If") 62 63 def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If: 64 self.visit_order.append("leave_If") 65 return updated_node 66 67 def visit_If_test(self, node: cst.If) -> None: 68 self.visit_order.append("visit_If_test") 69 70 def leave_If_test(self, node: cst.If) -> None: 71 self.visit_order.append("leave_If_test") 72 73 def visit_Name(self, node: cst.Name) -> None: 74 self.visit_order.append("visit_Name") 75 76 def leave_Name( 77 self, original_node: cst.Name, updated_node: cst.Name 78 ) -> cst.Name: 79 self.visit_order.append("leave_Name") 80 return updated_node 81 82 # Create and visit a simple module. 83 module = parse_module("if True:\n pass") 84 transformer = SomeTransformer() 85 module.visit(transformer) 86 87 # Check that visits worked. 88 self.assertEqual( 89 transformer.visit_order, 90 [ 91 "visit_If", 92 "visit_If_test", 93 "visit_Name", 94 "leave_Name", 95 "leave_If_test", 96 "leave_If", 97 ], 98 ) 99