1# Copyright (c) Facebook, Inc. and its affiliates. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6from abc import ABC, abstractmethod 7from copy import deepcopy 8from dataclasses import dataclass, field, fields, replace 9from typing import Any, Dict, List, Mapping, Sequence, TypeVar, Union, cast 10 11from libcst._flatten_sentinel import FlattenSentinel 12from libcst._nodes.internal import CodegenState 13from libcst._removal_sentinel import RemovalSentinel 14from libcst._type_enforce import is_value_of_type 15from libcst._types import CSTNodeT 16from libcst._visitors import CSTTransformer, CSTVisitor, CSTVisitorT 17 18_CSTNodeSelfT = TypeVar("_CSTNodeSelfT", bound="CSTNode") 19_EMPTY_SEQUENCE: Sequence["CSTNode"] = () 20 21 22class CSTValidationError(SyntaxError): 23 pass 24 25 26class CSTCodegenError(SyntaxError): 27 pass 28 29 30class _ChildrenCollectionVisitor(CSTVisitor): 31 def __init__(self) -> None: 32 self.children: List[CSTNode] = [] 33 34 def on_visit(self, node: "CSTNode") -> bool: 35 self.children.append(node) 36 return False # Don't include transitive children 37 38 39class _ChildReplacementTransformer(CSTTransformer): 40 def __init__( 41 self, old_node: "CSTNode", new_node: Union["CSTNode", RemovalSentinel] 42 ) -> None: 43 self.old_node = old_node 44 self.new_node = new_node 45 46 def on_visit(self, node: "CSTNode") -> bool: 47 # If the node is one we are about to replace, we shouldn't 48 # recurse down it, that would be a waste of time. 49 return node is not self.old_node 50 51 def on_leave( 52 self, original_node: "CSTNode", updated_node: "CSTNode" 53 ) -> Union["CSTNode", RemovalSentinel]: 54 if original_node is self.old_node: 55 return self.new_node 56 return updated_node 57 58 59class _ChildWithChangesTransformer(CSTTransformer): 60 def __init__(self, old_node: "CSTNode", changes: Mapping[str, Any]) -> None: 61 self.old_node = old_node 62 self.changes = changes 63 64 def on_visit(self, node: "CSTNode") -> bool: 65 # If the node is one we are about to replace, we shouldn't 66 # recurse down it, that would be a waste of time. 67 return node is not self.old_node 68 69 def on_leave(self, original_node: "CSTNode", updated_node: "CSTNode") -> "CSTNode": 70 if original_node is self.old_node: 71 return updated_node.with_changes(**self.changes) 72 return updated_node 73 74 75class _NOOPVisitor(CSTTransformer): 76 pass 77 78 79def _pretty_repr(value: object) -> str: 80 if not isinstance(value, str) and isinstance(value, Sequence): 81 return _pretty_repr_sequence(value) 82 else: 83 return repr(value) 84 85 86def _pretty_repr_sequence(seq: Sequence[object]) -> str: 87 if len(seq) == 0: 88 return "[]" 89 else: 90 return "\n".join(["[", *[f"{_indent(repr(el))}," for el in seq], "]"]) 91 92 93def _indent(value: str) -> str: 94 return "\n".join(f" {line}" for line in value.split("\n")) 95 96 97def _clone(val: object) -> object: 98 # We can't use isinstance(val, CSTNode) here due to poor performance 99 # of isinstance checks against ABC direct subclasses. What we're trying 100 # to do here is recursively call this functionality on subclasses, but 101 # if the attribute isn't a CSTNode, fall back to copy.deepcopy. 102 try: 103 # pyre-ignore We know this might not exist, that's the point of the 104 # attribute error and try block. 105 return val.deep_clone() 106 except AttributeError: 107 return deepcopy(val) 108 109 110@dataclass(frozen=True) 111class CSTNode(ABC): 112 def __post_init__(self) -> None: 113 # PERF: It might make more sense to move validation work into the visitor, which 114 # would allow us to avoid validating the tree when parsing a file. 115 self._validate() 116 117 @classmethod 118 def __init_subclass__(cls, **kwargs: Any) -> None: 119 """ 120 HACK: Add our implementation of `__repr__`, `__hash__`, and `__eq__` to the 121 class's __dict__ to prevent dataclass from generating it's own `__repr__`, 122 `__hash__`, and `__eq__`. 123 124 The alternative is to require each implementation of a node to remember to add 125 `repr=False, eq=False`, which is more error-prone. 126 """ 127 super().__init_subclass__(**kwargs) 128 129 if "__repr__" not in cls.__dict__: 130 cls.__repr__ = CSTNode.__repr__ 131 if "__eq__" not in cls.__dict__: 132 cls.__eq__ = CSTNode.__eq__ 133 if "__hash__" not in cls.__dict__: 134 cls.__hash__ = CSTNode.__hash__ 135 136 def _validate(self) -> None: 137 """ 138 Override this to perform runtime validation of a newly created node. 139 140 The function is called during `__init__`. It should check for possible mistakes 141 that wouldn't be caught by a static type checker. 142 143 If you can't use a static type checker, and want to perform a runtime validation 144 of this node's types, use `validate_types` instead. 145 """ 146 pass 147 148 def validate_types_shallow(self) -> None: 149 """ 150 Compares the type annotations on a node's fields with those field's actual 151 values at runtime. Raises a TypeError is a mismatch is found. 152 153 Only validates the current node, not any of it's children. For a recursive 154 version, see :func:`validate_types_deep`. 155 156 If you're using a static type checker (highly recommended), this is useless. 157 However, if your code doesn't use a static type checker, or if you're unable to 158 statically type your code for some reason, you can use this method to help 159 validate your tree. 160 161 Some (non-typing) validation is done unconditionally during the construction of 162 a node. That validation does not overlap with the work that 163 :func:`validate_types_deep` does. 164 """ 165 for f in fields(self): 166 value = getattr(self, f.name) 167 if not is_value_of_type(value, f.type): 168 raise TypeError( 169 f"Expected an instance of {f.type!r} on " 170 + f"{type(self).__name__}'s '{f.name}' field, but instead got " 171 + f"an instance of {type(value)!r}" 172 ) 173 174 def validate_types_deep(self) -> None: 175 """ 176 Like :func:`validate_types_shallow`, but recursively validates the whole tree. 177 """ 178 self.validate_types_shallow() 179 for ch in self.children: 180 ch.validate_types_deep() 181 182 @property 183 def children(self) -> Sequence["CSTNode"]: 184 """ 185 The immediate (not transitive) child CSTNodes of the current node. Various 186 properties on the nodes, such as string values, will not be visited if they are 187 not a subclass of CSTNode. 188 189 Iterable properties of the node (e.g. an IndentedBlock's body) will be flattened 190 into the children's sequence. 191 192 The children will always be returned in the same order that they appear 193 lexically in the code. 194 """ 195 196 # We're hooking into _visit_and_replace_children, which means that our current 197 # implementation is slow. We may need to rethink and/or cache this if it becomes 198 # a frequently accessed property. 199 # 200 # This probably won't be called frequently, because most child access will 201 # probably through visit, or directly through named property access, not through 202 # children. 203 204 visitor = _ChildrenCollectionVisitor() 205 self._visit_and_replace_children(visitor) 206 return visitor.children 207 208 def visit( 209 self: _CSTNodeSelfT, visitor: CSTVisitorT 210 ) -> Union[_CSTNodeSelfT, RemovalSentinel, FlattenSentinel[_CSTNodeSelfT]]: 211 """ 212 Visits the current node, its children, and all transitive children using 213 the given visitor's callbacks. 214 """ 215 # visit self 216 should_visit_children = visitor.on_visit(self) 217 218 # TODO: provide traversal where children are not replaced 219 # visit children (optionally) 220 if should_visit_children: 221 # It's not possible to define `_visit_and_replace_children` with the correct 222 # return type in any sane way, so we're using this cast. See the 223 # explanation above the declaration of `_visit_and_replace_children`. 224 with_updated_children = cast( 225 _CSTNodeSelfT, self._visit_and_replace_children(visitor) 226 ) 227 else: 228 with_updated_children = self 229 230 if isinstance(visitor, CSTVisitor): 231 visitor.on_leave(self) 232 leave_result = self 233 else: 234 leave_result = visitor.on_leave(self, with_updated_children) 235 236 # validate return type of the user-defined `visitor.on_leave` method 237 if not isinstance(leave_result, (CSTNode, RemovalSentinel, FlattenSentinel)): 238 raise Exception( 239 "Expected a node of type CSTNode or a RemovalSentinel, " 240 + f"but got a return value of {type(leave_result).__name__}" 241 ) 242 243 # TODO: Run runtime typechecks against updated nodes 244 245 return leave_result 246 247 # The return type of `_visit_and_replace_children` is `CSTNode`, not 248 # `_CSTNodeSelfT`. This is because pyre currently doesn't have a way to annotate 249 # classes as final. https://mypy.readthedocs.io/en/latest/final_attrs.html 250 # 251 # The issue is that any reasonable implementation of `_visit_and_replace_children` 252 # needs to refer to the class' own constructor: 253 # 254 # class While(CSTNode): 255 # def _visit_and_replace_children(self, visitor: CSTVisitorT) -> While: 256 # return While(...) 257 # 258 # You'll notice that because this implementation needs to call the `While` 259 # constructor, the return type is also `While`. This function is a valid subtype of 260 # `Callable[[CSTVisitorT], CSTNode]`. 261 # 262 # It is not a valid subtype of `Callable[[CSTVisitorT], _CSTNodeSelfT]`. That's 263 # because the return type of this function wouldn't be valid for any subclasses. 264 # In practice, that's not an issue, because we don't have any subclasses of `While`, 265 # but there's no way to tell pyre that without a `@final` annotation. 266 # 267 # Instead, we're just relying on an unchecked call to `cast()` in the `visit` 268 # method. 269 @abstractmethod 270 def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "CSTNode": 271 """ 272 Intended to be overridden by subclasses to provide a low-level hook for the 273 visitor API. 274 275 Don't call this directly. Instead, use `visitor.visit_and_replace_node` or 276 `visitor.visit_and_replace_module`. If you need list of children, access the 277 `children` property instead. 278 279 The general expectation is that children should be visited in the order in which 280 they appear lexically. 281 """ 282 ... 283 284 def _is_removable(self) -> bool: 285 """ 286 Intended to be overridden by nodes that will be iterated over inside 287 Module and IndentedBlock. Returning true signifies that this node is 288 essentially useless and can be dropped when doing a visit across it. 289 """ 290 return False 291 292 @abstractmethod 293 def _codegen_impl(self, state: CodegenState) -> None: 294 ... 295 296 def _codegen(self, state: CodegenState, **kwargs: Any) -> None: 297 state.before_codegen(self) 298 self._codegen_impl(state, **kwargs) 299 state.after_codegen(self) 300 301 def with_changes(self: _CSTNodeSelfT, **changes: Any) -> _CSTNodeSelfT: 302 """ 303 A convenience method for performing mutation-like operations on immutable nodes. 304 Creates a new object of the same type, replacing fields with values from the 305 supplied keyword arguments. 306 307 For example, to update the test of an if conditional, you could do:: 308 309 def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If: 310 new_node = updated_node.with_changes(test=new_conditional) 311 return new_node 312 313 ``new_node`` will have the same ``body``, ``orelse``, and whitespace fields as 314 ``updated_node``, but with the updated ``test`` field. 315 316 The accepted arguments match the arguments given to ``__init__``, however there 317 are no required or positional arguments. 318 319 TODO: This API is untyped. There's probably no sane way to type it using pyre's 320 current feature-set, but we should still think about ways to type this or a 321 similar API in the future. 322 """ 323 return replace(self, **changes) 324 325 def deep_clone(self: _CSTNodeSelfT) -> _CSTNodeSelfT: 326 """ 327 Recursively clone the entire tree. The created tree is a new tree has the same 328 representation but different identity. 329 330 >>> tree = cst.parse_expression("1+2") 331 332 >>> tree.deep_clone() == tree 333 False 334 335 >>> tree == tree 336 True 337 338 >>> tree.deep_equals(tree.deep_clone()) 339 True 340 """ 341 cloned_fields: Dict[str, object] = {} 342 for field in fields(self): 343 key = field.name 344 if key[0] == "_": 345 continue 346 val = getattr(self, key) 347 348 # Much like the comment on _clone itself, we are allergic to instance 349 # checks against Sequence because of speed issues with ABC classes. So, 350 # instead, first handle sequence types that we do not want to iterate on 351 # and then just try to iterate and clone. 352 if isinstance(val, (str, bytes)): 353 cloned_fields[key] = _clone(val) 354 else: 355 try: 356 cloned_fields[key] = tuple(_clone(v) for v in val) 357 except TypeError: 358 cloned_fields[key] = _clone(val) 359 360 return type(self)(**cloned_fields) 361 362 def deep_equals(self, other: "CSTNode") -> bool: 363 """ 364 Recursively inspects the entire tree under ``self`` and ``other`` to determine if 365 the two trees are equal by representation instead of identity (``==``). 366 """ 367 from libcst._nodes.deep_equals import deep_equals as deep_equals_impl 368 369 return deep_equals_impl(self, other) 370 371 def deep_replace( 372 self: _CSTNodeSelfT, old_node: "CSTNode", new_node: CSTNodeT 373 ) -> Union[_CSTNodeSelfT, CSTNodeT]: 374 """ 375 Recursively replaces any instance of ``old_node`` with ``new_node`` by identity. 376 Use this to avoid nested ``with_changes`` blocks when you are replacing one of 377 a node's deep children with a new node. Note that if you have previously 378 modified the tree in a way that ``old_node`` appears more than once as a deep 379 child, all instances will be replaced. 380 """ 381 new_tree = self.visit(_ChildReplacementTransformer(old_node, new_node)) 382 if isinstance(new_tree, (FlattenSentinel, RemovalSentinel)): 383 # The above transform never returns *Sentinel, so this isn't possible 384 raise Exception("Logic error, cannot get a *Sentinal here!") 385 return new_tree 386 387 def deep_remove( 388 self: _CSTNodeSelfT, old_node: "CSTNode" 389 ) -> Union[_CSTNodeSelfT, RemovalSentinel]: 390 """ 391 Recursively removes any instance of ``old_node`` by identity. Note that if you 392 have previously modified the tree in a way that ``old_node`` appears more than 393 once as a deep child, all instances will be removed. 394 """ 395 new_tree = self.visit( 396 _ChildReplacementTransformer(old_node, RemovalSentinel.REMOVE) 397 ) 398 399 if isinstance(new_tree, FlattenSentinel): 400 # The above transform never returns FlattenSentinel, so this isn't possible 401 raise Exception("Logic error, cannot get a FlattenSentinel here!") 402 403 return new_tree 404 405 def with_deep_changes( 406 self: _CSTNodeSelfT, old_node: "CSTNode", **changes: Any 407 ) -> _CSTNodeSelfT: 408 """ 409 A convenience method for applying :attr:`with_changes` to a child node. Use 410 this to avoid chains of :attr:`with_changes` or combinations of 411 :attr:`deep_replace` and :attr:`with_changes`. 412 413 The accepted arguments match the arguments given to the child node's 414 ``__init__``. 415 416 TODO: This API is untyped. There's probably no sane way to type it using pyre's 417 current feature-set, but we should still think about ways to type this or a 418 similar API in the future. 419 """ 420 new_tree = self.visit(_ChildWithChangesTransformer(old_node, changes)) 421 if isinstance(new_tree, (FlattenSentinel, RemovalSentinel)): 422 # This is impossible with the above transform. 423 raise Exception("Logic error, cannot get a *Sentinel here!") 424 return new_tree 425 426 def __eq__(self: _CSTNodeSelfT, other: _CSTNodeSelfT) -> bool: 427 """ 428 CSTNodes are only treated as equal by identity. This matches the behavior of 429 CPython's AST nodes. 430 431 If you actually want to compare the value instead of the identity of the current 432 node with another, use `node.deep_equals`. Because `deep_equals` must traverse 433 the entire tree, it can have an unexpectedly large time complexity. 434 435 We're not exposing value equality as the default behavior because of 436 `deep_equals`'s large time complexity. 437 """ 438 return self is other 439 440 def __hash__(self) -> int: 441 # Equality of nodes is based on identity, so the hash should be too. 442 return id(self) 443 444 def __repr__(self) -> str: 445 if len(fields(self)) == 0: 446 return f"{type(self).__name__}()" 447 448 lines = [f"{type(self).__name__}("] 449 for f in fields(self): 450 key = f.name 451 if key[0] != "_": 452 value = getattr(self, key) 453 lines.append(_indent(f"{key}={_pretty_repr(value)},")) 454 lines.append(")") 455 return "\n".join(lines) 456 457 @classmethod 458 # pyre-fixme[3]: Return annotation cannot be `Any`. 459 def field(cls, *args: object, **kwargs: object) -> Any: 460 """ 461 A helper that allows us to easily use CSTNodes in dataclass constructor 462 defaults without accidentally aliasing nodes by identity across multiple 463 instances. 464 """ 465 # pyre-ignore Pyre is complaining about CSTNode not being instantiable, 466 # but we're only going to call this from concrete subclasses. 467 return field(default_factory=lambda: cls(*args, **kwargs)) 468 469 470class BaseLeaf(CSTNode, ABC): 471 @property 472 def children(self) -> Sequence[CSTNode]: 473 # override this with an optimized implementation 474 return _EMPTY_SEQUENCE 475 476 def _visit_and_replace_children( 477 self: _CSTNodeSelfT, visitor: CSTVisitorT 478 ) -> _CSTNodeSelfT: 479 return self 480 481 482class BaseValueToken(BaseLeaf, ABC): 483 """ 484 Represents the subset of nodes that only contain a value. Not all tokens from the 485 tokenizer will exist as BaseValueTokens. In places where the token is always a 486 constant value (e.g. a COLON token), the token's value will be implicitly folded 487 into the parent CSTNode, and hard-coded into the implementation of _codegen. 488 """ 489 490 value: str 491 492 def _codegen_impl(self, state: CodegenState) -> None: 493 state.add_token(self.value) 494