1from __future__ import absolute_import
2import ast
3from ast import *
4
5from rope.base import fscommands
6
7try:
8    unicode
9except NameError:
10    unicode = str
11
12
13def parse(source, filename='<string>'):
14    # NOTE: the raw string should be given to `compile` function
15    if isinstance(source, unicode):
16        source = fscommands.unicode_to_file_data(source)
17    if b'\r' in source:
18        source = source.replace(b'\r\n', b'\n').replace(b'\r', b'\n')
19    if not source.endswith(b'\n'):
20        source += b'\n'
21    try:
22        return ast.parse(source, filename='<unknown>')
23    except (TypeError, ValueError) as e:
24        error = SyntaxError()
25        error.lineno = 1
26        error.filename = filename
27        error.msg = str(e)
28        raise error
29
30
31def walk(node, walker):
32    """Walk the syntax tree"""
33    method_name = '_' + node.__class__.__name__
34    method = getattr(walker, method_name, None)
35    if method is not None:
36        if isinstance(node, ast.ImportFrom) and node.module is None:
37            # In python < 2.7 ``node.module == ''`` for relative imports
38            # but for python 2.7 it is None. Generalizing it to ''.
39            node.module = ''
40        return method(node)
41    for child in get_child_nodes(node):
42        walk(child, walker)
43
44
45def get_child_nodes(node):
46    if isinstance(node, ast.Module):
47        return node.body
48    result = []
49    if node._fields is not None:
50        for name in node._fields:
51            child = getattr(node, name)
52            if isinstance(child, list):
53                for entry in child:
54                    if isinstance(entry, ast.AST):
55                        result.append(entry)
56            if isinstance(child, ast.AST):
57                result.append(child)
58    return result
59
60
61def call_for_nodes(node, callback, recursive=False):
62    """If callback returns `True` the child nodes are skipped"""
63    result = callback(node)
64    if recursive and not result:
65        for child in get_child_nodes(node):
66            call_for_nodes(child, callback, recursive)
67
68
69def get_children(node):
70    result = []
71    if node._fields is not None:
72        for name in node._fields:
73            if name in ['lineno', 'col_offset']:
74                continue
75            child = getattr(node, name)
76            result.append(child)
77    return result
78