1# Copyright (c) Facebook, Inc. and its affiliates. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6import dataclasses 7from contextlib import ExitStack 8from dataclasses import dataclass 9from typing import Any, Callable, Iterable, List, Optional, Sequence, Type 10from unittest.mock import patch 11 12import libcst as cst 13from libcst._nodes.internal import CodegenState, visit_required 14from libcst._types import CSTNodeT 15from libcst._visitors import CSTTransformer, CSTVisitorT 16from libcst.metadata import CodeRange, PositionProvider 17from libcst.metadata.position_provider import PositionProvidingCodegenState 18from libcst.testing.utils import UnitTest 19 20 21@dataclass(frozen=True) 22class _CSTCodegenPatchTarget: 23 type: Type[cst.CSTNode] 24 name: str 25 old_codegen: Callable[..., None] 26 27 28class _NOOPVisitor(CSTTransformer): 29 pass 30 31 32def _cst_node_equality_func( 33 a: cst.CSTNode, b: cst.CSTNode, msg: Optional[str] = None 34) -> None: 35 """ 36 For use with addTypeEqualityFunc. 37 """ 38 if not a.deep_equals(b): 39 suffix = "" if msg is None else f"\n{msg}" 40 raise AssertionError(f"\n{a!r}\nis not deeply equal to \n{b!r}{suffix}") 41 42 43def parse_expression_as(**config: Any) -> Callable[[str], cst.BaseExpression]: 44 def inner(code: str) -> cst.BaseExpression: 45 return cst.parse_expression(code, config=cst.PartialParserConfig(**config)) 46 47 return inner 48 49 50def parse_statement_as(**config: Any) -> Callable[[str], cst.BaseStatement]: 51 def inner(code: str) -> cst.BaseStatement: 52 return cst.parse_statement(code, config=cst.PartialParserConfig(**config)) 53 54 return inner 55 56 57# We can't use an ABCMeta here, because of metaclass conflicts 58class CSTNodeTest(UnitTest): 59 def setUp(self) -> None: 60 # Fix `self.assertEqual` for CSTNode subclasses. We should compare equality by 61 # value instead of identity (what `CSTNode.__eq__` does) for tests. 62 # 63 # The time complexity of CSTNode.deep_equals doesn't matter much inside tests. 64 for v in cst.__dict__.values(): 65 if isinstance(v, type) and issubclass(v, cst.CSTNode): 66 self.addTypeEqualityFunc(v, _cst_node_equality_func) 67 self.addTypeEqualityFunc(DummyIndentedBlock, _cst_node_equality_func) 68 69 def validate_node( 70 self, 71 node: CSTNodeT, 72 code: str, 73 parser: Optional[Callable[[str], CSTNodeT]] = None, 74 expected_position: Optional[CodeRange] = None, 75 ) -> None: 76 node.validate_types_deep() 77 self.__assert_codegen(node, code, expected_position) 78 79 if parser is not None: 80 parsed_node = parser(code) 81 self.assertEqual(parsed_node, node) 82 83 # Tests of children should unwrap DummyIndentedBlock first, because we don't 84 # want to test DummyIndentedBlock's behavior. 85 unwrapped_node = node 86 while isinstance(unwrapped_node, DummyIndentedBlock): 87 unwrapped_node = unwrapped_node.child 88 self.__assert_children_match_codegen(unwrapped_node) 89 self.__assert_children_match_fields(unwrapped_node) 90 self.__assert_visit_returns_identity(unwrapped_node) 91 92 def assert_invalid( 93 self, get_node: Callable[[], cst.CSTNode], expected_re: str 94 ) -> None: 95 with self.assertRaisesRegex(cst.CSTValidationError, expected_re): 96 get_node() 97 98 def assert_invalid_types( 99 self, get_node: Callable[[], cst.CSTNode], expected_re: str 100 ) -> None: 101 with self.assertRaisesRegex(TypeError, expected_re): 102 get_node().validate_types_shallow() 103 104 def __assert_codegen( 105 self, 106 node: cst.CSTNode, 107 expected: str, 108 expected_position: Optional[CodeRange] = None, 109 ) -> None: 110 """ 111 Verifies that the given node's `_codegen` method is correct. 112 """ 113 module = cst.Module([]) 114 self.assertEqual(module.code_for_node(node), expected) 115 116 if expected_position is not None: 117 # This is using some internal APIs, because we only want to compute 118 # position for the node being tested, not a whole module. 119 # 120 # Normally, this is a nonsense operation (how can a node have a position if 121 # its not in a module?), which is why it's not supported, but it makes 122 # sense in the context of these node tests. 123 provider = PositionProvider() 124 state = PositionProvidingCodegenState( 125 default_indent=module.default_indent, 126 default_newline=module.default_newline, 127 provider=provider, 128 ) 129 node._codegen(state) 130 self.assertEqual(provider._computed[node], expected_position) 131 132 def __assert_children_match_codegen(self, node: cst.CSTNode) -> None: 133 children = node.children 134 codegen_children = self.__derive_children_from_codegen(node) 135 self.assertSequenceEqual( 136 children, 137 codegen_children, 138 msg=( 139 "The list of children we got from `node.children` differs from the " 140 + "children that were visited by `node._codegen`." 141 ), 142 ) 143 144 def __derive_children_from_codegen( 145 self, node: cst.CSTNode 146 ) -> Sequence[cst.CSTNode]: 147 """ 148 Patches all subclasses of `CSTNode` exported by the `cst` module to track which 149 `_codegen` methods get called, generating a list of children. 150 151 Because all children must be rendered out into lexical order, this should be 152 equivalent to `node.children`. 153 154 `node.children` uses `_visit_and_replace_children` under the hood, not 155 `_codegen`, so this helps us verify that both of those two method's behaviors 156 are in sync. 157 """ 158 159 patch_targets: Iterable[_CSTCodegenPatchTarget] = [ 160 _CSTCodegenPatchTarget(type=v, name=k, old_codegen=v._codegen) 161 for (k, v) in cst.__dict__.items() 162 if isinstance(v, type) 163 and issubclass(v, cst.CSTNode) 164 and hasattr(v, "_codegen") 165 ] 166 167 children: List[cst.CSTNode] = [] 168 codegen_stack: List[cst.CSTNode] = [] 169 170 def _get_codegen_override( 171 target: _CSTCodegenPatchTarget, 172 ) -> Callable[..., None]: 173 def _codegen_impl(self: CSTNodeT, *args: Any, **kwargs: Any) -> None: 174 should_pop = False 175 # Don't stick duplicates in the stack. This is needed so that we don't 176 # track calls to `super()._codegen()`. 177 if len(codegen_stack) == 0 or codegen_stack[-1] is not self: 178 # Check the stack to see that we're a direct child, not the root or 179 # a transitive child. 180 if len(codegen_stack) == 1: 181 children.append(self) 182 codegen_stack.append(self) 183 should_pop = True 184 target.old_codegen(self, *args, **kwargs) 185 # only pop if we pushed something to the stack earlier 186 if should_pop: 187 codegen_stack.pop() 188 189 return _codegen_impl 190 191 with ExitStack() as patch_stack: 192 for t in patch_targets: 193 patch_stack.enter_context( 194 patch(f"libcst.{t.name}._codegen", _get_codegen_override(t)) 195 ) 196 # Execute `node._codegen()` 197 cst.Module([]).code_for_node(node) 198 199 return children 200 201 def __assert_children_match_fields(self, node: cst.CSTNode) -> None: 202 """ 203 We expect `node.children` to match everything we can extract from the node's 204 fields, but maybe in a different order. This asserts that those things match. 205 206 If you want to verify order as well, use `assert_children_ordered`. 207 """ 208 node_children_ids = {id(child) for child in node.children} 209 fields = dataclasses.fields(node) 210 field_child_ids = set() 211 for f in fields: 212 value = getattr(node, f.name) 213 if isinstance(value, cst.CSTNode): 214 field_child_ids.add(id(value)) 215 elif isinstance(value, Iterable): 216 field_child_ids.update( 217 id(el) for el in value if isinstance(el, cst.CSTNode) 218 ) 219 220 # order doesn't matter 221 self.assertSetEqual( 222 node_children_ids, 223 field_child_ids, 224 msg="`node.children` doesn't match what we found through introspection", 225 ) 226 227 def __assert_visit_returns_identity(self, node: cst.CSTNode) -> None: 228 """ 229 When visit is called with a visitor that acts as a no-op, the visit method 230 should return the same node it started with. 231 """ 232 # TODO: We're only checking equality right now, because visit currently clones 233 # the node, since that was easier to implement. We should fix that behavior in a 234 # later version and tighten this check. 235 self.assertEqual(node, node.visit(_NOOPVisitor())) 236 237 def assert_parses( 238 self, 239 code: str, 240 parser: Callable[[str], cst.BaseExpression], 241 expect_success: bool, 242 ) -> None: 243 if not expect_success: 244 with self.assertRaises(cst.ParserSyntaxError): 245 parser(code) 246 else: 247 parser(code) 248 249 250@dataclass(frozen=True) 251class DummyIndentedBlock(cst.CSTNode): 252 """ 253 A stripped-down version of cst.IndentedBlock that only sets/clears the indentation 254 state for the purpose of testing cst.IndentWhitespace in isolation. 255 """ 256 257 value: str 258 child: cst.CSTNode 259 260 def _codegen_impl(self, state: CodegenState) -> None: 261 state.increase_indent(self.value) 262 with state.record_syntactic_position( 263 self, start_node=self.child, end_node=self.child 264 ): 265 self.child._codegen(state) 266 state.decrease_indent() 267 268 def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "DummyIndentedBlock": 269 return DummyIndentedBlock( 270 value=self.value, child=visit_required(self, "child", self.child, visitor) 271 ) 272