1"""Utility functions, node construction macros, etc."""
2# Author: Collin Winter
3
4# Local imports
5from .pgen2 import token
6from .pytree import Leaf, Node
7from .pygram import python_symbols as syms
8from . import patcomp
9
10
11###########################################################
12### Common node-construction "macros"
13###########################################################
14
15def KeywordArg(keyword, value):
16    return Node(syms.argument,
17                [keyword, Leaf(token.EQUAL, "="), value])
18
19def LParen():
20    return Leaf(token.LPAR, "(")
21
22def RParen():
23    return Leaf(token.RPAR, ")")
24
25def Assign(target, source):
26    """Build an assignment statement"""
27    if not isinstance(target, list):
28        target = [target]
29    if not isinstance(source, list):
30        source.prefix = " "
31        source = [source]
32
33    return Node(syms.atom,
34                target + [Leaf(token.EQUAL, "=", prefix=" ")] + source)
35
36def Name(name, prefix=None):
37    """Return a NAME leaf"""
38    return Leaf(token.NAME, name, prefix=prefix)
39
40def Attr(obj, attr):
41    """A node tuple for obj.attr"""
42    return [obj, Node(syms.trailer, [Dot(), attr])]
43
44def Comma():
45    """A comma leaf"""
46    return Leaf(token.COMMA, ",")
47
48def Dot():
49    """A period (.) leaf"""
50    return Leaf(token.DOT, ".")
51
52def ArgList(args, lparen=LParen(), rparen=RParen()):
53    """A parenthesised argument list, used by Call()"""
54    node = Node(syms.trailer, [lparen.clone(), rparen.clone()])
55    if args:
56        node.insert_child(1, Node(syms.arglist, args))
57    return node
58
59def Call(func_name, args=None, prefix=None):
60    """A function call"""
61    node = Node(syms.power, [func_name, ArgList(args)])
62    if prefix is not None:
63        node.prefix = prefix
64    return node
65
66def Newline():
67    """A newline literal"""
68    return Leaf(token.NEWLINE, "\n")
69
70def BlankLine():
71    """A blank line"""
72    return Leaf(token.NEWLINE, "")
73
74def Number(n, prefix=None):
75    return Leaf(token.NUMBER, n, prefix=prefix)
76
77def Subscript(index_node):
78    """A numeric or string subscript"""
79    return Node(syms.trailer, [Leaf(token.LBRACE, "["),
80                               index_node,
81                               Leaf(token.RBRACE, "]")])
82
83def String(string, prefix=None):
84    """A string leaf"""
85    return Leaf(token.STRING, string, prefix=prefix)
86
87def ListComp(xp, fp, it, test=None):
88    """A list comprehension of the form [xp for fp in it if test].
89
90    If test is None, the "if test" part is omitted.
91    """
92    xp.prefix = ""
93    fp.prefix = " "
94    it.prefix = " "
95    for_leaf = Leaf(token.NAME, "for")
96    for_leaf.prefix = " "
97    in_leaf = Leaf(token.NAME, "in")
98    in_leaf.prefix = " "
99    inner_args = [for_leaf, fp, in_leaf, it]
100    if test:
101        test.prefix = " "
102        if_leaf = Leaf(token.NAME, "if")
103        if_leaf.prefix = " "
104        inner_args.append(Node(syms.comp_if, [if_leaf, test]))
105    inner = Node(syms.listmaker, [xp, Node(syms.comp_for, inner_args)])
106    return Node(syms.atom,
107                       [Leaf(token.LBRACE, "["),
108                        inner,
109                        Leaf(token.RBRACE, "]")])
110
111def FromImport(package_name, name_leafs):
112    """ Return an import statement in the form:
113        from package import name_leafs"""
114    # XXX: May not handle dotted imports properly (eg, package_name='foo.bar')
115    #assert package_name == '.' or '.' not in package_name, "FromImport has "\
116    #       "not been tested with dotted package names -- use at your own "\
117    #       "peril!"
118
119    for leaf in name_leafs:
120        # Pull the leaves out of their old tree
121        leaf.remove()
122
123    children = [Leaf(token.NAME, "from"),
124                Leaf(token.NAME, package_name, prefix=" "),
125                Leaf(token.NAME, "import", prefix=" "),
126                Node(syms.import_as_names, name_leafs)]
127    imp = Node(syms.import_from, children)
128    return imp
129
130def ImportAndCall(node, results, names):
131    """Returns an import statement and calls a method
132    of the module:
133
134    import module
135    module.name()"""
136    obj = results["obj"].clone()
137    if obj.type == syms.arglist:
138        newarglist = obj.clone()
139    else:
140        newarglist = Node(syms.arglist, [obj.clone()])
141    after = results["after"]
142    if after:
143        after = [n.clone() for n in after]
144    new = Node(syms.power,
145               Attr(Name(names[0]), Name(names[1])) +
146               [Node(syms.trailer,
147                     [results["lpar"].clone(),
148                      newarglist,
149                      results["rpar"].clone()])] + after)
150    new.prefix = node.prefix
151    return new
152
153
154###########################################################
155### Determine whether a node represents a given literal
156###########################################################
157
158def is_tuple(node):
159    """Does the node represent a tuple literal?"""
160    if isinstance(node, Node) and node.children == [LParen(), RParen()]:
161        return True
162    return (isinstance(node, Node)
163            and len(node.children) == 3
164            and isinstance(node.children[0], Leaf)
165            and isinstance(node.children[1], Node)
166            and isinstance(node.children[2], Leaf)
167            and node.children[0].value == "("
168            and node.children[2].value == ")")
169
170def is_list(node):
171    """Does the node represent a list literal?"""
172    return (isinstance(node, Node)
173            and len(node.children) > 1
174            and isinstance(node.children[0], Leaf)
175            and isinstance(node.children[-1], Leaf)
176            and node.children[0].value == "["
177            and node.children[-1].value == "]")
178
179
180###########################################################
181### Misc
182###########################################################
183
184def parenthesize(node):
185    return Node(syms.atom, [LParen(), node, RParen()])
186
187
188consuming_calls = {"sorted", "list", "set", "any", "all", "tuple", "sum",
189                   "min", "max", "enumerate"}
190
191def attr_chain(obj, attr):
192    """Follow an attribute chain.
193
194    If you have a chain of objects where a.foo -> b, b.foo-> c, etc,
195    use this to iterate over all objects in the chain. Iteration is
196    terminated by getattr(x, attr) is None.
197
198    Args:
199        obj: the starting object
200        attr: the name of the chaining attribute
201
202    Yields:
203        Each successive object in the chain.
204    """
205    next = getattr(obj, attr)
206    while next:
207        yield next
208        next = getattr(next, attr)
209
210p0 = """for_stmt< 'for' any 'in' node=any ':' any* >
211        | comp_for< 'for' any 'in' node=any any* >
212     """
213p1 = """
214power<
215    ( 'iter' | 'list' | 'tuple' | 'sorted' | 'set' | 'sum' |
216      'any' | 'all' | 'enumerate' | (any* trailer< '.' 'join' >) )
217    trailer< '(' node=any ')' >
218    any*
219>
220"""
221p2 = """
222power<
223    ( 'sorted' | 'enumerate' )
224    trailer< '(' arglist<node=any any*> ')' >
225    any*
226>
227"""
228pats_built = False
229def in_special_context(node):
230    """ Returns true if node is in an environment where all that is required
231        of it is being iterable (ie, it doesn't matter if it returns a list
232        or an iterator).
233        See test_map_nochange in test_fixers.py for some examples and tests.
234        """
235    global p0, p1, p2, pats_built
236    if not pats_built:
237        p0 = patcomp.compile_pattern(p0)
238        p1 = patcomp.compile_pattern(p1)
239        p2 = patcomp.compile_pattern(p2)
240        pats_built = True
241    patterns = [p0, p1, p2]
242    for pattern, parent in zip(patterns, attr_chain(node, "parent")):
243        results = {}
244        if pattern.match(parent, results) and results["node"] is node:
245            return True
246    return False
247
248def is_probably_builtin(node):
249    """
250    Check that something isn't an attribute or function name etc.
251    """
252    prev = node.prev_sibling
253    if prev is not None and prev.type == token.DOT:
254        # Attribute lookup.
255        return False
256    parent = node.parent
257    if parent.type in (syms.funcdef, syms.classdef):
258        return False
259    if parent.type == syms.expr_stmt and parent.children[0] is node:
260        # Assignment.
261        return False
262    if parent.type == syms.parameters or \
263            (parent.type == syms.typedargslist and (
264            (prev is not None and prev.type == token.COMMA) or
265            parent.children[0] is node
266            )):
267        # The name of an argument.
268        return False
269    return True
270
271def find_indentation(node):
272    """Find the indentation of *node*."""
273    while node is not None:
274        if node.type == syms.suite and len(node.children) > 2:
275            indent = node.children[1]
276            if indent.type == token.INDENT:
277                return indent.value
278        node = node.parent
279    return ""
280
281###########################################################
282### The following functions are to find bindings in a suite
283###########################################################
284
285def make_suite(node):
286    if node.type == syms.suite:
287        return node
288    node = node.clone()
289    parent, node.parent = node.parent, None
290    suite = Node(syms.suite, [node])
291    suite.parent = parent
292    return suite
293
294def find_root(node):
295    """Find the top level namespace."""
296    # Scamper up to the top level namespace
297    while node.type != syms.file_input:
298        node = node.parent
299        if not node:
300            raise ValueError("root found before file_input node was found.")
301    return node
302
303def does_tree_import(package, name, node):
304    """ Returns true if name is imported from package at the
305        top level of the tree which node belongs to.
306        To cover the case of an import like 'import foo', use
307        None for the package and 'foo' for the name. """
308    binding = find_binding(name, find_root(node), package)
309    return bool(binding)
310
311def is_import(node):
312    """Returns true if the node is an import statement."""
313    return node.type in (syms.import_name, syms.import_from)
314
315def touch_import(package, name, node):
316    """ Works like `does_tree_import` but adds an import statement
317        if it was not imported. """
318    def is_import_stmt(node):
319        return (node.type == syms.simple_stmt and node.children and
320                is_import(node.children[0]))
321
322    root = find_root(node)
323
324    if does_tree_import(package, name, root):
325        return
326
327    # figure out where to insert the new import.  First try to find
328    # the first import and then skip to the last one.
329    insert_pos = offset = 0
330    for idx, node in enumerate(root.children):
331        if not is_import_stmt(node):
332            continue
333        for offset, node2 in enumerate(root.children[idx:]):
334            if not is_import_stmt(node2):
335                break
336        insert_pos = idx + offset
337        break
338
339    # if there are no imports where we can insert, find the docstring.
340    # if that also fails, we stick to the beginning of the file
341    if insert_pos == 0:
342        for idx, node in enumerate(root.children):
343            if (node.type == syms.simple_stmt and node.children and
344               node.children[0].type == token.STRING):
345                insert_pos = idx + 1
346                break
347
348    if package is None:
349        import_ = Node(syms.import_name, [
350            Leaf(token.NAME, "import"),
351            Leaf(token.NAME, name, prefix=" ")
352        ])
353    else:
354        import_ = FromImport(package, [Leaf(token.NAME, name, prefix=" ")])
355
356    children = [import_, Newline()]
357    root.insert_child(insert_pos, Node(syms.simple_stmt, children))
358
359
360_def_syms = {syms.classdef, syms.funcdef}
361def find_binding(name, node, package=None):
362    """ Returns the node which binds variable name, otherwise None.
363        If optional argument package is supplied, only imports will
364        be returned.
365        See test cases for examples."""
366    for child in node.children:
367        ret = None
368        if child.type == syms.for_stmt:
369            if _find(name, child.children[1]):
370                return child
371            n = find_binding(name, make_suite(child.children[-1]), package)
372            if n: ret = n
373        elif child.type in (syms.if_stmt, syms.while_stmt):
374            n = find_binding(name, make_suite(child.children[-1]), package)
375            if n: ret = n
376        elif child.type == syms.try_stmt:
377            n = find_binding(name, make_suite(child.children[2]), package)
378            if n:
379                ret = n
380            else:
381                for i, kid in enumerate(child.children[3:]):
382                    if kid.type == token.COLON and kid.value == ":":
383                        # i+3 is the colon, i+4 is the suite
384                        n = find_binding(name, make_suite(child.children[i+4]), package)
385                        if n: ret = n
386        elif child.type in _def_syms and child.children[1].value == name:
387            ret = child
388        elif _is_import_binding(child, name, package):
389            ret = child
390        elif child.type == syms.simple_stmt:
391            ret = find_binding(name, child, package)
392        elif child.type == syms.expr_stmt:
393            if _find(name, child.children[0]):
394                ret = child
395
396        if ret:
397            if not package:
398                return ret
399            if is_import(ret):
400                return ret
401    return None
402
403_block_syms = {syms.funcdef, syms.classdef, syms.trailer}
404def _find(name, node):
405    nodes = [node]
406    while nodes:
407        node = nodes.pop()
408        if node.type > 256 and node.type not in _block_syms:
409            nodes.extend(node.children)
410        elif node.type == token.NAME and node.value == name:
411            return node
412    return None
413
414def _is_import_binding(node, name, package=None):
415    """ Will return node if node will import name, or node
416        will import * from package.  None is returned otherwise.
417        See test cases for examples. """
418
419    if node.type == syms.import_name and not package:
420        imp = node.children[1]
421        if imp.type == syms.dotted_as_names:
422            for child in imp.children:
423                if child.type == syms.dotted_as_name:
424                    if child.children[2].value == name:
425                        return node
426                elif child.type == token.NAME and child.value == name:
427                    return node
428        elif imp.type == syms.dotted_as_name:
429            last = imp.children[-1]
430            if last.type == token.NAME and last.value == name:
431                return node
432        elif imp.type == token.NAME and imp.value == name:
433            return node
434    elif node.type == syms.import_from:
435        # str(...) is used to make life easier here, because
436        # from a.b import parses to ['import', ['a', '.', 'b'], ...]
437        if package and str(node.children[1]).strip() != package:
438            return None
439        n = node.children[3]
440        if package and _find("as", n):
441            # See test_from_import_as for explanation
442            return None
443        elif n.type == syms.import_as_names and _find(name, n):
444            return node
445        elif n.type == syms.import_as_name:
446            child = n.children[2]
447            if child.type == token.NAME and child.value == name:
448                return node
449        elif n.type == token.NAME and n.value == name:
450            return node
451        elif package and n.type == token.STAR:
452            return node
453    return None
454