1"""
2Replace variable that can be lazy evaluated and used only once by their full
3computation code.
4"""
5
6from pythran.analyses import LazynessAnalysis, UseDefChains, DefUseChains
7from pythran.analyses import Literals, Ancestors, Identifiers, CFG, IsAssigned
8from pythran.passmanager import Transformation
9import pythran.graph as graph
10
11from collections import defaultdict
12import gast as ast
13
14try:
15    from math import isfinite
16except ImportError:
17    from math import isinf, isnan
18
19    def isfinite(x):
20        return not isinf(x) and not isnan(x)
21
22
23class Remover(ast.NodeTransformer):
24
25    def __init__(self, nodes):
26        self.nodes = nodes
27
28    def visit_Assign(self, node):
29        if node in self.nodes:
30            to_prune = self.nodes[node]
31            node.targets = [tgt for tgt in node.targets if tgt not in to_prune]
32            if node.targets:
33                return node
34            else:
35                return ast.Pass()
36        return node
37
38
39class ForwardSubstitution(Transformation):
40
41    """
42    Replace variable that can be computed later.
43
44    >>> import gast as ast
45    >>> from pythran import passmanager, backend
46    >>> pm = passmanager.PassManager("test")
47
48    >>> node = ast.parse("def foo(): a = [2, 3]; builtins.print(a)")
49    >>> _, node = pm.apply(ForwardSubstitution, node)
50    >>> print(pm.dump(backend.Python, node))
51    def foo():
52        pass
53        builtins.print([2, 3])
54
55    >>> node = ast.parse("def foo(): a = 2; builtins.print(a + a)")
56    >>> _, node = pm.apply(ForwardSubstitution, node)
57    >>> print(pm.dump(backend.Python, node))
58    def foo():
59        a = 2
60        builtins.print((2 + 2))
61
62    >>> node = ast.parse("def foo():\\n a=b=2\\n while a: a -= 1\\n return b")
63    >>> _, node = pm.apply(ForwardSubstitution, node)
64    >>> print(pm.dump(backend.Python, node))
65    def foo():
66        a = 2
67        while a:
68            a -= 1
69        return 2
70    """
71
72    def __init__(self):
73        """ Satisfy dependencies on others analyses. """
74        super(ForwardSubstitution, self).__init__(LazynessAnalysis,
75                                                  UseDefChains,
76                                                  DefUseChains,
77                                                  Ancestors,
78                                                  CFG,
79                                                  Literals)
80        self.to_remove = None
81
82    def visit_FunctionDef(self, node):
83        self.to_remove = defaultdict(list)
84        self.locals = self.def_use_chains.locals[node]
85
86        # prune some assignment as a second phase, as an assignment could be
87        # forward-substituted several times (in the case of constants)
88        self.generic_visit(node)
89        Remover(self.to_remove).visit(node)
90        return node
91
92    def visit_Name(self, node):
93        if not isinstance(node.ctx, ast.Load):
94            return node
95
96        # OpenMP metdata are not handled by beniget, which is fine in our case
97        if node not in self.use_def_chains:
98            if __debug__:
99                from pythran.openmp import OMPDirective
100                assert any(isinstance(p, OMPDirective)
101                           for p in self.ancestors[node])
102            return node
103        defuses = self.use_def_chains[node]
104
105        if len(defuses) != 1:
106            return node
107
108        defuse = defuses[0]
109
110        dnode = defuse.node
111        if not isinstance(dnode, ast.Name):
112            return node
113
114        # multiple definition, which one should we forward?
115        if sum(1 for d in self.locals if d.name() == dnode.id) > 1:
116            return node
117
118        # either a constant or a value
119        fwd = (dnode in self.literals and
120               isfinite(self.lazyness_analysis[dnode.id]))
121        fwd |= self.lazyness_analysis[dnode.id] == 1
122
123        if not fwd:
124            return node
125
126        parent = self.ancestors[dnode][-1]
127        if isinstance(parent, ast.Assign):
128            value = parent.value
129            if dnode in self.literals:
130                self.update = True
131                if len(defuse.users()) == 1:
132                    self.to_remove[parent].append(dnode)
133                    return value
134                else:
135                    # FIXME: deepcopy here creates an unknown node
136                    # for alias computations
137                    return value
138            elif len(parent.targets) == 1:
139                ids = self.gather(Identifiers, value)
140                node_stmt = next(reversed([s for s in self.ancestors[node]
141                                 if isinstance(s, ast.stmt)]))
142                all_paths = graph.all_simple_paths(self.cfg, parent, node_stmt)
143                for path in all_paths:
144                    for stmt in path[1:-1]:
145                        assigned_ids = {n.id
146                                        for n in self.gather(IsAssigned, stmt)}
147                        if not ids.isdisjoint(assigned_ids):
148                            break
149                    else:
150                        continue
151                    break
152                else:
153                    self.update = True
154                    self.to_remove[parent].append(dnode)
155                    return value
156
157        return node
158