1""" Expand some builtins implementation when it is profitable.""" 2 3from pythran.analyses import Aliases 4from pythran.analyses.pure_expressions import PureExpressions 5from pythran.passmanager import Transformation 6from pythran.tables import MODULES 7from pythran.intrinsic import FunctionIntr 8from pythran.utils import path_to_attr, path_to_node 9from pythran.syntax import PythranSyntaxError 10 11from copy import deepcopy 12import gast as ast 13 14 15class InlineBuiltins(Transformation): 16 17 """ 18 Replace some builtins by their bodies. 19 20 This may trigger some extra optimizations later on! 21 22 >>> import gast as ast 23 >>> from pythran import passmanager, backend 24 >>> pm = passmanager.PassManager("test") 25 >>> node = ast.parse(''' 26 ... def foo(a): 27 ... return a + 1 28 ... def bar(b): 29 ... return builtins.map(bar, (1, 2))''') 30 >>> _, node = pm.apply(InlineBuiltins, node) 31 >>> print(pm.dump(backend.Python, node)) 32 def foo(a): 33 return (a + 1) 34 def bar(b): 35 return [bar(1), bar(2)] 36 """ 37 38 def __init__(self): 39 Transformation.__init__(self, Aliases, PureExpressions) 40 41 def inlineBuiltinsXMap(self, node): 42 self.update = True 43 44 elts = [] 45 nelts = min(len(n.elts) for n in node.args[1:]) 46 for i in range(nelts): 47 elts.append([n.elts[i] for n in node.args[1:]]) 48 return ast.List([ast.Call(node.args[0], elt, []) for elt in elts], 49 ast.Load()) 50 51 def inlineBuiltinsMap(self, node): 52 53 if not isinstance(node, ast.Call): 54 return node 55 56 func_aliases = self.aliases[node.func] 57 if len(func_aliases) != 1: 58 return node 59 60 obj = next(iter(func_aliases)) 61 if obj is not MODULES['builtins']['map']: 62 return node 63 64 if not all(isinstance(arg, (ast.List, ast.Tuple)) 65 for arg in node.args[1:]): 66 return node 67 68 mapped_func_aliases = self.aliases[node.args[0]] 69 if len(mapped_func_aliases) != 1: 70 return node 71 72 obj = next(iter(mapped_func_aliases)) 73 if not isinstance(obj, (ast.FunctionDef, FunctionIntr)): 74 return node 75 76 # all preconditions are met, do it! 77 return self.inlineBuiltinsXMap(node) 78 79 def visit_Call(self, node): 80 node = self.generic_visit(node) 81 node = self.inlineBuiltinsMap(node) 82 return node 83 84 def make_array_index(self, base, size, index): 85 if isinstance(base, ast.Constant): 86 return ast.Constant(base.value, None) 87 if size == 1: 88 return deepcopy(base.elts[0]) 89 return base.elts[index] 90 91 def fixedSizeArray(self, node): 92 if isinstance(node, ast.Constant): 93 return node, 1 94 95 if isinstance(node, (ast.List, ast.Tuple)): 96 return node, len(node.elts) 97 98 if not isinstance(node, ast.Call): 99 return None, 0 100 101 func_aliases = self.aliases[node.func] 102 if len(func_aliases) != 1: 103 return None, 0 104 105 obj = next(iter(func_aliases)) 106 if obj not in (MODULES['numpy']['array'], MODULES['numpy']['asarray']): 107 return None, 0 108 109 if len(node.args) != 1: 110 return None, 0 111 112 if isinstance(node.args[0], (ast.List, ast.Tuple)): 113 return node.args[0], len(node.args[0].elts) 114 115 return None, 0 116 117 def inlineFixedSizeArrayBinOp(self, node): 118 119 alike = ast.List, ast.Tuple, ast.Constant 120 if isinstance(node.left, alike) and isinstance(node.right, alike): 121 return node 122 123 lbase, lsize = self.fixedSizeArray(node.left) 124 rbase, rsize = self.fixedSizeArray(node.right) 125 if not lbase or not rbase: 126 return node 127 128 if rsize != 1 and lsize != 1 and rsize != lsize: 129 raise PythranSyntaxError("Invalid numpy broadcasting", node) 130 131 self.update = True 132 133 operands = [ast.BinOp(self.make_array_index(lbase, lsize, i), 134 type(node.op)(), 135 self.make_array_index(rbase, rsize, i)) 136 for i in range(max(lsize, rsize))] 137 res = ast.Call(path_to_attr(('numpy', 'array')), 138 [ast.Tuple(operands, ast.Load())], 139 []) 140 self.aliases[res.func] = {path_to_node(('numpy', 'array'))} 141 return res 142 143 def visit_BinOp(self, node): 144 node = self.generic_visit(node) 145 node = self.inlineFixedSizeArrayBinOp(node) 146 return node 147 148 def inlineFixedSizeArrayUnaryOp(self, node): 149 150 if isinstance(node.operand, (ast.Constant, ast.List, ast.Tuple)): 151 return node 152 153 base, size = self.fixedSizeArray(node.operand) 154 if not base: 155 return node 156 157 self.update = True 158 159 operands = [ast.UnaryOp(type(node.op)(), 160 self.make_array_index(base, size, i)) 161 for i in range(size)] 162 res = ast.Call(path_to_attr(('numpy', 'array')), 163 [ast.Tuple(operands, ast.Load())], 164 []) 165 self.aliases[res.func] = {path_to_node(('numpy', 'array'))} 166 return res 167 168 def visit_UnaryOp(self, node): 169 node = self.generic_visit(node) 170 node = self.inlineFixedSizeArrayUnaryOp(node) 171 return node 172 173 def inlineFixedSizeArrayCompare(self, node): 174 if len(node.comparators) != 1: 175 return node 176 177 node_right = node.comparators[0] 178 179 alike = ast.Constant, ast.List, ast.Tuple 180 if isinstance(node.left, alike) and isinstance(node_right, alike): 181 return node 182 183 lbase, lsize = self.fixedSizeArray(node.left) 184 rbase, rsize = self.fixedSizeArray(node_right) 185 if not lbase or not rbase: 186 return node 187 188 if rsize != 1 and lsize != 1 and rsize != lsize: 189 raise PythranSyntaxError("Invalid numpy broadcasting", node) 190 191 self.update = True 192 193 operands = [ast.Compare(self.make_array_index(lbase, lsize, i), 194 [type(node.ops[0])()], 195 [self.make_array_index(rbase, rsize, i)]) 196 for i in range(max(lsize, rsize))] 197 res = ast.Call(path_to_attr(('numpy', 'array')), 198 [ast.Tuple(operands, ast.Load())], 199 []) 200 self.aliases[res.func] = {path_to_node(('numpy', 'array'))} 201 return res 202 203 def visit_Compare(self, node): 204 node = self.generic_visit(node) 205 node = self.inlineFixedSizeArrayCompare(node) 206 return node 207