1# -*- coding: utf-8 -*-
2#
3# Copyright 2008 by Armin Ronacher.
4# License: Python License.
5#
6
7import _ast
8
9from _ast import *
10
11
12def fix_missing_locations(node):
13    """
14    When you compile a node tree with compile(), the compiler expects lineno and
15    col_offset attributes for every node that supports them.  This is rather
16    tedious to fill in for generated nodes, so this helper adds these attributes
17    recursively where not already set, by setting them to the values of the
18    parent node.  It works recursively starting at *node*.
19    """
20    def _fix(node, lineno, col_offset):
21        if 'lineno' in node._attributes:
22            if not hasattr(node, 'lineno'):
23                node.lineno = lineno
24            else:
25                lineno = node.lineno
26        if 'col_offset' in node._attributes:
27            if not hasattr(node, 'col_offset'):
28                node.col_offset = col_offset
29            else:
30                col_offset = node.col_offset
31        for child in iter_child_nodes(node):
32            _fix(child, lineno, col_offset)
33    _fix(node, 1, 0)
34    return node
35
36
37def iter_child_nodes(node):
38    """
39    Yield all direct child nodes of *node*, that is, all fields that are nodes
40    and all items of fields that are lists of nodes.
41    """
42    for name, field in iter_fields(node):
43        if isinstance(field, (AST, _ast.AST)):
44            yield field
45        elif isinstance(field, list):
46            for item in field:
47                if isinstance(item, (AST, _ast.AST)):
48                    yield item
49
50
51def iter_fields(node):
52    """
53    Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
54    that is present on *node*.
55    """
56
57    for field in node._fields or ():
58        try:
59            yield field, getattr(node, field)
60        except AttributeError:
61            pass
62
63
64def walk(node):
65    """
66    Recursively yield all child nodes of *node*, in no specified order.  This is
67    useful if you only want to modify nodes in place and don't care about the
68    context.
69    """
70    from collections import deque
71    todo = deque([node])
72    while todo:
73        node = todo.popleft()
74        todo.extend(iter_child_nodes(node))
75        yield node
76
77
78class NodeVisitor(object):
79    """
80    A node visitor base class that walks the abstract syntax tree and calls a
81    visitor function for every node found.  This function may return a value
82    which is forwarded by the `visit` method.
83
84    This class is meant to be subclassed, with the subclass adding visitor
85    methods.
86
87    Per default the visitor functions for the nodes are ``'visit_'`` +
88    class name of the node.  So a `TryFinally` node visit function would
89    be `visit_TryFinally`.  This behavior can be changed by overriding
90    the `visit` method.  If no visitor function exists for a node
91    (return value `None`) the `generic_visit` visitor is used instead.
92
93    Don't use the `NodeVisitor` if you want to apply changes to nodes during
94    traversing.  For this a special visitor exists (`NodeTransformer`) that
95    allows modifications.
96    """
97
98    def visit(self, node):
99        """Visit a node."""
100        method = 'visit_' + node.__class__.__name__
101        visitor = getattr(self, method, self.generic_visit)
102        return visitor(node)
103
104    def generic_visit(self, node):
105        """Called if no explicit visitor function exists for a node."""
106        for field, value in iter_fields(node):
107            if isinstance(value, list):
108                for item in value:
109                    if isinstance(item, (AST, _ast.AST)):
110                        self.visit(item)
111            elif isinstance(value, (AST, _ast.AST)):
112                self.visit(value)
113
114
115class AST(object):
116    _fields = ()
117    _attributes = 'lineno', 'col_offset'
118
119    def __init__(self, *args, **kwargs):
120        self.__dict__.update(kwargs)
121        self._fields = self._fields or ()
122        for name, value in zip(self._fields, args):
123            setattr(self, name, value)
124
125
126for name, cls in _ast.__dict__.items():
127    if isinstance(cls, type) and issubclass(cls, _ast.AST):
128        try:
129            cls.__bases__ = (AST, ) + cls.__bases__
130        except TypeError:
131            pass
132
133
134class ExceptHandler(AST):
135    _fields = "type", "name", "body"
136