1"""This module builds a LALR(1) transition-table for lalr_parser.py
2
3For now, shift/reduce conflicts are automatically resolved as shifts.
4"""
5
6# Author: Erez Shinan (2017)
7# Email : erezshin@gmail.com
8
9from collections import defaultdict
10
11from ..utils import classify, classify_bool, bfs, fzset, Enumerator, logger
12from ..exceptions import GrammarError
13
14from .grammar_analysis import GrammarAnalyzer, Terminal, LR0ItemSet
15from ..grammar import Rule
16
17###{standalone
18
19class Action:
20    def __init__(self, name):
21        self.name = name
22    def __str__(self):
23        return self.name
24    def __repr__(self):
25        return str(self)
26
27Shift = Action('Shift')
28Reduce = Action('Reduce')
29
30
31class ParseTable:
32    def __init__(self, states, start_states, end_states):
33        self.states = states
34        self.start_states = start_states
35        self.end_states = end_states
36
37    def serialize(self, memo):
38        tokens = Enumerator()
39        rules = Enumerator()
40
41        states = {
42            state: {tokens.get(token): ((1, arg.serialize(memo)) if action is Reduce else (0, arg))
43                    for token, (action, arg) in actions.items()}
44            for state, actions in self.states.items()
45        }
46
47        return {
48            'tokens': tokens.reversed(),
49            'states': states,
50            'start_states': self.start_states,
51            'end_states': self.end_states,
52        }
53
54    @classmethod
55    def deserialize(cls, data, memo):
56        tokens = data['tokens']
57        states = {
58            state: {tokens[token]: ((Reduce, Rule.deserialize(arg, memo)) if action==1 else (Shift, arg))
59                    for token, (action, arg) in actions.items()}
60            for state, actions in data['states'].items()
61        }
62        return cls(states, data['start_states'], data['end_states'])
63
64
65class IntParseTable(ParseTable):
66
67    @classmethod
68    def from_ParseTable(cls, parse_table):
69        enum = list(parse_table.states)
70        state_to_idx = {s:i for i,s in enumerate(enum)}
71        int_states = {}
72
73        for s, la in parse_table.states.items():
74            la = {k:(v[0], state_to_idx[v[1]]) if v[0] is Shift else v
75                  for k,v in la.items()}
76            int_states[ state_to_idx[s] ] = la
77
78
79        start_states = {start:state_to_idx[s] for start, s in parse_table.start_states.items()}
80        end_states = {start:state_to_idx[s] for start, s in parse_table.end_states.items()}
81        return cls(int_states, start_states, end_states)
82
83###}
84
85
86# digraph and traverse, see The Theory and Practice of Compiler Writing
87
88# computes F(x) = G(x) union (union { G(y) | x R y })
89# X: nodes
90# R: relation (function mapping node -> list of nodes that satisfy the relation)
91# G: set valued function
92def digraph(X, R, G):
93    F = {}
94    S = []
95    N = {}
96    for x in X:
97        N[x] = 0
98    for x in X:
99        # this is always true for the first iteration, but N[x] may be updated in traverse below
100        if N[x] == 0:
101            traverse(x, S, N, X, R, G, F)
102    return F
103
104# x: single node
105# S: stack
106# N: weights
107# X: nodes
108# R: relation (see above)
109# G: set valued function
110# F: set valued function we are computing (map of input -> output)
111def traverse(x, S, N, X, R, G, F):
112    S.append(x)
113    d = len(S)
114    N[x] = d
115    F[x] = G[x]
116    for y in R[x]:
117        if N[y] == 0:
118            traverse(y, S, N, X, R, G, F)
119        n_x = N[x]
120        assert(n_x > 0)
121        n_y = N[y]
122        assert(n_y != 0)
123        if (n_y > 0) and (n_y < n_x):
124            N[x] = n_y
125        F[x].update(F[y])
126    if N[x] == d:
127        f_x = F[x]
128        while True:
129            z = S.pop()
130            N[z] = -1
131            F[z] = f_x
132            if z == x:
133                break
134
135
136class LALR_Analyzer(GrammarAnalyzer):
137    def __init__(self, parser_conf, debug=False):
138        GrammarAnalyzer.__init__(self, parser_conf, debug)
139        self.nonterminal_transitions = []
140        self.directly_reads = defaultdict(set)
141        self.reads = defaultdict(set)
142        self.includes = defaultdict(set)
143        self.lookback = defaultdict(set)
144
145
146    def compute_lr0_states(self):
147        self.lr0_states = set()
148        # map of kernels to LR0ItemSets
149        cache = {}
150
151        def step(state):
152            _, unsat = classify_bool(state.closure, lambda rp: rp.is_satisfied)
153
154            d = classify(unsat, lambda rp: rp.next)
155            for sym, rps in d.items():
156                kernel = fzset({rp.advance(sym) for rp in rps})
157                new_state = cache.get(kernel, None)
158                if new_state is None:
159                    closure = set(kernel)
160                    for rp in kernel:
161                        if not rp.is_satisfied and not rp.next.is_term:
162                            closure |= self.expand_rule(rp.next, self.lr0_rules_by_origin)
163                    new_state = LR0ItemSet(kernel, closure)
164                    cache[kernel] = new_state
165
166                state.transitions[sym] = new_state
167                yield new_state
168
169            self.lr0_states.add(state)
170
171        for _ in bfs(self.lr0_start_states.values(), step):
172            pass
173
174    def compute_reads_relations(self):
175        # handle start state
176        for root in self.lr0_start_states.values():
177            assert(len(root.kernel) == 1)
178            for rp in root.kernel:
179                assert(rp.index == 0)
180                self.directly_reads[(root, rp.next)] = set([ Terminal('$END') ])
181
182        for state in self.lr0_states:
183            seen = set()
184            for rp in state.closure:
185                if rp.is_satisfied:
186                    continue
187                s = rp.next
188                # if s is a not a nonterminal
189                if s not in self.lr0_rules_by_origin:
190                    continue
191                if s in seen:
192                    continue
193                seen.add(s)
194                nt = (state, s)
195                self.nonterminal_transitions.append(nt)
196                dr = self.directly_reads[nt]
197                r = self.reads[nt]
198                next_state = state.transitions[s]
199                for rp2 in next_state.closure:
200                    if rp2.is_satisfied:
201                        continue
202                    s2 = rp2.next
203                    # if s2 is a terminal
204                    if s2 not in self.lr0_rules_by_origin:
205                        dr.add(s2)
206                    if s2 in self.NULLABLE:
207                        r.add((next_state, s2))
208
209    def compute_includes_lookback(self):
210        for nt in self.nonterminal_transitions:
211            state, nonterminal = nt
212            includes = []
213            lookback = self.lookback[nt]
214            for rp in state.closure:
215                if rp.rule.origin != nonterminal:
216                    continue
217                # traverse the states for rp(.rule)
218                state2 = state
219                for i in range(rp.index, len(rp.rule.expansion)):
220                    s = rp.rule.expansion[i]
221                    nt2 = (state2, s)
222                    state2 = state2.transitions[s]
223                    if nt2 not in self.reads:
224                        continue
225                    for j in range(i + 1, len(rp.rule.expansion)):
226                        if not rp.rule.expansion[j] in self.NULLABLE:
227                            break
228                    else:
229                        includes.append(nt2)
230                # state2 is at the final state for rp.rule
231                if rp.index == 0:
232                    for rp2 in state2.closure:
233                        if (rp2.rule == rp.rule) and rp2.is_satisfied:
234                            lookback.add((state2, rp2.rule))
235            for nt2 in includes:
236                self.includes[nt2].add(nt)
237
238    def compute_lookaheads(self):
239        read_sets = digraph(self.nonterminal_transitions, self.reads, self.directly_reads)
240        follow_sets = digraph(self.nonterminal_transitions, self.includes, read_sets)
241
242        for nt, lookbacks in self.lookback.items():
243            for state, rule in lookbacks:
244                for s in follow_sets[nt]:
245                    state.lookaheads[s].add(rule)
246
247    def compute_lalr1_states(self):
248        m = {}
249        reduce_reduce = []
250        for state in self.lr0_states:
251            actions = {}
252            for la, next_state in state.transitions.items():
253                actions[la] = (Shift, next_state.closure)
254            for la, rules in state.lookaheads.items():
255                if len(rules) > 1:
256                    # Try to resolve conflict based on priority
257                    p = [(r.options.priority or 0, r) for r in rules]
258                    p.sort(key=lambda r: r[0], reverse=True)
259                    best, second_best = p[:2]
260                    if best[0] > second_best[0]:
261                        rules = [best[1]]
262                    else:
263                        reduce_reduce.append((state, la, rules))
264                if la in actions:
265                    if self.debug:
266                        logger.warning('Shift/Reduce conflict for terminal %s: (resolving as shift)', la.name)
267                        logger.warning(' * %s', list(rules)[0])
268                else:
269                    actions[la] = (Reduce, list(rules)[0])
270            m[state] = { k.name: v for k, v in actions.items() }
271
272        if reduce_reduce:
273            msgs = []
274            for state, la, rules in reduce_reduce:
275                msg = 'Reduce/Reduce collision in %s between the following rules: %s' % (la, ''.join([ '\n\t- ' + str(r) for r in rules ]))
276                if self.debug:
277                    msg += '\n    collision occurred in state: {%s\n    }' % ''.join(['\n\t' + str(x) for x in state.closure])
278                msgs.append(msg)
279            raise GrammarError('\n\n'.join(msgs))
280
281        states = { k.closure: v for k, v in m.items() }
282
283        # compute end states
284        end_states = {}
285        for state in states:
286            for rp in state:
287                for start in self.lr0_start_states:
288                    if rp.rule.origin.name == ('$root_' + start) and rp.is_satisfied:
289                        assert(not start in end_states)
290                        end_states[start] = state
291
292        _parse_table = ParseTable(states, { start: state.closure for start, state in self.lr0_start_states.items() }, end_states)
293
294        if self.debug:
295            self.parse_table = _parse_table
296        else:
297            self.parse_table = IntParseTable.from_ParseTable(_parse_table)
298
299    def compute_lalr(self):
300        self.compute_lr0_states()
301        self.compute_reads_relations()
302        self.compute_includes_lookback()
303        self.compute_lookaheads()
304        self.compute_lalr1_states()
305