1from copy import copy
2from enum import Enum
3from typing import (
4    Any,
5    Callable,
6    Collection,
7    Dict,
8    List,
9    NamedTuple,
10    Optional,
11    Tuple,
12    Union,
13)
14
15from ..pyutils import inspect, snake_to_camel
16from . import ast
17
18from .ast import Node
19
20__all__ = [
21    "Visitor",
22    "ParallelVisitor",
23    "VisitorAction",
24    "visit",
25    "BREAK",
26    "SKIP",
27    "REMOVE",
28    "IDLE",
29    "QUERY_DOCUMENT_KEYS",
30]
31
32
33class VisitorActionEnum(Enum):
34    """Special return values for the visitor methods.
35
36    You can also use the values of this enum directly.
37    """
38
39    BREAK = True
40    SKIP = False
41    REMOVE = Ellipsis
42
43
44VisitorAction = Optional[VisitorActionEnum]
45
46# Note that in GraphQL.js these are defined differently:
47# BREAK = {}, SKIP = false, REMOVE = null, IDLE = undefined
48
49BREAK = VisitorActionEnum.BREAK
50SKIP = VisitorActionEnum.SKIP
51REMOVE = VisitorActionEnum.REMOVE
52IDLE = None
53
54# Default map from visitor kinds to their traversable node attributes:
55QUERY_DOCUMENT_KEYS: Dict[str, Tuple[str, ...]] = {
56    "name": (),
57    "document": ("definitions",),
58    "operation_definition": (
59        "name",
60        "variable_definitions",
61        "directives",
62        "selection_set",
63    ),
64    "variable_definition": ("variable", "type", "default_value", "directives"),
65    "variable": ("name",),
66    "selection_set": ("selections",),
67    "field": ("alias", "name", "arguments", "directives", "selection_set"),
68    "argument": ("name", "value"),
69    "fragment_spread": ("name", "directives"),
70    "inline_fragment": ("type_condition", "directives", "selection_set"),
71    "fragment_definition": (
72        # Note: fragment variable definitions are experimental and may be changed or
73        # removed in the future.
74        "name",
75        "variable_definitions",
76        "type_condition",
77        "directives",
78        "selection_set",
79    ),
80    "int_value": (),
81    "float_value": (),
82    "string_value": (),
83    "boolean_value": (),
84    "enum_value": (),
85    "list_value": ("values",),
86    "object_value": ("fields",),
87    "object_field": ("name", "value"),
88    "directive": ("name", "arguments"),
89    "named_type": ("name",),
90    "list_type": ("type",),
91    "non_null_type": ("type",),
92    "schema_definition": ("description", "directives", "operation_types"),
93    "operation_type_definition": ("type",),
94    "scalar_type_definition": ("description", "name", "directives"),
95    "object_type_definition": (
96        "description",
97        "name",
98        "interfaces",
99        "directives",
100        "fields",
101    ),
102    "field_definition": ("description", "name", "arguments", "type", "directives"),
103    "input_value_definition": (
104        "description",
105        "name",
106        "type",
107        "default_value",
108        "directives",
109    ),
110    "interface_type_definition": (
111        "description",
112        "name",
113        "interfaces",
114        "directives",
115        "fields",
116    ),
117    "union_type_definition": ("description", "name", "directives", "types"),
118    "enum_type_definition": ("description", "name", "directives", "values"),
119    "enum_value_definition": ("description", "name", "directives"),
120    "input_object_type_definition": ("description", "name", "directives", "fields"),
121    "directive_definition": ("description", "name", "arguments", "locations"),
122    "schema_extension": ("directives", "operation_types"),
123    "scalar_type_extension": ("name", "directives"),
124    "object_type_extension": ("name", "interfaces", "directives", "fields"),
125    "interface_type_extension": ("name", "interfaces", "directives", "fields"),
126    "union_type_extension": ("name", "directives", "types"),
127    "enum_type_extension": ("name", "directives", "values"),
128    "input_object_type_extension": ("name", "directives", "fields"),
129}
130
131
132class Visitor:
133    """Visitor that walks through an AST.
134
135    Visitors can define two generic methods "enter" and "leave". The former will be
136    called when a node is entered in the traversal, the latter is called after visiting
137    the node and its child nodes. These methods have the following signature::
138
139        def enter(self, node, key, parent, path, ancestors):
140            # The return value has the following meaning:
141            # IDLE (None): no action
142            # SKIP: skip visiting this node
143            # BREAK: stop visiting altogether
144            # REMOVE: delete this node
145            # any other value: replace this node with the returned value
146            return
147
148        def leave(self, node, key, parent, path, ancestors):
149            # The return value has the following meaning:
150            # IDLE (None) or SKIP: no action
151            # BREAK: stop visiting altogether
152            # REMOVE: delete this node
153            # any other value: replace this node with the returned value
154            return
155
156    The parameters have the following meaning:
157
158    :arg node: The current node being visiting.
159    :arg key: The index or key to this node from the parent node or Array.
160    :arg parent: the parent immediately above this node, which may be an Array.
161    :arg path: The key path to get to this node from the root node.
162    :arg ancestors: All nodes and Arrays visited before reaching parent
163        of this node. These correspond to array indices in ``path``.
164        Note: ancestors includes arrays which contain the parent of visited node.
165
166    You can also define node kind specific methods by suffixing them with an underscore
167    followed by the kind of the node to be visited. For instance, to visit ``field``
168    nodes, you would defined the methods ``enter_field()`` and/or ``leave_field()``,
169    with the same signature as above. If no kind specific method has been defined
170    for a given node, the generic method is called.
171    """
172
173    # Provide special return values as attributes
174    BREAK, SKIP, REMOVE, IDLE = BREAK, SKIP, REMOVE, IDLE
175
176    def __init_subclass__(cls) -> None:
177        """Verify that all defined handlers are valid."""
178        super().__init_subclass__()
179        for attr, val in cls.__dict__.items():
180            if attr.startswith("_"):
181                continue
182            attr_kind = attr.split("_", 1)
183            if len(attr_kind) < 2:
184                kind: Optional[str] = None
185            else:
186                attr, kind = attr_kind
187            if attr in ("enter", "leave"):
188                if kind:
189                    name = snake_to_camel(kind) + "Node"
190                    node_cls = getattr(ast, name, None)
191                    if (
192                        not node_cls
193                        or not isinstance(node_cls, type)
194                        or not issubclass(node_cls, Node)
195                    ):
196                        raise TypeError(f"Invalid AST node kind: {kind}.")
197
198    def get_visit_fn(self, kind: str, is_leaving: bool = False) -> Callable:
199        """Get the visit function for the given node kind and direction."""
200        method = "leave" if is_leaving else "enter"
201        visit_fn = getattr(self, f"{method}_{kind}", None)
202        if not visit_fn:
203            visit_fn = getattr(self, method, None)
204        return visit_fn
205
206
207class Stack(NamedTuple):
208    """A stack for the visit function."""
209
210    in_array: bool
211    idx: int
212    keys: Tuple[Node, ...]
213    edits: List[Tuple[Union[int, str], Node]]
214    prev: Any  # 'Stack' (python/mypy/issues/731)
215
216
217def visit(
218    root: Node,
219    visitor: Visitor,
220    visitor_keys: Optional[Dict[str, Tuple[str, ...]]] = None,
221) -> Any:
222    """Visit each node in an AST.
223
224    :func:`~.visit` will walk through an AST using a depth-first traversal, calling the
225    visitor's enter methods at each node in the traversal, and calling the leave methods
226    after visiting that node and all of its child nodes.
227
228    By returning different values from the enter and leave methods, the behavior of the
229    visitor can be altered, including skipping over a sub-tree of the AST (by returning
230    False), editing the AST by returning a value or None to remove the value, or to stop
231    the whole traversal by returning :data:`~.BREAK`.
232
233    When using :func:`~.visit` to edit an AST, the original AST will not be modified,
234    and a new version of the AST with the changes applied will be returned from the
235    visit function.
236
237    To customize the node attributes to be used for traversal, you can provide a
238    dictionary visitor_keys mapping node kinds to node attributes.
239    """
240    if not isinstance(root, Node):
241        raise TypeError(f"Not an AST Node: {inspect(root)}.")
242    if not isinstance(visitor, Visitor):
243        raise TypeError(f"Not an AST Visitor: {inspect(visitor)}.")
244    if visitor_keys is None:
245        visitor_keys = QUERY_DOCUMENT_KEYS
246    stack: Any = None
247    in_array = isinstance(root, list)
248    keys: Tuple[Node, ...] = (root,)
249    idx = -1
250    edits: List[Any] = []
251    parent: Any = None
252    path: List[Any] = []
253    path_append = path.append
254    path_pop = path.pop
255    ancestors: List[Any] = []
256    ancestors_append = ancestors.append
257    ancestors_pop = ancestors.pop
258    new_root = root
259
260    while True:
261        idx += 1
262        is_leaving = idx == len(keys)
263        is_edited = is_leaving and edits
264        if is_leaving:
265            key = path[-1] if ancestors else None
266            node: Any = parent
267            parent = ancestors_pop() if ancestors else None
268            if is_edited:
269                if in_array:
270                    node = node[:]
271                else:
272                    node = copy(node)
273            edit_offset = 0
274            for edit_key, edit_value in edits:
275                if in_array:
276                    edit_key -= edit_offset
277                if in_array and (edit_value is REMOVE or edit_value is Ellipsis):
278                    node.pop(edit_key)
279                    edit_offset += 1
280                else:
281                    if isinstance(node, list):
282                        node[edit_key] = edit_value
283                    else:
284                        setattr(node, edit_key, edit_value)
285
286            idx = stack.idx
287            keys = stack.keys
288            edits = stack.edits
289            in_array = stack.in_array
290            stack = stack.prev
291        else:
292            if parent:
293                if in_array:
294                    key = idx
295                    node = parent[key]
296                else:
297                    key = keys[idx]
298                    node = getattr(parent, key, None)
299            else:
300                key = None
301                node = new_root
302            if node is None:
303                continue
304            if parent:
305                path_append(key)
306
307        if isinstance(node, list):
308            result = None
309        else:
310            if not isinstance(node, Node):
311                raise TypeError(f"Invalid AST Node: {inspect(node)}.")
312            visit_fn = visitor.get_visit_fn(node.kind, is_leaving)
313            if visit_fn:
314                result = visit_fn(node, key, parent, path, ancestors)
315
316                if result is BREAK or result is True:
317                    break
318
319                if result is SKIP or result is False:
320                    if not is_leaving:
321                        path_pop()
322                        continue
323
324                elif result is not None:
325                    edits.append((key, result))
326                    if not is_leaving:
327                        if isinstance(result, Node):
328                            node = result
329                        else:
330                            path_pop()
331                            continue
332            else:
333                result = None
334
335        if result is None and is_edited:
336            edits.append((key, node))
337
338        if is_leaving:
339            if path:
340                path_pop()
341        else:
342            stack = Stack(in_array, idx, keys, edits, stack)
343            in_array = isinstance(node, list)
344            keys = node if in_array else visitor_keys.get(node.kind, ())
345            idx = -1
346            edits = []
347            if parent:
348                ancestors_append(parent)
349            parent = node
350
351        if not stack:
352            break
353
354    if edits:
355        new_root = edits[-1][1]
356
357    return new_root
358
359
360class ParallelVisitor(Visitor):
361    """A Visitor which delegates to many visitors to run in parallel.
362
363    Each visitor will be visited for each node before moving on.
364
365    If a prior visitor edits a node, no following visitors will see that node.
366    """
367
368    def __init__(self, visitors: Collection[Visitor]):
369        """Create a new visitor from the given list of parallel visitors."""
370        self.visitors = visitors
371        self.skipping: List[Any] = [None] * len(visitors)
372
373    def enter(self, node: Node, *args: Any) -> Optional[VisitorAction]:
374        skipping = self.skipping
375        for i, visitor in enumerate(self.visitors):
376            if not skipping[i]:
377                fn = visitor.get_visit_fn(node.kind)
378                if fn:
379                    result = fn(node, *args)
380                    if result is SKIP or result is False:
381                        skipping[i] = node
382                    elif result is BREAK or result is True:
383                        skipping[i] = BREAK
384                    elif result is not None:
385                        return result
386        return None
387
388    def leave(self, node: Node, *args: Any) -> Optional[VisitorAction]:
389        skipping = self.skipping
390        for i, visitor in enumerate(self.visitors):
391            if not skipping[i]:
392                fn = visitor.get_visit_fn(node.kind, is_leaving=True)
393                if fn:
394                    result = fn(node, *args)
395                    if result is BREAK or result is True:
396                        skipping[i] = BREAK
397                    elif (
398                        result is not None
399                        and result is not SKIP
400                        and result is not False
401                    ):
402                        return result
403            elif skipping[i] is node:
404                skipping[i] = None
405        return None
406