1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6from inspect import ismethod, signature
7from typing import (
8    Any,
9    Callable,
10    Dict,
11    List,
12    Optional,
13    Sequence,
14    Set,
15    Tuple,
16    Type,
17    Union,
18    cast,
19    get_type_hints,
20)
21
22import libcst as cst
23from libcst import CSTTransformer, CSTVisitor
24from libcst._types import CSTNodeT
25from libcst.matchers._decorators import (
26    CONSTRUCTED_LEAVE_MATCHER_ATTR,
27    CONSTRUCTED_VISIT_MATCHER_ATTR,
28    VISIT_NEGATIVE_MATCHER_ATTR,
29    VISIT_POSITIVE_MATCHER_ATTR,
30)
31from libcst.matchers._matcher_base import (
32    AllOf,
33    AtLeastN,
34    AtMostN,
35    BaseMatcherNode,
36    MatchIfTrue,
37    MatchMetadata,
38    MatchMetadataIfTrue,
39    OneOf,
40    extract,
41    extractall,
42    findall,
43    matches,
44    replace,
45)
46from libcst.matchers._return_types import TYPED_FUNCTION_RETURN_MAPPING
47
48CONCRETE_METHODS: Set[str] = {
49    *{f"visit_{cls.__name__}" for cls in TYPED_FUNCTION_RETURN_MAPPING},
50    *{f"leave_{cls.__name__}" for cls in TYPED_FUNCTION_RETURN_MAPPING},
51}
52
53
54# pyre-ignore We don't care about Any here, its not exposed.
55def _match_decorator_unpickler(kwargs: Any) -> "MatchDecoratorMismatch":
56    return MatchDecoratorMismatch(**kwargs)
57
58
59class MatchDecoratorMismatch(Exception):
60    def __init__(self, func: str, message: str) -> None:
61        super().__init__(f"Invalid function signature for {func}: {message}")
62        self.func = func
63        self.message = message
64
65    def __reduce__(
66        self,
67    ) -> Tuple[Callable[..., "MatchDecoratorMismatch"], Tuple[object, ...]]:
68        return (
69            _match_decorator_unpickler,
70            ({"func": self.func, "message": self.message},),
71        )
72
73
74def _get_possible_match_classes(matcher: BaseMatcherNode) -> List[Type[cst.CSTNode]]:
75    if isinstance(matcher, (OneOf, AllOf)):
76        return [getattr(cst, m.__class__.__name__) for m in matcher.options]
77    else:
78        return [getattr(cst, matcher.__class__.__name__)]
79
80
81def _annotation_looks_like_union(annotation: object) -> bool:
82    if getattr(annotation, "__origin__", None) is Union:
83        return True
84    # support PEP-604 style unions introduced in Python 3.10
85    return (
86        annotation.__class__.__name__ == "Union"
87        and annotation.__class__.__module__ == "types"
88    )
89
90
91def _get_possible_annotated_classes(annotation: object) -> List[Type[object]]:
92    if _annotation_looks_like_union(annotation):
93        return getattr(annotation, "__args__", [])
94    else:
95        return [cast(Type[object], annotation)]
96
97
98def _get_valid_leave_annotations_for_classes(
99    classes: Sequence[Type[cst.CSTNode]],
100) -> Set[Type[object]]:
101    retval: Set[Type[object]] = set()
102
103    for cls in classes:
104        # Look up the leave annotation for each class, combine them so we get a list of
105        # all possible valid return annotations. Its not really possible for us (or
106        # pyre) to fully enforce return types given the presence of OneOf/AllOf matchers, so
107        # we do the best we can by taking a union of all valid return annotations.
108        retval.update(
109            _get_possible_annotated_classes(TYPED_FUNCTION_RETURN_MAPPING[cls])
110        )
111
112    return retval
113
114
115def _verify_return_annotation(
116    possible_match_classes: Sequence[Type[cst.CSTNode]],
117    # pyre-ignore We only care that meth is callable.
118    meth: Callable[..., Any],
119    decorator_name: str,
120    *,
121    expected_none: bool,
122) -> None:
123    type_hints = get_type_hints(meth)
124    if expected_none:
125        # Simply look for any annotation at all and if it exists, verify that
126        # it is "None".
127        if type_hints.get("return", type(None)) is not type(None):  # noqa: E721
128            raise MatchDecoratorMismatch(
129                meth.__qualname__,
130                f"@{decorator_name} should only decorate functions that do "
131                + "not return.",
132            )
133    else:
134        if "return" not in type_hints:
135            # Can't check this, type annotation not supplied.
136            return
137
138        possible_annotated_classes = _get_possible_annotated_classes(
139            type_hints["return"]
140        )
141        possible_returns = _get_valid_leave_annotations_for_classes(
142            possible_match_classes
143        )
144
145        # Look at the union of specified return annotation, make sure that
146        # they are all subclasses of the original leave_<Node> return
147        # annotations. This catches when somebody tries to return a new node
148        # that we know can't fit where the existing node was in the tree.
149        for ret in possible_annotated_classes:
150            for annotation in possible_returns:
151                if issubclass(ret, annotation):
152                    # This annotation is a superclass of the possible match,
153                    # so we know that the types are correct.
154                    break
155            else:
156                # The current ret was not a subclass of any of the annotated
157                # return types.
158                raise MatchDecoratorMismatch(
159                    meth.__qualname__,
160                    f"@{decorator_name} decorated function cannot return "
161                    + f"the type {ret.__name__}.",
162                )
163
164
165def _verify_parameter_annotations(
166    possible_match_classes: Sequence[Type[cst.CSTNode]],
167    # pyre-ignore We only care that meth is callable.
168    meth: Callable[..., Any],
169    decorator_name: str,
170    *,
171    expected_param_count: int,
172) -> None:
173    # First, verify that the number of parameters is sane.
174    meth_signature = signature(meth)
175    if len(meth_signature.parameters) != expected_param_count:
176        raise MatchDecoratorMismatch(
177            meth.__qualname__,
178            f"@{decorator_name} should decorate functions which take "
179            + f"{expected_param_count} parameter"
180            + ("s" if expected_param_count > 1 else ""),
181        )
182
183    # Finally, for each parameter, make sure that the annotation includes
184    # each of the classes that might appear given the match string. This
185    # can be done in the simple case by just specifying the correct cst node
186    # type. For complex matches that use OneOf/AllOf, this could be a base class
187    # that encompases all possible matches, or a union.
188    params = [v for k, v in get_type_hints(meth).items() if k != "return"]
189    for param in params:
190        # Go through each possible matcher, and make sure that the annotation
191        # for types is a superclass of each matcher.
192        possible_annotated_classes = _get_possible_annotated_classes(param)
193        for match in possible_match_classes:
194            for annotation in possible_annotated_classes:
195                if issubclass(match, annotation):
196                    # This annotation is a superclass of the possible match,
197                    # so we know that the types are correct.
198                    break
199            else:
200                # The current match was not a subclass of any of the annotated
201                # types.
202                raise MatchDecoratorMismatch(
203                    meth.__qualname__,
204                    f"@{decorator_name} can be called with {match.__name__} "
205                    + "but the decorated function parameter annotations do "
206                    + "not include this type.",
207                )
208
209
210def _check_types(
211    # pyre-ignore We don't care about the type of sequence, just that its callable.
212    decoratormap: Dict[BaseMatcherNode, Sequence[Callable[..., Any]]],
213    decorator_name: str,
214    *,
215    expected_param_count: int,
216    expected_none_return: bool,
217) -> None:
218    for matcher, methods in decoratormap.items():
219        # Given the matcher class we have, get the list of possible cst nodes that
220        # could be passed to the functionis we wrap.
221        possible_match_classes = _get_possible_match_classes(matcher)
222        has_invalid_top_level = any(
223            isinstance(m, (AtLeastN, AtMostN, MatchIfTrue))
224            for m in possible_match_classes
225        )
226
227        # Now, loop through each function we wrap and verify that the type signature
228        # is valid.
229        for meth in methods:
230            # First thing first, make sure this isn't wrapping an inner class.
231            if not ismethod(meth):
232                raise MatchDecoratorMismatch(
233                    meth.__qualname__,
234                    "Matcher decorators should only be used on methods of "
235                    + "MatcherDecoratableTransformer or "
236                    + "MatcherDecoratableVisitor",
237                )
238            if has_invalid_top_level:
239                raise MatchDecoratorMismatch(
240                    # pyre-ignore This anonymous method has a qualname.
241                    meth.__qualname__,
242                    "The root matcher in a matcher decorator cannot be an "
243                    + "AtLeastN, AtMostN or MatchIfTrue matcher",
244                )
245
246            # Now, check that the return annotation is valid.
247            _verify_return_annotation(
248                possible_match_classes,
249                meth,
250                decorator_name,
251                expected_none=expected_none_return,
252            )
253
254            # Finally, check that the parameter annotations are valid.
255            _verify_parameter_annotations(
256                possible_match_classes,
257                meth,
258                decorator_name,
259                expected_param_count=expected_param_count,
260            )
261
262
263def _gather_matchers(obj: object) -> Set[BaseMatcherNode]:
264    visit_matchers: Set[BaseMatcherNode] = set()
265
266    for func in dir(obj):
267        try:
268            for matcher in getattr(getattr(obj, func), VISIT_POSITIVE_MATCHER_ATTR, []):
269                visit_matchers.add(cast(BaseMatcherNode, matcher))
270            for matcher in getattr(getattr(obj, func), VISIT_NEGATIVE_MATCHER_ATTR, []):
271                visit_matchers.add(cast(BaseMatcherNode, matcher))
272        except Exception:
273            # This could be a caculated property, and calling getattr() evaluates it.
274            # We have no control over the implementation detail, so if it raises, we
275            # should not crash.
276            pass
277
278    return visit_matchers
279
280
281def _assert_not_concrete(
282    decorator_name: str, func: Callable[[cst.CSTNode], None]
283) -> None:
284    if func.__name__ in CONCRETE_METHODS:
285        raise MatchDecoratorMismatch(
286            # pyre-ignore This anonymous method has a qualname.
287            func.__qualname__,
288            f"@{decorator_name} should not decorate functions that are concrete "
289            + "visit or leave methods.",
290        )
291
292
293def _gather_constructed_visit_funcs(
294    obj: object,
295) -> Dict[BaseMatcherNode, Sequence[Callable[[cst.CSTNode], None]]]:
296    constructed_visitors: Dict[
297        BaseMatcherNode, Sequence[Callable[[cst.CSTNode], None]]
298    ] = {}
299
300    for funcname in dir(obj):
301        try:
302            possible_func = getattr(obj, funcname)
303            if not ismethod(possible_func):
304                continue
305            func = cast(Callable[[cst.CSTNode], None], possible_func)
306        except Exception:
307            # This could be a caculated property, and calling getattr() evaluates it.
308            # We have no control over the implementation detail, so if it raises, we
309            # should not crash.
310            continue
311        matchers = getattr(func, CONSTRUCTED_VISIT_MATCHER_ATTR, [])
312        if matchers:
313            # Make sure that we aren't accidentally putting a @visit on a visit_Node.
314            _assert_not_concrete("visit", func)
315        for matcher in matchers:
316            casted_matcher = cast(BaseMatcherNode, matcher)
317            constructed_visitors[casted_matcher] = (
318                *constructed_visitors.get(casted_matcher, ()),
319                func,
320            )
321
322    return constructed_visitors
323
324
325# pyre-ignore: There is no reasonable way to type this, so ignore the Any type. This
326# is because the leave_* methods have a different signature depending on whether they
327# are in a MatcherDecoratableTransformer or a MatcherDecoratableVisitor.
328def _gather_constructed_leave_funcs(
329    obj: object,
330) -> Dict[BaseMatcherNode, Sequence[Callable[..., Any]]]:
331    constructed_visitors: Dict[
332        BaseMatcherNode, Sequence[Callable[[cst.CSTNode], None]]
333    ] = {}
334
335    for funcname in dir(obj):
336        try:
337            possible_func = getattr(obj, funcname)
338            if not ismethod(possible_func):
339                continue
340            func = cast(Callable[[cst.CSTNode], None], possible_func)
341        except Exception:
342            # This could be a caculated property, and calling getattr() evaluates it.
343            # We have no control over the implementation detail, so if it raises, we
344            # should not crash.
345            continue
346        matchers = getattr(func, CONSTRUCTED_LEAVE_MATCHER_ATTR, [])
347        if matchers:
348            # Make sure that we aren't accidentally putting a @leave on a leave_Node.
349            _assert_not_concrete("leave", func)
350        for matcher in matchers:
351            casted_matcher = cast(BaseMatcherNode, matcher)
352            constructed_visitors[casted_matcher] = (
353                *constructed_visitors.get(casted_matcher, ()),
354                func,
355            )
356
357    return constructed_visitors
358
359
360def _visit_matchers(
361    matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]],
362    node: cst.CSTNode,
363    metadata_resolver: cst.MetadataDependent,
364) -> Dict[BaseMatcherNode, Optional[cst.CSTNode]]:
365    new_matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]] = {}
366    for matcher, existing_node in matchers.items():
367        # We don't care about visiting matchers that are already true.
368        if existing_node is None and matches(
369            node, matcher, metadata_resolver=metadata_resolver
370        ):
371            # This node matches! Remember which node it was so we can
372            # cancel it later.
373            new_matchers[matcher] = node
374        else:
375            new_matchers[matcher] = existing_node
376    return new_matchers
377
378
379def _leave_matchers(
380    matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]], node: cst.CSTNode
381) -> Dict[BaseMatcherNode, Optional[cst.CSTNode]]:
382    new_matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]] = {}
383    for matcher, existing_node in matchers.items():
384        if node is existing_node:
385            # This node matches, so we are no longer inside it.
386            new_matchers[matcher] = None
387        else:
388            # We aren't leaving this node.
389            new_matchers[matcher] = existing_node
390    return new_matchers
391
392
393def _all_positive_matchers_true(
394    all_matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]], obj: object
395) -> bool:
396    requested_matchers = getattr(obj, VISIT_POSITIVE_MATCHER_ATTR, [])
397    for matcher in requested_matchers:
398        if all_matchers[matcher] is None:
399            # The passed in object has been decorated with a matcher that isn't
400            # active.
401            return False
402    return True
403
404
405def _all_negative_matchers_false(
406    all_matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]], obj: object
407) -> bool:
408    requested_matchers = getattr(obj, VISIT_NEGATIVE_MATCHER_ATTR, [])
409    for matcher in requested_matchers:
410        if all_matchers[matcher] is not None:
411            # The passed in object has been decorated with a matcher that is active.
412            return False
413    return True
414
415
416def _should_allow_visit(
417    all_matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]], obj: object
418) -> bool:
419    return _all_positive_matchers_true(
420        all_matchers, obj
421    ) and _all_negative_matchers_false(all_matchers, obj)
422
423
424def _visit_constructed_funcs(
425    visit_funcs: Dict[BaseMatcherNode, Sequence[Callable[[cst.CSTNode], None]]],
426    all_matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]],
427    node: cst.CSTNode,
428    metadata_resolver: cst.MetadataDependent,
429) -> None:
430    for matcher, visit_funcs in visit_funcs.items():
431        if matches(node, matcher, metadata_resolver=metadata_resolver):
432            for visit_func in visit_funcs:
433                if _should_allow_visit(all_matchers, visit_func):
434                    visit_func(node)
435
436
437class MatcherDecoratableTransformer(CSTTransformer):
438    """
439    This class provides all of the features of a :class:`libcst.CSTTransformer`, and
440    additionally supports various decorators to control when methods get called when
441    traversing a tree. Use this instead of a :class:`libcst.CSTTransformer` if you
442    wish to do more powerful decorator-based visiting.
443    """
444
445    def __init__(self) -> None:
446        CSTTransformer.__init__(self)
447        # List of gating matchers that we need to track and evaluate. We use these
448        # in conjuction with the call_if_inside and call_if_not_inside decorators
449        # to determine whether or not to call a visit/leave function.
450        self._matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]] = {
451            m: None for m in _gather_matchers(self)
452        }
453        # Mapping of matchers to functions. If in the course of visiting the tree,
454        # a node matches one of these matchers, the corresponding function will be
455        # called as if it was a visit_* method.
456        self._extra_visit_funcs: Dict[
457            BaseMatcherNode, Sequence[Callable[[cst.CSTNode], None]]
458        ] = _gather_constructed_visit_funcs(self)
459        # Mapping of matchers to functions. If in the course of leaving the tree,
460        # a node matches one of these matchers, the corresponding function will be
461        # called as if it was a leave_* method.
462        self._extra_leave_funcs: Dict[
463            BaseMatcherNode,
464            Sequence[
465                Callable[
466                    [cst.CSTNode, cst.CSTNode], Union[cst.CSTNode, cst.RemovalSentinel]
467                ]
468            ],
469        ] = _gather_constructed_leave_funcs(self)
470        # Make sure visit/leave functions constructed with @visit and @leave decorators
471        # have correct type annotations.
472        _check_types(
473            self._extra_visit_funcs,
474            "visit",
475            expected_param_count=1,
476            expected_none_return=True,
477        )
478        _check_types(
479            self._extra_leave_funcs,
480            "leave",
481            expected_param_count=2,
482            expected_none_return=False,
483        )
484
485    def on_visit(self, node: cst.CSTNode) -> bool:
486        # First, evaluate any matchers that we have which we are not inside already.
487        self._matchers = _visit_matchers(self._matchers, node, self)
488
489        # Now, call any visitors that were hooked using a visit decorator.
490        _visit_constructed_funcs(self._extra_visit_funcs, self._matchers, node, self)
491
492        # Now, evaluate whether this current function has any matchers it requires.
493        if not _should_allow_visit(
494            self._matchers, getattr(self, f"visit_{type(node).__name__}", None)
495        ):
496            # We shouldn't visit this directly. However, we should continue
497            # visiting its children.
498            return True
499
500        # Either the visit_func doesn't exist, we have no matchers, or we passed all
501        # matchers. In either case, just call the superclass behavior.
502        return CSTTransformer.on_visit(self, node)
503
504    def on_leave(
505        self, original_node: CSTNodeT, updated_node: CSTNodeT
506    ) -> Union[CSTNodeT, cst.RemovalSentinel]:
507        # First, evaluate whether this current function has a decorator on it.
508        if _should_allow_visit(
509            self._matchers, getattr(self, f"leave_{type(original_node).__name__}", None)
510        ):
511            retval = CSTTransformer.on_leave(self, original_node, updated_node)
512        else:
513            retval = updated_node
514
515        # Now, call any visitors that were hooked using a leave decorator.
516        for matcher, leave_funcs in reversed(list(self._extra_leave_funcs.items())):
517            if not self.matches(original_node, matcher):
518                continue
519            for leave_func in leave_funcs:
520                if _should_allow_visit(self._matchers, leave_func) and isinstance(
521                    retval, cst.CSTNode
522                ):
523                    retval = leave_func(original_node, retval)
524
525        # Now, see if we have any matchers we should deactivate.
526        self._matchers = _leave_matchers(self._matchers, original_node)
527
528        # pyre-ignore The return value of on_leave is subtly wrong in that we can
529        # actually return any value that passes this node's parent's constructor
530        # validation. Fixing this is beyond the scope of this file, and would involve
531        # forcing a lot of ensure_type() checks across the codebase.
532        return retval
533
534    def on_visit_attribute(self, node: cst.CSTNode, attribute: str) -> None:
535        # Evaluate whether this current function has a decorator on it.
536        if _should_allow_visit(
537            self._matchers,
538            getattr(self, f"visit_{type(node).__name__}_{attribute}", None),
539        ):
540            # Either the visit_func doesn't exist, we have no matchers, or we passed all
541            # matchers. In either case, just call the superclass behavior.
542            return CSTTransformer.on_visit_attribute(self, node, attribute)
543
544    def on_leave_attribute(self, original_node: cst.CSTNode, attribute: str) -> None:
545        # Evaluate whether this current function has a decorator on it.
546        if _should_allow_visit(
547            self._matchers,
548            getattr(self, f"leave_{type(original_node).__name__}_{attribute}", None),
549        ):
550            # Either the visit_func doesn't exist, we have no matchers, or we passed all
551            # matchers. In either case, just call the superclass behavior.
552            CSTTransformer.on_leave_attribute(self, original_node, attribute)
553
554    def matches(
555        self,
556        node: Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode],
557        matcher: BaseMatcherNode,
558    ) -> bool:
559        """
560        A convenience method to call :func:`~libcst.matchers.matches` without requiring
561        an explicit parameter for metadata. Since our instance is an instance of
562        :class:`libcst.MetadataDependent`, we work as a metadata resolver. Please see
563        documentation for :func:`~libcst.matchers.matches` as it is identical to this
564        function.
565        """
566        return matches(node, matcher, metadata_resolver=self)
567
568    def findall(
569        self,
570        tree: Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode],
571        matcher: Union[
572            BaseMatcherNode,
573            MatchIfTrue[cst.CSTNode],
574            MatchMetadata,
575            MatchMetadataIfTrue,
576        ],
577    ) -> Sequence[cst.CSTNode]:
578        """
579        A convenience method to call :func:`~libcst.matchers.findall` without requiring
580        an explicit parameter for metadata. Since our instance is an instance of
581        :class:`libcst.MetadataDependent`, we work as a metadata resolver. Please see
582        documentation for :func:`~libcst.matchers.findall` as it is identical to this
583        function.
584        """
585        return findall(tree, matcher, metadata_resolver=self)
586
587    def extract(
588        self,
589        node: Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode],
590        matcher: BaseMatcherNode,
591    ) -> Optional[Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]]]:
592        """
593        A convenience method to call :func:`~libcst.matchers.extract` without requiring
594        an explicit parameter for metadata. Since our instance is an instance of
595        :class:`libcst.MetadataDependent`, we work as a metadata resolver. Please see
596        documentation for :func:`~libcst.matchers.extract` as it is identical to this
597        function.
598        """
599        return extract(node, matcher, metadata_resolver=self)
600
601    def extractall(
602        self,
603        tree: Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode],
604        matcher: Union[
605            BaseMatcherNode,
606            MatchIfTrue[cst.CSTNode],
607            MatchMetadata,
608            MatchMetadataIfTrue,
609        ],
610    ) -> Sequence[Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]]]:
611        """
612        A convenience method to call :func:`~libcst.matchers.extractall` without requiring
613        an explicit parameter for metadata. Since our instance is an instance of
614        :class:`libcst.MetadataDependent`, we work as a metadata resolver. Please see
615        documentation for :func:`~libcst.matchers.extractall` as it is identical to this
616        function.
617        """
618        return extractall(tree, matcher, metadata_resolver=self)
619
620    def replace(
621        self,
622        tree: Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode],
623        matcher: Union[
624            BaseMatcherNode,
625            MatchIfTrue[cst.CSTNode],
626            MatchMetadata,
627            MatchMetadataIfTrue,
628        ],
629        replacement: Union[
630            cst.MaybeSentinel,
631            cst.RemovalSentinel,
632            cst.CSTNode,
633            Callable[
634                [cst.CSTNode, Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]]],
635                Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode],
636            ],
637        ],
638    ) -> Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode]:
639        """
640        A convenience method to call :func:`~libcst.matchers.replace` without requiring
641        an explicit parameter for metadata. Since our instance is an instance of
642        :class:`libcst.MetadataDependent`, we work as a metadata resolver. Please see
643        documentation for :func:`~libcst.matchers.replace` as it is identical to this
644        function.
645        """
646        return replace(tree, matcher, replacement, metadata_resolver=self)
647
648
649class MatcherDecoratableVisitor(CSTVisitor):
650    """
651    This class provides all of the features of a :class:`libcst.CSTVisitor`, and
652    additionally supports various decorators to control when methods get called
653    when traversing a tree. Use this instead of a :class:`libcst.CSTVisitor` if
654    you wish to do more powerful decorator-based visiting.
655    """
656
657    def __init__(self) -> None:
658        CSTVisitor.__init__(self)
659        # List of gating matchers that we need to track and evaluate. We use these
660        # in conjuction with the call_if_inside and call_if_not_inside decorators
661        # to determine whether or not to call a visit/leave function.
662        self._matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]] = {
663            m: None for m in _gather_matchers(self)
664        }
665        # Mapping of matchers to functions. If in the course of visiting the tree,
666        # a node matches one of these matchers, the corresponding function will be
667        # called as if it was a visit_* method.
668        self._extra_visit_funcs: Dict[
669            BaseMatcherNode, Sequence[Callable[[cst.CSTNode], None]]
670        ] = _gather_constructed_visit_funcs(self)
671        # Mapping of matchers to functions. If in the course of leaving the tree,
672        # a node matches one of these matchers, the corresponding function will be
673        # called as if it was a leave_* method.
674        self._extra_leave_funcs: Dict[
675            BaseMatcherNode, Sequence[Callable[[cst.CSTNode], None]]
676        ] = _gather_constructed_leave_funcs(self)
677        # Make sure visit/leave functions constructed with @visit and @leave decorators
678        # have correct type annotations.
679        _check_types(
680            self._extra_visit_funcs,
681            "visit",
682            expected_param_count=1,
683            expected_none_return=True,
684        )
685        _check_types(
686            self._extra_leave_funcs,
687            "leave",
688            expected_param_count=1,
689            expected_none_return=True,
690        )
691
692    def on_visit(self, node: cst.CSTNode) -> bool:
693        # First, evaluate any matchers that we have which we are not inside already.
694        self._matchers = _visit_matchers(self._matchers, node, self)
695
696        # Now, call any visitors that were hooked using a visit decorator.
697        _visit_constructed_funcs(self._extra_visit_funcs, self._matchers, node, self)
698
699        # Now, evaluate whether this current function has a decorator on it.
700        if not _should_allow_visit(
701            self._matchers, getattr(self, f"visit_{type(node).__name__}", None)
702        ):
703            # We shouldn't visit this directly. However, we should continue
704            # visiting its children.
705            return True
706
707        # Either the visit_func doesn't exist, we have no matchers, or we passed all
708        # matchers. In either case, just call the superclass behavior.
709        return CSTVisitor.on_visit(self, node)
710
711    def on_leave(self, original_node: cst.CSTNode) -> None:
712        # First, evaluate whether this current function has a decorator on it.
713        if _should_allow_visit(
714            self._matchers, getattr(self, f"leave_{type(original_node).__name__}", None)
715        ):
716            CSTVisitor.on_leave(self, original_node)
717
718        # Now, call any visitors that were hooked using a leave decorator.
719        for matcher, leave_funcs in reversed(list(self._extra_leave_funcs.items())):
720            if not self.matches(original_node, matcher):
721                continue
722            for leave_func in leave_funcs:
723                if _should_allow_visit(self._matchers, leave_func):
724                    leave_func(original_node)
725
726        # Now, see if we have any matchers we should deactivate.
727        self._matchers = _leave_matchers(self._matchers, original_node)
728
729    def on_visit_attribute(self, node: cst.CSTNode, attribute: str) -> None:
730        # Evaluate whether this current function has a decorator on it.
731        if _should_allow_visit(
732            self._matchers,
733            getattr(self, f"visit_{type(node).__name__}_{attribute}", None),
734        ):
735            # Either the visit_func doesn't exist, we have no matchers, or we passed all
736            # matchers. In either case, just call the superclass behavior.
737            return CSTVisitor.on_visit_attribute(self, node, attribute)
738
739    def on_leave_attribute(self, original_node: cst.CSTNode, attribute: str) -> None:
740        # Evaluate whether this current function has a decorator on it.
741        if _should_allow_visit(
742            self._matchers,
743            getattr(self, f"leave_{type(original_node).__name__}_{attribute}", None),
744        ):
745            # Either the visit_func doesn't exist, we have no matchers, or we passed all
746            # matchers. In either case, just call the superclass behavior.
747            CSTVisitor.on_leave_attribute(self, original_node, attribute)
748
749    def matches(
750        self,
751        node: Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode],
752        matcher: BaseMatcherNode,
753    ) -> bool:
754        """
755        A convenience method to call :func:`~libcst.matchers.matches` without requiring
756        an explicit parameter for metadata. Since our instance is an instance of
757        :class:`libcst.MetadataDependent`, we work as a metadata resolver. Please see
758        documentation for :func:`~libcst.matchers.matches` as it is identical to this
759        function.
760        """
761        return matches(node, matcher, metadata_resolver=self)
762
763    def findall(
764        self,
765        tree: Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode],
766        matcher: Union[
767            BaseMatcherNode,
768            MatchIfTrue[cst.CSTNode],
769            MatchMetadata,
770            MatchMetadataIfTrue,
771        ],
772    ) -> Sequence[cst.CSTNode]:
773        """
774        A convenience method to call :func:`~libcst.matchers.findall` without requiring
775        an explicit parameter for metadata. Since our instance is an instance of
776        :class:`libcst.MetadataDependent`, we work as a metadata resolver. Please see
777        documentation for :func:`~libcst.matchers.findall` as it is identical to this
778        function.
779        """
780        return findall(tree, matcher, metadata_resolver=self)
781
782    def extract(
783        self,
784        node: Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode],
785        matcher: BaseMatcherNode,
786    ) -> Optional[Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]]]:
787        """
788        A convenience method to call :func:`~libcst.matchers.extract` without requiring
789        an explicit parameter for metadata. Since our instance is an instance of
790        :class:`libcst.MetadataDependent`, we work as a metadata resolver. Please see
791        documentation for :func:`~libcst.matchers.extract` as it is identical to this
792        function.
793        """
794        return extract(node, matcher, metadata_resolver=self)
795
796    def extractall(
797        self,
798        tree: Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode],
799        matcher: Union[
800            BaseMatcherNode,
801            MatchIfTrue[cst.CSTNode],
802            MatchMetadata,
803            MatchMetadataIfTrue,
804        ],
805    ) -> Sequence[Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]]]:
806        """
807        A convenience method to call :func:`~libcst.matchers.extractall` without requiring
808        an explicit parameter for metadata. Since our instance is an instance of
809        :class:`libcst.MetadataDependent`, we work as a metadata resolver. Please see
810        documentation for :func:`~libcst.matchers.extractall` as it is identical to this
811        function.
812        """
813        return extractall(tree, matcher, metadata_resolver=self)
814
815    def replace(
816        self,
817        tree: Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode],
818        matcher: Union[
819            BaseMatcherNode,
820            MatchIfTrue[cst.CSTNode],
821            MatchMetadata,
822            MatchMetadataIfTrue,
823        ],
824        replacement: Union[
825            cst.MaybeSentinel,
826            cst.RemovalSentinel,
827            cst.CSTNode,
828            Callable[
829                [cst.CSTNode, Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]]],
830                Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode],
831            ],
832        ],
833    ) -> Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode]:
834        """
835        A convenience method to call :func:`~libcst.matchers.replace` without requiring
836        an explicit parameter for metadata. Since our instance is an instance of
837        :class:`libcst.MetadataDependent`, we work as a metadata resolver. Please see
838        documentation for :func:`~libcst.matchers.replace` as it is identical to this
839        function.
840        """
841        return replace(tree, matcher, replacement, metadata_resolver=self)
842