1"""HandleImport transformation takes care of importing user-defined modules.""" 2from pythran.passmanager import Transformation 3from pythran.tables import MODULES, pythran_ward 4from pythran.syntax import PythranSyntaxError 5 6import gast as ast 7import logging 8import os 9 10logger = logging.getLogger('pythran') 11 12 13def add_filename_field(node, filename): 14 for descendant in ast.walk(node): 15 descendant.filename = filename 16 17 18def mangle_imported_module(module_name): 19 return pythran_ward + "imported__" + module_name.replace('.', '$') + '$' 20 21 22def mangle_imported_function(module_name, func_name): 23 return mangle_imported_module(module_name) + func_name 24 25 26def demangle(name): 27 return name[len(pythran_ward + "imported__"):-1].replace('$', '.') 28 29 30def is_builtin_function(func_name): 31 """Test if a function is a builtin (like len(), map(), ...).""" 32 return func_name in MODULES["builtins"] 33 34 35def is_builtin_module(module_name): 36 """Test if a module is a builtin module (numpy, math, ...).""" 37 module_name = module_name.split(".")[0] 38 return module_name in MODULES 39 40 41def is_mangled_module(name): 42 return name.endswith('$') 43 44 45def getsource(name, module_dir, level): 46 # Try to load py file 47 module_base = name.replace('.', os.path.sep) + '.py' 48 if module_dir is None: 49 assert level <= 0, "Cannot use relative path without module_dir" 50 module_file = module_base 51 else: 52 module_file = os.path.sep.join(([module_dir] + ['..'] * (level - 1) 53 + [module_base])) 54 try: 55 with open(module_file, 'r') as fp: 56 from pythran.frontend import raw_parse 57 node = raw_parse(fp.read()) 58 add_filename_field(node, name + ".py") 59 return node 60 except IOError: 61 raise PythranSyntaxError("Module '{}' not found." 62 .format(name)) 63 64 65class HandleImport(Transformation): 66 67 """This pass handle user-defined import, mangling name for function from 68 other modules and include them in the current module, patching all call 69 site accordingly. 70 """ 71 72 def __init__(self): 73 super(HandleImport, self).__init__() 74 self.identifiers = [{}] 75 self.imported = set() 76 self.prefixes = [""] 77 78 def lookup(self, name): 79 for renaming in reversed(self.identifiers): 80 if name in renaming: 81 return renaming[name] 82 return None 83 84 def is_imported(self, name): 85 return name in self.imported 86 87 def visit_Module(self, node): 88 self.imported_stmts = list() 89 self.generic_visit(node) 90 node.body = self.imported_stmts + node.body 91 return node 92 93 def rename(self, node, attr): 94 prev_name = getattr(node, attr) 95 new_name = self.prefixes[-1] + prev_name 96 setattr(node, attr, new_name) 97 self.identifiers[-1][prev_name] = new_name 98 99 def rename_top_level_functions(self, node): 100 for stmt in node.body: 101 if isinstance(stmt, ast.FunctionDef): 102 self.rename(stmt, 'name') 103 elif isinstance(stmt, ast.Assign): 104 for target in stmt.targets: 105 if isinstance(target, ast.Name): 106 self.rename(target, 'id') 107 108 def visit_FunctionDef(self, node): 109 self.identifiers.append({}) 110 self.generic_visit(node) 111 self.identifiers.pop() 112 return node 113 114 def visit_ListComp(self, node): 115 # change transversal order so that store happens before load 116 for generator in node.generators: 117 self.visit(generator) 118 self.visit(node.elt) 119 return node 120 121 visit_SetComp = visit_ListComp 122 visit_GeneratorExp = visit_ListComp 123 124 def visit_DictComp(self, node): 125 for generator in node.generators: 126 self.visit(generator) 127 self.visit(node.key) 128 self.visit(node.value) 129 return node 130 131 def visit_comprehension(self, node): 132 self.visit(node.iter) 133 for if_ in node.ifs: 134 self.visit(if_) 135 self.visit(node.target) 136 return node 137 138 def visit_assign(self, node): 139 self.visit(node.value) 140 for target in node.targets: 141 self.visit(target) 142 return node 143 144 def visit_Assign(self, node): 145 if not isinstance(node.value, ast.Name): 146 return self.visit_assign(node) 147 148 renaming = self.lookup(node.value.id) 149 if not renaming: 150 return self.visit_assign(node) 151 152 if not is_mangled_module(renaming): 153 return self.visit_assign(node) 154 155 if any(not isinstance(target, ast.Name) for target in node.targets): 156 raise PythranSyntaxError("Invalid module assignment", node) 157 158 return node 159 160 def visit_Call(self, node): 161 if isinstance(node.func, ast.Name): 162 renaming = self.lookup(node.func.id) 163 if renaming and is_mangled_module(renaming): 164 raise PythranSyntaxError("Invalid module call", node) 165 return self.generic_visit(node) 166 167 def visit_Name(self, node): 168 if isinstance(node.ctx, ast.Load): 169 renaming = self.lookup(node.id) 170 if renaming: 171 node.id = renaming 172 elif isinstance(node.ctx, (ast.Store, ast.Param)): 173 self.identifiers[-1][node.id] = node.id 174 elif isinstance(node.ctx, ast.Del): 175 pass 176 else: 177 raise NotImplementedError(node) 178 return node 179 180 def visit_Attribute(self, node): 181 if not isinstance(node.ctx, ast.Load): 182 return node 183 184 # is that a module attribute load? 185 root = node.value 186 while isinstance(root, ast.Attribute): 187 root = root.value 188 if not isinstance(root, ast.Name): 189 return node 190 191 renaming = self.lookup(root.id) 192 193 if not renaming: 194 return node 195 196 if not is_mangled_module(renaming): 197 return node 198 199 base_module = demangle(renaming) 200 201 if is_builtin_module(base_module): 202 return node 203 204 renaming = self.lookup(root.id) 205 206 root = node 207 suffix = "" 208 while isinstance(root, ast.Attribute): 209 root = root.value 210 suffix = '$' + node.attr + suffix 211 return ast.Name(renaming + suffix[1:], node.ctx, None, None) 212 213 def import_module(self, module_name, module_level): 214 self.imported.add(module_name) 215 module_node = getsource(module_name, 216 self.passmanager.module_dir, 217 module_level) 218 self.prefixes.append(mangle_imported_module(module_name)) 219 self.identifiers.append({}) 220 self.rename_top_level_functions(module_node) 221 self.generic_visit(module_node) 222 self.prefixes.pop() 223 self.identifiers.pop() 224 return module_node.body 225 226 def visit_ImportFrom(self, node): 227 if node.module == '__future__': 228 return None 229 230 if is_builtin_module(node.module): 231 for alias in node.names: 232 name = alias.asname or alias.name 233 self.identifiers[-1][name] = name 234 return node 235 else: 236 for alias in node.names: 237 name = alias.asname or alias.name 238 self.identifiers[-1][name] = mangle_imported_function( 239 node.module, alias.name) 240 241 if self.is_imported(node.module): 242 return None 243 244 new_stmts = self.import_module(node.module, node.level) 245 self.imported_stmts.extend(new_stmts) 246 247 return None 248 249 def visit_Import(self, node): 250 new_aliases = [] 251 for alias in node.names: 252 name = alias.asname or alias.name 253 self.identifiers[-1][name] = mangle_imported_module(alias.name) 254 if alias.name in self.imported: 255 continue 256 if is_builtin_module(alias.name): 257 new_aliases.append(alias) 258 continue 259 260 new_stmts = self.import_module(alias.name, 0) 261 self.imported_stmts.extend(new_stmts) 262 263 if new_aliases: 264 node.names = new_aliases 265 return node 266 else: 267 return None 268