1from collections import Counter, defaultdict
2
3from ..utils import bfs, fzset, classify
4from ..exceptions import GrammarError
5from ..grammar import Rule, Terminal, NonTerminal
6
7
8class RulePtr(object):
9    __slots__ = ('rule', 'index')
10
11    def __init__(self, rule, index):
12        assert isinstance(rule, Rule)
13        assert index <= len(rule.expansion)
14        self.rule = rule
15        self.index = index
16
17    def __repr__(self):
18        before = [x.name for x in self.rule.expansion[:self.index]]
19        after = [x.name for x in self.rule.expansion[self.index:]]
20        return '<%s : %s * %s>' % (self.rule.origin.name, ' '.join(before), ' '.join(after))
21
22    @property
23    def next(self):
24        return self.rule.expansion[self.index]
25
26    def advance(self, sym):
27        assert self.next == sym
28        return RulePtr(self.rule, self.index+1)
29
30    @property
31    def is_satisfied(self):
32        return self.index == len(self.rule.expansion)
33
34    def __eq__(self, other):
35        return self.rule == other.rule and self.index == other.index
36    def __hash__(self):
37        return hash((self.rule, self.index))
38
39
40# state generation ensures no duplicate LR0ItemSets
41class LR0ItemSet(object):
42    __slots__ = ('kernel', 'closure', 'transitions', 'lookaheads')
43
44    def __init__(self, kernel, closure):
45        self.kernel = fzset(kernel)
46        self.closure = fzset(closure)
47        self.transitions = {}
48        self.lookaheads = defaultdict(set)
49
50    def __repr__(self):
51        return '{%s | %s}' % (', '.join([repr(r) for r in self.kernel]), ', '.join([repr(r) for r in self.closure]))
52
53
54def update_set(set1, set2):
55    if not set2 or set1 > set2:
56        return False
57
58    copy = set(set1)
59    set1 |= set2
60    return set1 != copy
61
62def calculate_sets(rules):
63    """Calculate FOLLOW sets.
64
65    Adapted from: http://lara.epfl.ch/w/cc09:algorithm_for_first_and_follow_sets"""
66    symbols = {sym for rule in rules for sym in rule.expansion} | {rule.origin for rule in rules}
67
68    # foreach grammar rule X ::= Y(1) ... Y(k)
69    # if k=0 or {Y(1),...,Y(k)} subset of NULLABLE then
70    #   NULLABLE = NULLABLE union {X}
71    # for i = 1 to k
72    #   if i=1 or {Y(1),...,Y(i-1)} subset of NULLABLE then
73    #     FIRST(X) = FIRST(X) union FIRST(Y(i))
74    #   for j = i+1 to k
75    #     if i=k or {Y(i+1),...Y(k)} subset of NULLABLE then
76    #       FOLLOW(Y(i)) = FOLLOW(Y(i)) union FOLLOW(X)
77    #     if i+1=j or {Y(i+1),...,Y(j-1)} subset of NULLABLE then
78    #       FOLLOW(Y(i)) = FOLLOW(Y(i)) union FIRST(Y(j))
79    # until none of NULLABLE,FIRST,FOLLOW changed in last iteration
80
81    NULLABLE = set()
82    FIRST = {}
83    FOLLOW = {}
84    for sym in symbols:
85        FIRST[sym]={sym} if sym.is_term else set()
86        FOLLOW[sym]=set()
87
88    # Calculate NULLABLE and FIRST
89    changed = True
90    while changed:
91        changed = False
92
93        for rule in rules:
94            if set(rule.expansion) <= NULLABLE:
95                if update_set(NULLABLE, {rule.origin}):
96                    changed = True
97
98            for i, sym in enumerate(rule.expansion):
99                if set(rule.expansion[:i]) <= NULLABLE:
100                    if update_set(FIRST[rule.origin], FIRST[sym]):
101                        changed = True
102                else:
103                    break
104
105    # Calculate FOLLOW
106    changed = True
107    while changed:
108        changed = False
109
110        for rule in rules:
111            for i, sym in enumerate(rule.expansion):
112                if i==len(rule.expansion)-1 or set(rule.expansion[i+1:]) <= NULLABLE:
113                    if update_set(FOLLOW[sym], FOLLOW[rule.origin]):
114                        changed = True
115
116                for j in range(i+1, len(rule.expansion)):
117                    if set(rule.expansion[i+1:j]) <= NULLABLE:
118                        if update_set(FOLLOW[sym], FIRST[rule.expansion[j]]):
119                            changed = True
120
121    return FIRST, FOLLOW, NULLABLE
122
123
124class GrammarAnalyzer(object):
125    def __init__(self, parser_conf, debug=False):
126        self.debug = debug
127
128        root_rules = {start: Rule(NonTerminal('$root_' + start), [NonTerminal(start), Terminal('$END')])
129                      for start in parser_conf.start}
130
131        rules = parser_conf.rules + list(root_rules.values())
132        self.rules_by_origin = classify(rules, lambda r: r.origin)
133
134        if len(rules) != len(set(rules)):
135            duplicates = [item for item, count in Counter(rules).items() if count > 1]
136            raise GrammarError("Rules defined twice: %s" % ', '.join(str(i) for i in duplicates))
137
138        for r in rules:
139            for sym in r.expansion:
140                if not (sym.is_term or sym in self.rules_by_origin):
141                    raise GrammarError("Using an undefined rule: %s" % sym)
142
143        self.start_states = {start: self.expand_rule(root_rule.origin)
144                             for start, root_rule in root_rules.items()}
145
146        self.end_states = {start: fzset({RulePtr(root_rule, len(root_rule.expansion))})
147                           for start, root_rule in root_rules.items()}
148
149        lr0_root_rules = {start: Rule(NonTerminal('$root_' + start), [NonTerminal(start)])
150                for start in parser_conf.start}
151
152        lr0_rules = parser_conf.rules + list(lr0_root_rules.values())
153        assert(len(lr0_rules) == len(set(lr0_rules)))
154
155        self.lr0_rules_by_origin = classify(lr0_rules, lambda r: r.origin)
156
157        # cache RulePtr(r, 0) in r (no duplicate RulePtr objects)
158        self.lr0_start_states = {start: LR0ItemSet([RulePtr(root_rule, 0)], self.expand_rule(root_rule.origin, self.lr0_rules_by_origin))
159                for start, root_rule in lr0_root_rules.items()}
160
161        self.FIRST, self.FOLLOW, self.NULLABLE = calculate_sets(rules)
162
163    def expand_rule(self, source_rule, rules_by_origin=None):
164        "Returns all init_ptrs accessible by rule (recursive)"
165
166        if rules_by_origin is None:
167            rules_by_origin = self.rules_by_origin
168
169        init_ptrs = set()
170        def _expand_rule(rule):
171            assert not rule.is_term, rule
172
173            for r in rules_by_origin[rule]:
174                init_ptr = RulePtr(r, 0)
175                init_ptrs.add(init_ptr)
176
177                if r.expansion: # if not empty rule
178                    new_r = init_ptr.next
179                    if not new_r.is_term:
180                        yield new_r
181
182        for _ in bfs([source_rule], _expand_rule):
183            pass
184
185        return fzset(init_ptrs)
186