1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6from abc import ABC, abstractmethod
7from copy import deepcopy
8from dataclasses import dataclass, field, fields, replace
9from typing import Any, Dict, List, Mapping, Sequence, TypeVar, Union, cast
10
11from libcst._flatten_sentinel import FlattenSentinel
12from libcst._nodes.internal import CodegenState
13from libcst._removal_sentinel import RemovalSentinel
14from libcst._type_enforce import is_value_of_type
15from libcst._types import CSTNodeT
16from libcst._visitors import CSTTransformer, CSTVisitor, CSTVisitorT
17
18_CSTNodeSelfT = TypeVar("_CSTNodeSelfT", bound="CSTNode")
19_EMPTY_SEQUENCE: Sequence["CSTNode"] = ()
20
21
22class CSTValidationError(SyntaxError):
23    pass
24
25
26class CSTCodegenError(SyntaxError):
27    pass
28
29
30class _ChildrenCollectionVisitor(CSTVisitor):
31    def __init__(self) -> None:
32        self.children: List[CSTNode] = []
33
34    def on_visit(self, node: "CSTNode") -> bool:
35        self.children.append(node)
36        return False  # Don't include transitive children
37
38
39class _ChildReplacementTransformer(CSTTransformer):
40    def __init__(
41        self, old_node: "CSTNode", new_node: Union["CSTNode", RemovalSentinel]
42    ) -> None:
43        self.old_node = old_node
44        self.new_node = new_node
45
46    def on_visit(self, node: "CSTNode") -> bool:
47        # If the node is one we are about to replace, we shouldn't
48        # recurse down it, that would be a waste of time.
49        return node is not self.old_node
50
51    def on_leave(
52        self, original_node: "CSTNode", updated_node: "CSTNode"
53    ) -> Union["CSTNode", RemovalSentinel]:
54        if original_node is self.old_node:
55            return self.new_node
56        return updated_node
57
58
59class _ChildWithChangesTransformer(CSTTransformer):
60    def __init__(self, old_node: "CSTNode", changes: Mapping[str, Any]) -> None:
61        self.old_node = old_node
62        self.changes = changes
63
64    def on_visit(self, node: "CSTNode") -> bool:
65        # If the node is one we are about to replace, we shouldn't
66        # recurse down it, that would be a waste of time.
67        return node is not self.old_node
68
69    def on_leave(self, original_node: "CSTNode", updated_node: "CSTNode") -> "CSTNode":
70        if original_node is self.old_node:
71            return updated_node.with_changes(**self.changes)
72        return updated_node
73
74
75class _NOOPVisitor(CSTTransformer):
76    pass
77
78
79def _pretty_repr(value: object) -> str:
80    if not isinstance(value, str) and isinstance(value, Sequence):
81        return _pretty_repr_sequence(value)
82    else:
83        return repr(value)
84
85
86def _pretty_repr_sequence(seq: Sequence[object]) -> str:
87    if len(seq) == 0:
88        return "[]"
89    else:
90        return "\n".join(["[", *[f"{_indent(repr(el))}," for el in seq], "]"])
91
92
93def _indent(value: str) -> str:
94    return "\n".join(f"    {line}" for line in value.split("\n"))
95
96
97def _clone(val: object) -> object:
98    # We can't use isinstance(val, CSTNode) here due to poor performance
99    # of isinstance checks against ABC direct subclasses. What we're trying
100    # to do here is recursively call this functionality on subclasses, but
101    # if the attribute isn't a CSTNode, fall back to copy.deepcopy.
102    try:
103        # pyre-ignore We know this might not exist, that's the point of the
104        # attribute error and try block.
105        return val.deep_clone()
106    except AttributeError:
107        return deepcopy(val)
108
109
110@dataclass(frozen=True)
111class CSTNode(ABC):
112    def __post_init__(self) -> None:
113        # PERF: It might make more sense to move validation work into the visitor, which
114        # would allow us to avoid validating the tree when parsing a file.
115        self._validate()
116
117    @classmethod
118    def __init_subclass__(cls, **kwargs: Any) -> None:
119        """
120        HACK: Add our implementation of `__repr__`, `__hash__`, and `__eq__` to the
121        class's __dict__ to prevent dataclass from generating it's own `__repr__`,
122        `__hash__`, and `__eq__`.
123
124        The alternative is to require each implementation of a node to remember to add
125        `repr=False, eq=False`, which is more error-prone.
126        """
127        super().__init_subclass__(**kwargs)
128
129        if "__repr__" not in cls.__dict__:
130            cls.__repr__ = CSTNode.__repr__
131        if "__eq__" not in cls.__dict__:
132            cls.__eq__ = CSTNode.__eq__
133        if "__hash__" not in cls.__dict__:
134            cls.__hash__ = CSTNode.__hash__
135
136    def _validate(self) -> None:
137        """
138        Override this to perform runtime validation of a newly created node.
139
140        The function is called during `__init__`. It should check for possible mistakes
141        that wouldn't be caught by a static type checker.
142
143        If you can't use a static type checker, and want to perform a runtime validation
144        of this node's types, use `validate_types` instead.
145        """
146        pass
147
148    def validate_types_shallow(self) -> None:
149        """
150        Compares the type annotations on a node's fields with those field's actual
151        values at runtime. Raises a TypeError is a mismatch is found.
152
153        Only validates the current node, not any of it's children. For a recursive
154        version, see :func:`validate_types_deep`.
155
156        If you're using a static type checker (highly recommended), this is useless.
157        However, if your code doesn't use a static type checker, or if you're unable to
158        statically type your code for some reason, you can use this method to help
159        validate your tree.
160
161        Some (non-typing) validation is done unconditionally during the construction of
162        a node. That validation does not overlap with the work that
163        :func:`validate_types_deep` does.
164        """
165        for f in fields(self):
166            value = getattr(self, f.name)
167            if not is_value_of_type(value, f.type):
168                raise TypeError(
169                    f"Expected an instance of {f.type!r} on "
170                    + f"{type(self).__name__}'s '{f.name}' field, but instead got "
171                    + f"an instance of {type(value)!r}"
172                )
173
174    def validate_types_deep(self) -> None:
175        """
176        Like :func:`validate_types_shallow`, but recursively validates the whole tree.
177        """
178        self.validate_types_shallow()
179        for ch in self.children:
180            ch.validate_types_deep()
181
182    @property
183    def children(self) -> Sequence["CSTNode"]:
184        """
185        The immediate (not transitive) child CSTNodes of the current node. Various
186        properties on the nodes, such as string values, will not be visited if they are
187        not a subclass of CSTNode.
188
189        Iterable properties of the node (e.g. an IndentedBlock's body) will be flattened
190        into the children's sequence.
191
192        The children will always be returned in the same order that they appear
193        lexically in the code.
194        """
195
196        # We're hooking into _visit_and_replace_children, which means that our current
197        # implementation is slow. We may need to rethink and/or cache this if it becomes
198        # a frequently accessed property.
199        #
200        # This probably won't be called frequently, because most child access will
201        # probably through visit, or directly through named property access, not through
202        # children.
203
204        visitor = _ChildrenCollectionVisitor()
205        self._visit_and_replace_children(visitor)
206        return visitor.children
207
208    def visit(
209        self: _CSTNodeSelfT, visitor: CSTVisitorT
210    ) -> Union[_CSTNodeSelfT, RemovalSentinel, FlattenSentinel[_CSTNodeSelfT]]:
211        """
212        Visits the current node, its children, and all transitive children using
213        the given visitor's callbacks.
214        """
215        # visit self
216        should_visit_children = visitor.on_visit(self)
217
218        # TODO: provide traversal where children are not replaced
219        # visit children (optionally)
220        if should_visit_children:
221            # It's not possible to define `_visit_and_replace_children` with the correct
222            # return type in any sane way, so we're using this cast. See the
223            # explanation above the declaration of `_visit_and_replace_children`.
224            with_updated_children = cast(
225                _CSTNodeSelfT, self._visit_and_replace_children(visitor)
226            )
227        else:
228            with_updated_children = self
229
230        if isinstance(visitor, CSTVisitor):
231            visitor.on_leave(self)
232            leave_result = self
233        else:
234            leave_result = visitor.on_leave(self, with_updated_children)
235
236        # validate return type of the user-defined `visitor.on_leave` method
237        if not isinstance(leave_result, (CSTNode, RemovalSentinel, FlattenSentinel)):
238            raise Exception(
239                "Expected a node of type CSTNode or a RemovalSentinel, "
240                + f"but got a return value of {type(leave_result).__name__}"
241            )
242
243        # TODO: Run runtime typechecks against updated nodes
244
245        return leave_result
246
247    # The return type of `_visit_and_replace_children` is `CSTNode`, not
248    # `_CSTNodeSelfT`. This is because pyre currently doesn't have a way to annotate
249    # classes as final. https://mypy.readthedocs.io/en/latest/final_attrs.html
250    #
251    # The issue is that any reasonable implementation of `_visit_and_replace_children`
252    # needs to refer to the class' own constructor:
253    #
254    #   class While(CSTNode):
255    #       def _visit_and_replace_children(self, visitor: CSTVisitorT) -> While:
256    #           return While(...)
257    #
258    # You'll notice that because this implementation needs to call the `While`
259    # constructor, the return type is also `While`. This function is a valid subtype of
260    # `Callable[[CSTVisitorT], CSTNode]`.
261    #
262    # It is not a valid subtype of `Callable[[CSTVisitorT], _CSTNodeSelfT]`. That's
263    # because the return type of this function wouldn't be valid for any subclasses.
264    # In practice, that's not an issue, because we don't have any subclasses of `While`,
265    # but there's no way to tell pyre that without a `@final` annotation.
266    #
267    # Instead, we're just relying on an unchecked call to `cast()` in the `visit`
268    # method.
269    @abstractmethod
270    def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "CSTNode":
271        """
272        Intended to be overridden by subclasses to provide a low-level hook for the
273        visitor API.
274
275        Don't call this directly. Instead, use `visitor.visit_and_replace_node` or
276        `visitor.visit_and_replace_module`. If you need list of children, access the
277        `children` property instead.
278
279        The general expectation is that children should be visited in the order in which
280        they appear lexically.
281        """
282        ...
283
284    def _is_removable(self) -> bool:
285        """
286        Intended to be overridden by nodes that will be iterated over inside
287        Module and IndentedBlock. Returning true signifies that this node is
288        essentially useless and can be dropped when doing a visit across it.
289        """
290        return False
291
292    @abstractmethod
293    def _codegen_impl(self, state: CodegenState) -> None:
294        ...
295
296    def _codegen(self, state: CodegenState, **kwargs: Any) -> None:
297        state.before_codegen(self)
298        self._codegen_impl(state, **kwargs)
299        state.after_codegen(self)
300
301    def with_changes(self: _CSTNodeSelfT, **changes: Any) -> _CSTNodeSelfT:
302        """
303        A convenience method for performing mutation-like operations on immutable nodes.
304        Creates a new object of the same type, replacing fields with values from the
305        supplied keyword arguments.
306
307        For example, to update the test of an if conditional, you could do::
308
309            def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
310                new_node = updated_node.with_changes(test=new_conditional)
311                return new_node
312
313        ``new_node`` will have the same ``body``, ``orelse``, and whitespace fields as
314        ``updated_node``, but with the updated ``test`` field.
315
316        The accepted arguments match the arguments given to ``__init__``, however there
317        are no required or positional arguments.
318
319        TODO: This API is untyped. There's probably no sane way to type it using pyre's
320        current feature-set, but we should still think about ways to type this or a
321        similar API in the future.
322        """
323        return replace(self, **changes)
324
325    def deep_clone(self: _CSTNodeSelfT) -> _CSTNodeSelfT:
326        """
327        Recursively clone the entire tree. The created tree is a new tree has the same
328        representation but different identity.
329
330        >>> tree = cst.parse_expression("1+2")
331
332        >>> tree.deep_clone() == tree
333        False
334
335        >>> tree == tree
336        True
337
338        >>> tree.deep_equals(tree.deep_clone())
339        True
340        """
341        cloned_fields: Dict[str, object] = {}
342        for field in fields(self):
343            key = field.name
344            if key[0] == "_":
345                continue
346            val = getattr(self, key)
347
348            # Much like the comment on _clone itself, we are allergic to instance
349            # checks against Sequence because of speed issues with ABC classes. So,
350            # instead, first handle sequence types that we do not want to iterate on
351            # and then just try to iterate and clone.
352            if isinstance(val, (str, bytes)):
353                cloned_fields[key] = _clone(val)
354            else:
355                try:
356                    cloned_fields[key] = tuple(_clone(v) for v in val)
357                except TypeError:
358                    cloned_fields[key] = _clone(val)
359
360        return type(self)(**cloned_fields)
361
362    def deep_equals(self, other: "CSTNode") -> bool:
363        """
364        Recursively inspects the entire tree under ``self`` and ``other`` to determine if
365        the two trees are equal by representation instead of identity (``==``).
366        """
367        from libcst._nodes.deep_equals import deep_equals as deep_equals_impl
368
369        return deep_equals_impl(self, other)
370
371    def deep_replace(
372        self: _CSTNodeSelfT, old_node: "CSTNode", new_node: CSTNodeT
373    ) -> Union[_CSTNodeSelfT, CSTNodeT]:
374        """
375        Recursively replaces any instance of ``old_node`` with ``new_node`` by identity.
376        Use this to avoid nested ``with_changes`` blocks when you are replacing one of
377        a node's deep children with a new node. Note that if you have previously
378        modified the tree in a way that ``old_node`` appears more than once as a deep
379        child, all instances will be replaced.
380        """
381        new_tree = self.visit(_ChildReplacementTransformer(old_node, new_node))
382        if isinstance(new_tree, (FlattenSentinel, RemovalSentinel)):
383            # The above transform never returns *Sentinel, so this isn't possible
384            raise Exception("Logic error, cannot get a *Sentinal here!")
385        return new_tree
386
387    def deep_remove(
388        self: _CSTNodeSelfT, old_node: "CSTNode"
389    ) -> Union[_CSTNodeSelfT, RemovalSentinel]:
390        """
391        Recursively removes any instance of ``old_node`` by identity. Note that if you
392        have previously modified the tree in a way that ``old_node`` appears more than
393        once as a deep child, all instances will be removed.
394        """
395        new_tree = self.visit(
396            _ChildReplacementTransformer(old_node, RemovalSentinel.REMOVE)
397        )
398
399        if isinstance(new_tree, FlattenSentinel):
400            # The above transform never returns FlattenSentinel, so this isn't possible
401            raise Exception("Logic error, cannot get a FlattenSentinel here!")
402
403        return new_tree
404
405    def with_deep_changes(
406        self: _CSTNodeSelfT, old_node: "CSTNode", **changes: Any
407    ) -> _CSTNodeSelfT:
408        """
409        A convenience method for applying :attr:`with_changes` to a child node. Use
410        this to avoid chains of :attr:`with_changes` or combinations of
411        :attr:`deep_replace` and :attr:`with_changes`.
412
413        The accepted arguments match the arguments given to the child node's
414        ``__init__``.
415
416        TODO: This API is untyped. There's probably no sane way to type it using pyre's
417        current feature-set, but we should still think about ways to type this or a
418        similar API in the future.
419        """
420        new_tree = self.visit(_ChildWithChangesTransformer(old_node, changes))
421        if isinstance(new_tree, (FlattenSentinel, RemovalSentinel)):
422            # This is impossible with the above transform.
423            raise Exception("Logic error, cannot get a *Sentinel here!")
424        return new_tree
425
426    def __eq__(self: _CSTNodeSelfT, other: _CSTNodeSelfT) -> bool:
427        """
428        CSTNodes are only treated as equal by identity. This matches the behavior of
429        CPython's AST nodes.
430
431        If you actually want to compare the value instead of the identity of the current
432        node with another, use `node.deep_equals`. Because `deep_equals` must traverse
433        the entire tree, it can have an unexpectedly large time complexity.
434
435        We're not exposing value equality as the default behavior because of
436        `deep_equals`'s large time complexity.
437        """
438        return self is other
439
440    def __hash__(self) -> int:
441        # Equality of nodes is based on identity, so the hash should be too.
442        return id(self)
443
444    def __repr__(self) -> str:
445        if len(fields(self)) == 0:
446            return f"{type(self).__name__}()"
447
448        lines = [f"{type(self).__name__}("]
449        for f in fields(self):
450            key = f.name
451            if key[0] != "_":
452                value = getattr(self, key)
453                lines.append(_indent(f"{key}={_pretty_repr(value)},"))
454        lines.append(")")
455        return "\n".join(lines)
456
457    @classmethod
458    # pyre-fixme[3]: Return annotation cannot be `Any`.
459    def field(cls, *args: object, **kwargs: object) -> Any:
460        """
461        A helper that allows us to easily use CSTNodes in dataclass constructor
462        defaults without accidentally aliasing nodes by identity across multiple
463        instances.
464        """
465        # pyre-ignore Pyre is complaining about CSTNode not being instantiable,
466        # but we're only going to call this from concrete subclasses.
467        return field(default_factory=lambda: cls(*args, **kwargs))
468
469
470class BaseLeaf(CSTNode, ABC):
471    @property
472    def children(self) -> Sequence[CSTNode]:
473        # override this with an optimized implementation
474        return _EMPTY_SEQUENCE
475
476    def _visit_and_replace_children(
477        self: _CSTNodeSelfT, visitor: CSTVisitorT
478    ) -> _CSTNodeSelfT:
479        return self
480
481
482class BaseValueToken(BaseLeaf, ABC):
483    """
484    Represents the subset of nodes that only contain a value. Not all tokens from the
485    tokenizer will exist as BaseValueTokens. In places where the token is always a
486    constant value (e.g. a COLON token), the token's value will be implicitly folded
487    into the parent CSTNode, and hard-coded into the implementation of _codegen.
488    """
489
490    value: str
491
492    def _codegen_impl(self, state: CodegenState) -> None:
493        state.add_token(self.value)
494