1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6import dataclasses
7from contextlib import ExitStack
8from dataclasses import dataclass
9from typing import Any, Callable, Iterable, List, Optional, Sequence, Type
10from unittest.mock import patch
11
12import libcst as cst
13from libcst._nodes.internal import CodegenState, visit_required
14from libcst._types import CSTNodeT
15from libcst._visitors import CSTTransformer, CSTVisitorT
16from libcst.metadata import CodeRange, PositionProvider
17from libcst.metadata.position_provider import PositionProvidingCodegenState
18from libcst.testing.utils import UnitTest
19
20
21@dataclass(frozen=True)
22class _CSTCodegenPatchTarget:
23    type: Type[cst.CSTNode]
24    name: str
25    old_codegen: Callable[..., None]
26
27
28class _NOOPVisitor(CSTTransformer):
29    pass
30
31
32def _cst_node_equality_func(
33    a: cst.CSTNode, b: cst.CSTNode, msg: Optional[str] = None
34) -> None:
35    """
36    For use with addTypeEqualityFunc.
37    """
38    if not a.deep_equals(b):
39        suffix = "" if msg is None else f"\n{msg}"
40        raise AssertionError(f"\n{a!r}\nis not deeply equal to \n{b!r}{suffix}")
41
42
43def parse_expression_as(**config: Any) -> Callable[[str], cst.BaseExpression]:
44    def inner(code: str) -> cst.BaseExpression:
45        return cst.parse_expression(code, config=cst.PartialParserConfig(**config))
46
47    return inner
48
49
50def parse_statement_as(**config: Any) -> Callable[[str], cst.BaseStatement]:
51    def inner(code: str) -> cst.BaseStatement:
52        return cst.parse_statement(code, config=cst.PartialParserConfig(**config))
53
54    return inner
55
56
57# We can't use an ABCMeta here, because of metaclass conflicts
58class CSTNodeTest(UnitTest):
59    def setUp(self) -> None:
60        # Fix `self.assertEqual` for CSTNode subclasses. We should compare equality by
61        # value instead of identity (what `CSTNode.__eq__` does) for tests.
62        #
63        # The time complexity of CSTNode.deep_equals doesn't matter much inside tests.
64        for v in cst.__dict__.values():
65            if isinstance(v, type) and issubclass(v, cst.CSTNode):
66                self.addTypeEqualityFunc(v, _cst_node_equality_func)
67        self.addTypeEqualityFunc(DummyIndentedBlock, _cst_node_equality_func)
68
69    def validate_node(
70        self,
71        node: CSTNodeT,
72        code: str,
73        parser: Optional[Callable[[str], CSTNodeT]] = None,
74        expected_position: Optional[CodeRange] = None,
75    ) -> None:
76        node.validate_types_deep()
77        self.__assert_codegen(node, code, expected_position)
78
79        if parser is not None:
80            parsed_node = parser(code)
81            self.assertEqual(parsed_node, node)
82
83        # Tests of children should unwrap DummyIndentedBlock first, because we don't
84        # want to test DummyIndentedBlock's behavior.
85        unwrapped_node = node
86        while isinstance(unwrapped_node, DummyIndentedBlock):
87            unwrapped_node = unwrapped_node.child
88        self.__assert_children_match_codegen(unwrapped_node)
89        self.__assert_children_match_fields(unwrapped_node)
90        self.__assert_visit_returns_identity(unwrapped_node)
91
92    def assert_invalid(
93        self, get_node: Callable[[], cst.CSTNode], expected_re: str
94    ) -> None:
95        with self.assertRaisesRegex(cst.CSTValidationError, expected_re):
96            get_node()
97
98    def assert_invalid_types(
99        self, get_node: Callable[[], cst.CSTNode], expected_re: str
100    ) -> None:
101        with self.assertRaisesRegex(TypeError, expected_re):
102            get_node().validate_types_shallow()
103
104    def __assert_codegen(
105        self,
106        node: cst.CSTNode,
107        expected: str,
108        expected_position: Optional[CodeRange] = None,
109    ) -> None:
110        """
111        Verifies that the given node's `_codegen` method is correct.
112        """
113        module = cst.Module([])
114        self.assertEqual(module.code_for_node(node), expected)
115
116        if expected_position is not None:
117            # This is using some internal APIs, because we only want to compute
118            # position for the node being tested, not a whole module.
119            #
120            # Normally, this is a nonsense operation (how can a node have a position if
121            # its not in a module?), which is why it's not supported, but it makes
122            # sense in the context of these node tests.
123            provider = PositionProvider()
124            state = PositionProvidingCodegenState(
125                default_indent=module.default_indent,
126                default_newline=module.default_newline,
127                provider=provider,
128            )
129            node._codegen(state)
130            self.assertEqual(provider._computed[node], expected_position)
131
132    def __assert_children_match_codegen(self, node: cst.CSTNode) -> None:
133        children = node.children
134        codegen_children = self.__derive_children_from_codegen(node)
135        self.assertSequenceEqual(
136            children,
137            codegen_children,
138            msg=(
139                "The list of children we got from `node.children` differs from the "
140                + "children that were visited by `node._codegen`."
141            ),
142        )
143
144    def __derive_children_from_codegen(
145        self, node: cst.CSTNode
146    ) -> Sequence[cst.CSTNode]:
147        """
148        Patches all subclasses of `CSTNode` exported by the `cst` module to track which
149        `_codegen` methods get called, generating a list of children.
150
151        Because all children must be rendered out into lexical order, this should be
152        equivalent to `node.children`.
153
154        `node.children` uses `_visit_and_replace_children` under the hood, not
155        `_codegen`, so this helps us verify that both of those two method's behaviors
156        are in sync.
157        """
158
159        patch_targets: Iterable[_CSTCodegenPatchTarget] = [
160            _CSTCodegenPatchTarget(type=v, name=k, old_codegen=v._codegen)
161            for (k, v) in cst.__dict__.items()
162            if isinstance(v, type)
163            and issubclass(v, cst.CSTNode)
164            and hasattr(v, "_codegen")
165        ]
166
167        children: List[cst.CSTNode] = []
168        codegen_stack: List[cst.CSTNode] = []
169
170        def _get_codegen_override(
171            target: _CSTCodegenPatchTarget,
172        ) -> Callable[..., None]:
173            def _codegen_impl(self: CSTNodeT, *args: Any, **kwargs: Any) -> None:
174                should_pop = False
175                # Don't stick duplicates in the stack. This is needed so that we don't
176                # track calls to `super()._codegen()`.
177                if len(codegen_stack) == 0 or codegen_stack[-1] is not self:
178                    # Check the stack to see that we're a direct child, not the root or
179                    # a transitive child.
180                    if len(codegen_stack) == 1:
181                        children.append(self)
182                    codegen_stack.append(self)
183                    should_pop = True
184                target.old_codegen(self, *args, **kwargs)
185                # only pop if we pushed something to the stack earlier
186                if should_pop:
187                    codegen_stack.pop()
188
189            return _codegen_impl
190
191        with ExitStack() as patch_stack:
192            for t in patch_targets:
193                patch_stack.enter_context(
194                    patch(f"libcst.{t.name}._codegen", _get_codegen_override(t))
195                )
196            # Execute `node._codegen()`
197            cst.Module([]).code_for_node(node)
198
199        return children
200
201    def __assert_children_match_fields(self, node: cst.CSTNode) -> None:
202        """
203        We expect `node.children` to match everything we can extract from the node's
204        fields, but maybe in a different order. This asserts that those things match.
205
206        If you want to verify order as well, use `assert_children_ordered`.
207        """
208        node_children_ids = {id(child) for child in node.children}
209        fields = dataclasses.fields(node)
210        field_child_ids = set()
211        for f in fields:
212            value = getattr(node, f.name)
213            if isinstance(value, cst.CSTNode):
214                field_child_ids.add(id(value))
215            elif isinstance(value, Iterable):
216                field_child_ids.update(
217                    id(el) for el in value if isinstance(el, cst.CSTNode)
218                )
219
220        # order doesn't matter
221        self.assertSetEqual(
222            node_children_ids,
223            field_child_ids,
224            msg="`node.children` doesn't match what we found through introspection",
225        )
226
227    def __assert_visit_returns_identity(self, node: cst.CSTNode) -> None:
228        """
229        When visit is called with a visitor that acts as a no-op, the visit method
230        should return the same node it started with.
231        """
232        # TODO: We're only checking equality right now, because visit currently clones
233        # the node, since that was easier to implement. We should fix that behavior in a
234        # later version and tighten this check.
235        self.assertEqual(node, node.visit(_NOOPVisitor()))
236
237    def assert_parses(
238        self,
239        code: str,
240        parser: Callable[[str], cst.BaseExpression],
241        expect_success: bool,
242    ) -> None:
243        if not expect_success:
244            with self.assertRaises(cst.ParserSyntaxError):
245                parser(code)
246        else:
247            parser(code)
248
249
250@dataclass(frozen=True)
251class DummyIndentedBlock(cst.CSTNode):
252    """
253    A stripped-down version of cst.IndentedBlock that only sets/clears the indentation
254    state for the purpose of testing cst.IndentWhitespace in isolation.
255    """
256
257    value: str
258    child: cst.CSTNode
259
260    def _codegen_impl(self, state: CodegenState) -> None:
261        state.increase_indent(self.value)
262        with state.record_syntactic_position(
263            self, start_node=self.child, end_node=self.child
264        ):
265            self.child._codegen(state)
266        state.decrease_indent()
267
268    def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "DummyIndentedBlock":
269        return DummyIndentedBlock(
270            value=self.value, child=visit_required(self, "child", self.child, visitor)
271        )
272