1from copy import copy 2from enum import Enum 3from typing import ( 4 Any, 5 Callable, 6 Collection, 7 Dict, 8 List, 9 NamedTuple, 10 Optional, 11 Tuple, 12 Union, 13) 14 15from ..pyutils import inspect, snake_to_camel 16from . import ast 17 18from .ast import Node 19 20__all__ = [ 21 "Visitor", 22 "ParallelVisitor", 23 "VisitorAction", 24 "visit", 25 "BREAK", 26 "SKIP", 27 "REMOVE", 28 "IDLE", 29 "QUERY_DOCUMENT_KEYS", 30] 31 32 33class VisitorActionEnum(Enum): 34 """Special return values for the visitor methods. 35 36 You can also use the values of this enum directly. 37 """ 38 39 BREAK = True 40 SKIP = False 41 REMOVE = Ellipsis 42 43 44VisitorAction = Optional[VisitorActionEnum] 45 46# Note that in GraphQL.js these are defined differently: 47# BREAK = {}, SKIP = false, REMOVE = null, IDLE = undefined 48 49BREAK = VisitorActionEnum.BREAK 50SKIP = VisitorActionEnum.SKIP 51REMOVE = VisitorActionEnum.REMOVE 52IDLE = None 53 54# Default map from visitor kinds to their traversable node attributes: 55QUERY_DOCUMENT_KEYS: Dict[str, Tuple[str, ...]] = { 56 "name": (), 57 "document": ("definitions",), 58 "operation_definition": ( 59 "name", 60 "variable_definitions", 61 "directives", 62 "selection_set", 63 ), 64 "variable_definition": ("variable", "type", "default_value", "directives"), 65 "variable": ("name",), 66 "selection_set": ("selections",), 67 "field": ("alias", "name", "arguments", "directives", "selection_set"), 68 "argument": ("name", "value"), 69 "fragment_spread": ("name", "directives"), 70 "inline_fragment": ("type_condition", "directives", "selection_set"), 71 "fragment_definition": ( 72 # Note: fragment variable definitions are experimental and may be changed or 73 # removed in the future. 74 "name", 75 "variable_definitions", 76 "type_condition", 77 "directives", 78 "selection_set", 79 ), 80 "int_value": (), 81 "float_value": (), 82 "string_value": (), 83 "boolean_value": (), 84 "enum_value": (), 85 "list_value": ("values",), 86 "object_value": ("fields",), 87 "object_field": ("name", "value"), 88 "directive": ("name", "arguments"), 89 "named_type": ("name",), 90 "list_type": ("type",), 91 "non_null_type": ("type",), 92 "schema_definition": ("description", "directives", "operation_types"), 93 "operation_type_definition": ("type",), 94 "scalar_type_definition": ("description", "name", "directives"), 95 "object_type_definition": ( 96 "description", 97 "name", 98 "interfaces", 99 "directives", 100 "fields", 101 ), 102 "field_definition": ("description", "name", "arguments", "type", "directives"), 103 "input_value_definition": ( 104 "description", 105 "name", 106 "type", 107 "default_value", 108 "directives", 109 ), 110 "interface_type_definition": ( 111 "description", 112 "name", 113 "interfaces", 114 "directives", 115 "fields", 116 ), 117 "union_type_definition": ("description", "name", "directives", "types"), 118 "enum_type_definition": ("description", "name", "directives", "values"), 119 "enum_value_definition": ("description", "name", "directives"), 120 "input_object_type_definition": ("description", "name", "directives", "fields"), 121 "directive_definition": ("description", "name", "arguments", "locations"), 122 "schema_extension": ("directives", "operation_types"), 123 "scalar_type_extension": ("name", "directives"), 124 "object_type_extension": ("name", "interfaces", "directives", "fields"), 125 "interface_type_extension": ("name", "interfaces", "directives", "fields"), 126 "union_type_extension": ("name", "directives", "types"), 127 "enum_type_extension": ("name", "directives", "values"), 128 "input_object_type_extension": ("name", "directives", "fields"), 129} 130 131 132class Visitor: 133 """Visitor that walks through an AST. 134 135 Visitors can define two generic methods "enter" and "leave". The former will be 136 called when a node is entered in the traversal, the latter is called after visiting 137 the node and its child nodes. These methods have the following signature:: 138 139 def enter(self, node, key, parent, path, ancestors): 140 # The return value has the following meaning: 141 # IDLE (None): no action 142 # SKIP: skip visiting this node 143 # BREAK: stop visiting altogether 144 # REMOVE: delete this node 145 # any other value: replace this node with the returned value 146 return 147 148 def leave(self, node, key, parent, path, ancestors): 149 # The return value has the following meaning: 150 # IDLE (None) or SKIP: no action 151 # BREAK: stop visiting altogether 152 # REMOVE: delete this node 153 # any other value: replace this node with the returned value 154 return 155 156 The parameters have the following meaning: 157 158 :arg node: The current node being visiting. 159 :arg key: The index or key to this node from the parent node or Array. 160 :arg parent: the parent immediately above this node, which may be an Array. 161 :arg path: The key path to get to this node from the root node. 162 :arg ancestors: All nodes and Arrays visited before reaching parent 163 of this node. These correspond to array indices in ``path``. 164 Note: ancestors includes arrays which contain the parent of visited node. 165 166 You can also define node kind specific methods by suffixing them with an underscore 167 followed by the kind of the node to be visited. For instance, to visit ``field`` 168 nodes, you would defined the methods ``enter_field()`` and/or ``leave_field()``, 169 with the same signature as above. If no kind specific method has been defined 170 for a given node, the generic method is called. 171 """ 172 173 # Provide special return values as attributes 174 BREAK, SKIP, REMOVE, IDLE = BREAK, SKIP, REMOVE, IDLE 175 176 def __init_subclass__(cls) -> None: 177 """Verify that all defined handlers are valid.""" 178 super().__init_subclass__() 179 for attr, val in cls.__dict__.items(): 180 if attr.startswith("_"): 181 continue 182 attr_kind = attr.split("_", 1) 183 if len(attr_kind) < 2: 184 kind: Optional[str] = None 185 else: 186 attr, kind = attr_kind 187 if attr in ("enter", "leave"): 188 if kind: 189 name = snake_to_camel(kind) + "Node" 190 node_cls = getattr(ast, name, None) 191 if ( 192 not node_cls 193 or not isinstance(node_cls, type) 194 or not issubclass(node_cls, Node) 195 ): 196 raise TypeError(f"Invalid AST node kind: {kind}.") 197 198 def get_visit_fn(self, kind: str, is_leaving: bool = False) -> Callable: 199 """Get the visit function for the given node kind and direction.""" 200 method = "leave" if is_leaving else "enter" 201 visit_fn = getattr(self, f"{method}_{kind}", None) 202 if not visit_fn: 203 visit_fn = getattr(self, method, None) 204 return visit_fn 205 206 207class Stack(NamedTuple): 208 """A stack for the visit function.""" 209 210 in_array: bool 211 idx: int 212 keys: Tuple[Node, ...] 213 edits: List[Tuple[Union[int, str], Node]] 214 prev: Any # 'Stack' (python/mypy/issues/731) 215 216 217def visit( 218 root: Node, 219 visitor: Visitor, 220 visitor_keys: Optional[Dict[str, Tuple[str, ...]]] = None, 221) -> Any: 222 """Visit each node in an AST. 223 224 :func:`~.visit` will walk through an AST using a depth-first traversal, calling the 225 visitor's enter methods at each node in the traversal, and calling the leave methods 226 after visiting that node and all of its child nodes. 227 228 By returning different values from the enter and leave methods, the behavior of the 229 visitor can be altered, including skipping over a sub-tree of the AST (by returning 230 False), editing the AST by returning a value or None to remove the value, or to stop 231 the whole traversal by returning :data:`~.BREAK`. 232 233 When using :func:`~.visit` to edit an AST, the original AST will not be modified, 234 and a new version of the AST with the changes applied will be returned from the 235 visit function. 236 237 To customize the node attributes to be used for traversal, you can provide a 238 dictionary visitor_keys mapping node kinds to node attributes. 239 """ 240 if not isinstance(root, Node): 241 raise TypeError(f"Not an AST Node: {inspect(root)}.") 242 if not isinstance(visitor, Visitor): 243 raise TypeError(f"Not an AST Visitor: {inspect(visitor)}.") 244 if visitor_keys is None: 245 visitor_keys = QUERY_DOCUMENT_KEYS 246 stack: Any = None 247 in_array = isinstance(root, list) 248 keys: Tuple[Node, ...] = (root,) 249 idx = -1 250 edits: List[Any] = [] 251 parent: Any = None 252 path: List[Any] = [] 253 path_append = path.append 254 path_pop = path.pop 255 ancestors: List[Any] = [] 256 ancestors_append = ancestors.append 257 ancestors_pop = ancestors.pop 258 new_root = root 259 260 while True: 261 idx += 1 262 is_leaving = idx == len(keys) 263 is_edited = is_leaving and edits 264 if is_leaving: 265 key = path[-1] if ancestors else None 266 node: Any = parent 267 parent = ancestors_pop() if ancestors else None 268 if is_edited: 269 if in_array: 270 node = node[:] 271 else: 272 node = copy(node) 273 edit_offset = 0 274 for edit_key, edit_value in edits: 275 if in_array: 276 edit_key -= edit_offset 277 if in_array and (edit_value is REMOVE or edit_value is Ellipsis): 278 node.pop(edit_key) 279 edit_offset += 1 280 else: 281 if isinstance(node, list): 282 node[edit_key] = edit_value 283 else: 284 setattr(node, edit_key, edit_value) 285 286 idx = stack.idx 287 keys = stack.keys 288 edits = stack.edits 289 in_array = stack.in_array 290 stack = stack.prev 291 else: 292 if parent: 293 if in_array: 294 key = idx 295 node = parent[key] 296 else: 297 key = keys[idx] 298 node = getattr(parent, key, None) 299 else: 300 key = None 301 node = new_root 302 if node is None: 303 continue 304 if parent: 305 path_append(key) 306 307 if isinstance(node, list): 308 result = None 309 else: 310 if not isinstance(node, Node): 311 raise TypeError(f"Invalid AST Node: {inspect(node)}.") 312 visit_fn = visitor.get_visit_fn(node.kind, is_leaving) 313 if visit_fn: 314 result = visit_fn(node, key, parent, path, ancestors) 315 316 if result is BREAK or result is True: 317 break 318 319 if result is SKIP or result is False: 320 if not is_leaving: 321 path_pop() 322 continue 323 324 elif result is not None: 325 edits.append((key, result)) 326 if not is_leaving: 327 if isinstance(result, Node): 328 node = result 329 else: 330 path_pop() 331 continue 332 else: 333 result = None 334 335 if result is None and is_edited: 336 edits.append((key, node)) 337 338 if is_leaving: 339 if path: 340 path_pop() 341 else: 342 stack = Stack(in_array, idx, keys, edits, stack) 343 in_array = isinstance(node, list) 344 keys = node if in_array else visitor_keys.get(node.kind, ()) 345 idx = -1 346 edits = [] 347 if parent: 348 ancestors_append(parent) 349 parent = node 350 351 if not stack: 352 break 353 354 if edits: 355 new_root = edits[-1][1] 356 357 return new_root 358 359 360class ParallelVisitor(Visitor): 361 """A Visitor which delegates to many visitors to run in parallel. 362 363 Each visitor will be visited for each node before moving on. 364 365 If a prior visitor edits a node, no following visitors will see that node. 366 """ 367 368 def __init__(self, visitors: Collection[Visitor]): 369 """Create a new visitor from the given list of parallel visitors.""" 370 self.visitors = visitors 371 self.skipping: List[Any] = [None] * len(visitors) 372 373 def enter(self, node: Node, *args: Any) -> Optional[VisitorAction]: 374 skipping = self.skipping 375 for i, visitor in enumerate(self.visitors): 376 if not skipping[i]: 377 fn = visitor.get_visit_fn(node.kind) 378 if fn: 379 result = fn(node, *args) 380 if result is SKIP or result is False: 381 skipping[i] = node 382 elif result is BREAK or result is True: 383 skipping[i] = BREAK 384 elif result is not None: 385 return result 386 return None 387 388 def leave(self, node: Node, *args: Any) -> Optional[VisitorAction]: 389 skipping = self.skipping 390 for i, visitor in enumerate(self.visitors): 391 if not skipping[i]: 392 fn = visitor.get_visit_fn(node.kind, is_leaving=True) 393 if fn: 394 result = fn(node, *args) 395 if result is BREAK or result is True: 396 skipping[i] = BREAK 397 elif ( 398 result is not None 399 and result is not SKIP 400 and result is not False 401 ): 402 return result 403 elif skipping[i] is node: 404 skipping[i] = None 405 return None 406