1""" Expand some builtins implementation when it is profitable."""
2
3from pythran.analyses import Aliases
4from pythran.analyses.pure_expressions import PureExpressions
5from pythran.passmanager import Transformation
6from pythran.tables import MODULES
7from pythran.intrinsic import FunctionIntr
8from pythran.utils import path_to_attr, path_to_node
9from pythran.syntax import PythranSyntaxError
10
11from copy import deepcopy
12import gast as ast
13
14
15class InlineBuiltins(Transformation):
16
17    """
18    Replace some builtins by their bodies.
19
20    This may trigger some extra optimizations later on!
21
22    >>> import gast as ast
23    >>> from pythran import passmanager, backend
24    >>> pm = passmanager.PassManager("test")
25    >>> node = ast.parse('''
26    ... def foo(a):
27    ...     return  a + 1
28    ... def bar(b):
29    ...     return builtins.map(bar, (1, 2))''')
30    >>> _, node = pm.apply(InlineBuiltins, node)
31    >>> print(pm.dump(backend.Python, node))
32    def foo(a):
33        return (a + 1)
34    def bar(b):
35        return [bar(1), bar(2)]
36    """
37
38    def __init__(self):
39        Transformation.__init__(self, Aliases, PureExpressions)
40
41    def inlineBuiltinsXMap(self, node):
42        self.update = True
43
44        elts = []
45        nelts = min(len(n.elts) for n in node.args[1:])
46        for i in range(nelts):
47            elts.append([n.elts[i] for n in node.args[1:]])
48        return ast.List([ast.Call(node.args[0], elt, []) for elt in elts],
49                        ast.Load())
50
51    def inlineBuiltinsMap(self, node):
52
53        if not isinstance(node, ast.Call):
54            return node
55
56        func_aliases = self.aliases[node.func]
57        if len(func_aliases) != 1:
58            return node
59
60        obj = next(iter(func_aliases))
61        if obj is not MODULES['builtins']['map']:
62            return node
63
64        if not all(isinstance(arg, (ast.List, ast.Tuple))
65                   for arg in node.args[1:]):
66            return node
67
68        mapped_func_aliases = self.aliases[node.args[0]]
69        if len(mapped_func_aliases) != 1:
70            return node
71
72        obj = next(iter(mapped_func_aliases))
73        if not isinstance(obj, (ast.FunctionDef, FunctionIntr)):
74            return node
75
76        # all preconditions are met, do it!
77        return self.inlineBuiltinsXMap(node)
78
79    def visit_Call(self, node):
80        node = self.generic_visit(node)
81        node = self.inlineBuiltinsMap(node)
82        return node
83
84    def make_array_index(self, base, size, index):
85        if isinstance(base, ast.Constant):
86            return ast.Constant(base.value, None)
87        if size == 1:
88            return deepcopy(base.elts[0])
89        return base.elts[index]
90
91    def fixedSizeArray(self, node):
92        if isinstance(node, ast.Constant):
93            return node, 1
94
95        if isinstance(node, (ast.List, ast.Tuple)):
96            return node, len(node.elts)
97
98        if not isinstance(node, ast.Call):
99            return None, 0
100
101        func_aliases = self.aliases[node.func]
102        if len(func_aliases) != 1:
103            return None, 0
104
105        obj = next(iter(func_aliases))
106        if obj not in (MODULES['numpy']['array'], MODULES['numpy']['asarray']):
107            return None, 0
108
109        if len(node.args) != 1:
110            return None, 0
111
112        if isinstance(node.args[0], (ast.List, ast.Tuple)):
113            return node.args[0], len(node.args[0].elts)
114
115        return None, 0
116
117    def inlineFixedSizeArrayBinOp(self, node):
118
119        alike = ast.List, ast.Tuple, ast.Constant
120        if isinstance(node.left, alike) and isinstance(node.right, alike):
121            return node
122
123        lbase, lsize = self.fixedSizeArray(node.left)
124        rbase, rsize = self.fixedSizeArray(node.right)
125        if not lbase or not rbase:
126            return node
127
128        if rsize != 1 and lsize != 1 and rsize != lsize:
129            raise PythranSyntaxError("Invalid numpy broadcasting", node)
130
131        self.update = True
132
133        operands = [ast.BinOp(self.make_array_index(lbase, lsize, i),
134                              type(node.op)(),
135                              self.make_array_index(rbase, rsize, i))
136                    for i in range(max(lsize, rsize))]
137        res = ast.Call(path_to_attr(('numpy', 'array')),
138                       [ast.Tuple(operands, ast.Load())],
139                       [])
140        self.aliases[res.func] = {path_to_node(('numpy', 'array'))}
141        return res
142
143    def visit_BinOp(self, node):
144        node = self.generic_visit(node)
145        node = self.inlineFixedSizeArrayBinOp(node)
146        return node
147
148    def inlineFixedSizeArrayUnaryOp(self, node):
149
150        if isinstance(node.operand, (ast.Constant, ast.List, ast.Tuple)):
151            return node
152
153        base, size = self.fixedSizeArray(node.operand)
154        if not base:
155            return node
156
157        self.update = True
158
159        operands = [ast.UnaryOp(type(node.op)(),
160                                self.make_array_index(base, size, i))
161                    for i in range(size)]
162        res = ast.Call(path_to_attr(('numpy', 'array')),
163                       [ast.Tuple(operands, ast.Load())],
164                       [])
165        self.aliases[res.func] = {path_to_node(('numpy', 'array'))}
166        return res
167
168    def visit_UnaryOp(self, node):
169        node = self.generic_visit(node)
170        node = self.inlineFixedSizeArrayUnaryOp(node)
171        return node
172
173    def inlineFixedSizeArrayCompare(self, node):
174        if len(node.comparators) != 1:
175            return node
176
177        node_right = node.comparators[0]
178
179        alike = ast.Constant, ast.List, ast.Tuple
180        if isinstance(node.left, alike) and isinstance(node_right, alike):
181            return node
182
183        lbase, lsize = self.fixedSizeArray(node.left)
184        rbase, rsize = self.fixedSizeArray(node_right)
185        if not lbase or not rbase:
186            return node
187
188        if rsize != 1 and lsize != 1 and rsize != lsize:
189            raise PythranSyntaxError("Invalid numpy broadcasting", node)
190
191        self.update = True
192
193        operands = [ast.Compare(self.make_array_index(lbase, lsize, i),
194                                [type(node.ops[0])()],
195                                [self.make_array_index(rbase, rsize, i)])
196                    for i in range(max(lsize, rsize))]
197        res = ast.Call(path_to_attr(('numpy', 'array')),
198                       [ast.Tuple(operands, ast.Load())],
199                       [])
200        self.aliases[res.func] = {path_to_node(('numpy', 'array'))}
201        return res
202
203    def visit_Compare(self, node):
204        node = self.generic_visit(node)
205        node = self.inlineFixedSizeArrayCompare(node)
206        return node
207