1# -*- coding: utf-8 -*- 2# 3# Copyright 2008 by Armin Ronacher. 4# License: Python License. 5# 6 7import _ast 8 9from _ast import * 10 11 12def fix_missing_locations(node): 13 """ 14 When you compile a node tree with compile(), the compiler expects lineno and 15 col_offset attributes for every node that supports them. This is rather 16 tedious to fill in for generated nodes, so this helper adds these attributes 17 recursively where not already set, by setting them to the values of the 18 parent node. It works recursively starting at *node*. 19 """ 20 def _fix(node, lineno, col_offset): 21 if 'lineno' in node._attributes: 22 if not hasattr(node, 'lineno'): 23 node.lineno = lineno 24 else: 25 lineno = node.lineno 26 if 'col_offset' in node._attributes: 27 if not hasattr(node, 'col_offset'): 28 node.col_offset = col_offset 29 else: 30 col_offset = node.col_offset 31 for child in iter_child_nodes(node): 32 _fix(child, lineno, col_offset) 33 _fix(node, 1, 0) 34 return node 35 36 37def iter_child_nodes(node): 38 """ 39 Yield all direct child nodes of *node*, that is, all fields that are nodes 40 and all items of fields that are lists of nodes. 41 """ 42 for name, field in iter_fields(node): 43 if isinstance(field, (AST, _ast.AST)): 44 yield field 45 elif isinstance(field, list): 46 for item in field: 47 if isinstance(item, (AST, _ast.AST)): 48 yield item 49 50 51def iter_fields(node): 52 """ 53 Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields`` 54 that is present on *node*. 55 """ 56 57 for field in node._fields or (): 58 try: 59 yield field, getattr(node, field) 60 except AttributeError: 61 pass 62 63 64def walk(node): 65 """ 66 Recursively yield all child nodes of *node*, in no specified order. This is 67 useful if you only want to modify nodes in place and don't care about the 68 context. 69 """ 70 from collections import deque 71 todo = deque([node]) 72 while todo: 73 node = todo.popleft() 74 todo.extend(iter_child_nodes(node)) 75 yield node 76 77 78class NodeVisitor(object): 79 """ 80 A node visitor base class that walks the abstract syntax tree and calls a 81 visitor function for every node found. This function may return a value 82 which is forwarded by the `visit` method. 83 84 This class is meant to be subclassed, with the subclass adding visitor 85 methods. 86 87 Per default the visitor functions for the nodes are ``'visit_'`` + 88 class name of the node. So a `TryFinally` node visit function would 89 be `visit_TryFinally`. This behavior can be changed by overriding 90 the `visit` method. If no visitor function exists for a node 91 (return value `None`) the `generic_visit` visitor is used instead. 92 93 Don't use the `NodeVisitor` if you want to apply changes to nodes during 94 traversing. For this a special visitor exists (`NodeTransformer`) that 95 allows modifications. 96 """ 97 98 def visit(self, node): 99 """Visit a node.""" 100 method = 'visit_' + node.__class__.__name__ 101 visitor = getattr(self, method, self.generic_visit) 102 return visitor(node) 103 104 def generic_visit(self, node): 105 """Called if no explicit visitor function exists for a node.""" 106 for field, value in iter_fields(node): 107 if isinstance(value, list): 108 for item in value: 109 if isinstance(item, (AST, _ast.AST)): 110 self.visit(item) 111 elif isinstance(value, (AST, _ast.AST)): 112 self.visit(value) 113 114 115class AST(object): 116 _fields = () 117 _attributes = 'lineno', 'col_offset' 118 119 def __init__(self, *args, **kwargs): 120 self.__dict__.update(kwargs) 121 self._fields = self._fields or () 122 for name, value in zip(self._fields, args): 123 setattr(self, name, value) 124 125 126for name, cls in _ast.__dict__.items(): 127 if isinstance(cls, type) and issubclass(cls, _ast.AST): 128 try: 129 cls.__bases__ = (AST, ) + cls.__bases__ 130 except TypeError: 131 pass 132 133 134class ExceptHandler(AST): 135 _fields = "type", "name", "body" 136