1"""HandleImport transformation takes care of importing user-defined modules."""
2from pythran.passmanager import Transformation
3from pythran.tables import MODULES, pythran_ward
4from pythran.syntax import PythranSyntaxError
5
6import gast as ast
7import logging
8import os
9
10logger = logging.getLogger('pythran')
11
12
13def add_filename_field(node, filename):
14    for descendant in ast.walk(node):
15        descendant.filename = filename
16
17
18def mangle_imported_module(module_name):
19    return pythran_ward + "imported__" + module_name.replace('.', '$') + '$'
20
21
22def mangle_imported_function(module_name, func_name):
23    return mangle_imported_module(module_name) + func_name
24
25
26def demangle(name):
27    return name[len(pythran_ward + "imported__"):-1].replace('$', '.')
28
29
30def is_builtin_function(func_name):
31    """Test if a function is a builtin (like len(), map(), ...)."""
32    return func_name in MODULES["builtins"]
33
34
35def is_builtin_module(module_name):
36    """Test if a module is a builtin module (numpy, math, ...)."""
37    module_name = module_name.split(".")[0]
38    return module_name in MODULES
39
40
41def is_mangled_module(name):
42    return name.endswith('$')
43
44
45def getsource(name, module_dir, level):
46    # Try to load py file
47    module_base = name.replace('.', os.path.sep) + '.py'
48    if module_dir is None:
49        assert level <= 0, "Cannot use relative path without module_dir"
50        module_file = module_base
51    else:
52        module_file = os.path.sep.join(([module_dir] + ['..'] * (level - 1)
53                                        + [module_base]))
54    try:
55        with open(module_file, 'r') as fp:
56            from pythran.frontend import raw_parse
57            node = raw_parse(fp.read())
58            add_filename_field(node, name + ".py")
59            return node
60    except IOError:
61        raise PythranSyntaxError("Module '{}' not found."
62                                 .format(name))
63
64
65class HandleImport(Transformation):
66
67    """This pass handle user-defined import, mangling name for function from
68    other modules and include them in the current module, patching all call
69    site accordingly.
70    """
71
72    def __init__(self):
73        super(HandleImport, self).__init__()
74        self.identifiers = [{}]
75        self.imported = set()
76        self.prefixes = [""]
77
78    def lookup(self, name):
79        for renaming in reversed(self.identifiers):
80            if name in renaming:
81                return renaming[name]
82        return None
83
84    def is_imported(self, name):
85        return name in self.imported
86
87    def visit_Module(self, node):
88        self.imported_stmts = list()
89        self.generic_visit(node)
90        node.body = self.imported_stmts + node.body
91        return node
92
93    def rename(self, node, attr):
94        prev_name = getattr(node, attr)
95        new_name = self.prefixes[-1] + prev_name
96        setattr(node, attr, new_name)
97        self.identifiers[-1][prev_name] = new_name
98
99    def rename_top_level_functions(self, node):
100        for stmt in node.body:
101            if isinstance(stmt, ast.FunctionDef):
102                self.rename(stmt, 'name')
103            elif isinstance(stmt, ast.Assign):
104                for target in stmt.targets:
105                    if isinstance(target, ast.Name):
106                        self.rename(target, 'id')
107
108    def visit_FunctionDef(self, node):
109        self.identifiers.append({})
110        self.generic_visit(node)
111        self.identifiers.pop()
112        return node
113
114    def visit_ListComp(self, node):
115        # change transversal order so that store happens before load
116        for generator in node.generators:
117            self.visit(generator)
118        self.visit(node.elt)
119        return node
120
121    visit_SetComp = visit_ListComp
122    visit_GeneratorExp = visit_ListComp
123
124    def visit_DictComp(self, node):
125        for generator in node.generators:
126            self.visit(generator)
127        self.visit(node.key)
128        self.visit(node.value)
129        return node
130
131    def visit_comprehension(self, node):
132        self.visit(node.iter)
133        for if_ in node.ifs:
134            self.visit(if_)
135        self.visit(node.target)
136        return node
137
138    def visit_assign(self, node):
139        self.visit(node.value)
140        for target in node.targets:
141            self.visit(target)
142        return node
143
144    def visit_Assign(self, node):
145        if not isinstance(node.value, ast.Name):
146            return self.visit_assign(node)
147
148        renaming = self.lookup(node.value.id)
149        if not renaming:
150            return self.visit_assign(node)
151
152        if not is_mangled_module(renaming):
153            return self.visit_assign(node)
154
155        if any(not isinstance(target, ast.Name) for target in node.targets):
156            raise PythranSyntaxError("Invalid module assignment", node)
157
158        return node
159
160    def visit_Call(self, node):
161        if isinstance(node.func, ast.Name):
162            renaming = self.lookup(node.func.id)
163            if renaming and is_mangled_module(renaming):
164                raise PythranSyntaxError("Invalid module call", node)
165        return self.generic_visit(node)
166
167    def visit_Name(self, node):
168        if isinstance(node.ctx, ast.Load):
169            renaming = self.lookup(node.id)
170            if renaming:
171                node.id = renaming
172        elif isinstance(node.ctx, (ast.Store, ast.Param)):
173            self.identifiers[-1][node.id] = node.id
174        elif isinstance(node.ctx, ast.Del):
175            pass
176        else:
177            raise NotImplementedError(node)
178        return node
179
180    def visit_Attribute(self, node):
181        if not isinstance(node.ctx, ast.Load):
182            return node
183
184        # is that a module attribute load?
185        root = node.value
186        while isinstance(root, ast.Attribute):
187            root = root.value
188        if not isinstance(root, ast.Name):
189            return node
190
191        renaming = self.lookup(root.id)
192
193        if not renaming:
194            return node
195
196        if not is_mangled_module(renaming):
197            return node
198
199        base_module = demangle(renaming)
200
201        if is_builtin_module(base_module):
202            return node
203
204        renaming = self.lookup(root.id)
205
206        root = node
207        suffix = ""
208        while isinstance(root, ast.Attribute):
209            root = root.value
210            suffix = '$' + node.attr + suffix
211        return ast.Name(renaming + suffix[1:], node.ctx, None, None)
212
213    def import_module(self, module_name, module_level):
214        self.imported.add(module_name)
215        module_node = getsource(module_name,
216                                self.passmanager.module_dir,
217                                module_level)
218        self.prefixes.append(mangle_imported_module(module_name))
219        self.identifiers.append({})
220        self.rename_top_level_functions(module_node)
221        self.generic_visit(module_node)
222        self.prefixes.pop()
223        self.identifiers.pop()
224        return module_node.body
225
226    def visit_ImportFrom(self, node):
227        if node.module == '__future__':
228            return None
229
230        if is_builtin_module(node.module):
231            for alias in node.names:
232                name = alias.asname or alias.name
233                self.identifiers[-1][name] = name
234            return node
235        else:
236            for alias in node.names:
237                name = alias.asname or alias.name
238                self.identifiers[-1][name] = mangle_imported_function(
239                    node.module, alias.name)
240
241        if self.is_imported(node.module):
242            return None
243
244        new_stmts = self.import_module(node.module, node.level)
245        self.imported_stmts.extend(new_stmts)
246
247        return None
248
249    def visit_Import(self, node):
250        new_aliases = []
251        for alias in node.names:
252            name = alias.asname or alias.name
253            self.identifiers[-1][name] = mangle_imported_module(alias.name)
254            if alias.name in self.imported:
255                continue
256            if is_builtin_module(alias.name):
257                new_aliases.append(alias)
258                continue
259
260            new_stmts = self.import_module(alias.name, 0)
261            self.imported_stmts.extend(new_stmts)
262
263        if new_aliases:
264            node.names = new_aliases
265            return node
266        else:
267            return None
268