1import errno
2import hashlib
3import json
4import os
5import sys
6import tempfile
7import warnings
8
9from appdirs import AppDirs
10
11from rply.errors import ParserGeneratorError, ParserGeneratorWarning
12from rply.grammar import Grammar
13from rply.parser import LRParser
14from rply.utils import Counter, IdentityDict, iteritems, itervalues
15
16
17LARGE_VALUE = sys.maxsize
18
19
20class ParserGenerator(object):
21    """
22    A ParserGenerator represents a set of production rules, that define a
23    sequence of terminals and non-terminals to be replaced with a non-terminal,
24    which can be turned into a parser.
25
26    :param tokens: A list of token (non-terminal) names.
27    :param precedence: A list of tuples defining the order of operation for
28                       avoiding ambiguity, consisting of a string defining
29                       associativity (left, right or nonassoc) and a list of
30                       token names with the same associativity and level of
31                       precedence.
32    :param cache_id: A string specifying an ID for caching.
33    """
34    VERSION = 1
35
36    def __init__(self, tokens, precedence=[], cache_id=None):
37        self.tokens = tokens
38        self.productions = []
39        self.precedence = precedence
40        self.cache_id = cache_id
41        self.error_handler = None
42
43    def production(self, rule, precedence=None):
44        """
45        A decorator that defines a production rule and registers the decorated
46        function to be called with the terminals and non-terminals matched by
47        that rule.
48
49        A `rule` should consist of a name defining the non-terminal returned
50        by the decorated function and a sequence of non-terminals and terminals
51        that are supposed to be replaced::
52
53            replacing_non_terminal : ATERMINAL non_terminal
54
55        The name of the non-terminal replacing the sequence is on the left,
56        separated from the sequence by a colon. The whitespace around the colon
57        is required.
58
59        Knowing this we can define productions::
60
61            pg = ParserGenerator(['NUMBER', 'ADD'])
62
63            @pg.production('number : NUMBER')
64            def expr_number(p):
65                return BoxInt(int(p[0].getstr()))
66
67            @pg.production('expr : number ADD number')
68            def expr_add(p):
69                return BoxInt(p[0].getint() + p[2].getint())
70
71        If a state was passed to the parser, the decorated function is
72        additionally called with that state as first argument.
73        """
74        parts = rule.split()
75        production_name = parts[0]
76        if parts[1] != ":":
77            raise ParserGeneratorError("Expecting :")
78        syms = parts[2:]
79
80        def inner(func):
81            self.productions.append((production_name, syms, func, precedence))
82            return func
83        return inner
84
85    def error(self, func):
86        """
87        Sets the error handler that is called with the state (if passed to the
88        parser) and the token the parser errored on.
89
90        Currently error handlers must raise an exception. If an error handler
91        is not defined, a :exc:`rply.ParsingError` will be raised.
92        """
93        self.error_handler = func
94        return func
95
96    def compute_grammar_hash(self, g):
97        hasher = hashlib.sha1()
98        hasher.update(g.start.encode())
99        hasher.update(json.dumps(sorted(g.terminals)).encode())
100        for term, (assoc, level) in sorted(iteritems(g.precedence)):
101            hasher.update(term.encode())
102            hasher.update(assoc.encode())
103            hasher.update(bytes(level))
104        for p in g.productions:
105            hasher.update(p.name.encode())
106            hasher.update(json.dumps(p.prec).encode())
107            hasher.update(json.dumps(p.prod).encode())
108        return hasher.hexdigest()
109
110    def serialize_table(self, table):
111        return {
112            "lr_action": table.lr_action,
113            "lr_goto": table.lr_goto,
114            "sr_conflicts": table.sr_conflicts,
115            "rr_conflicts": table.rr_conflicts,
116            "default_reductions": table.default_reductions,
117            "start": table.grammar.start,
118            "terminals": sorted(table.grammar.terminals),
119            "precedence": table.grammar.precedence,
120            "productions": [
121                (p.name, p.prod, p.prec) for p in table.grammar.productions
122            ],
123        }
124
125    def data_is_valid(self, g, data):
126        if g.start != data["start"]:
127            return False
128        if sorted(g.terminals) != data["terminals"]:
129            return False
130        if sorted(g.precedence) != sorted(data["precedence"]):
131            return False
132        for key, (assoc, level) in iteritems(g.precedence):
133            if data["precedence"][key] != [assoc, level]:
134                return False
135        if len(g.productions) != len(data["productions"]):
136            return False
137        for p, (name, prod, (assoc, level)) in zip(g.productions, data["productions"]):
138            if p.name != name:
139                return False
140            if p.prod != prod:
141                return False
142            if p.prec != (assoc, level):
143                return False
144        return True
145
146    def build(self):
147        g = Grammar(self.tokens)
148
149        for level, (assoc, terms) in enumerate(self.precedence, 1):
150            for term in terms:
151                g.set_precedence(term, assoc, level)
152
153        for prod_name, syms, func, precedence in self.productions:
154            g.add_production(prod_name, syms, func, precedence)
155
156        g.set_start()
157
158        for unused_term in g.unused_terminals():
159            warnings.warn(
160                "Token %r is unused" % unused_term,
161                ParserGeneratorWarning,
162                stacklevel=2
163            )
164        for unused_prod in g.unused_productions():
165            warnings.warn(
166                "Production %r is not reachable" % unused_prod,
167                ParserGeneratorWarning,
168                stacklevel=2
169            )
170
171        g.build_lritems()
172        g.compute_first()
173        g.compute_follow()
174
175        table = None
176        if self.cache_id is not None:
177            cache_dir = AppDirs("rply").user_cache_dir
178            cache_file = os.path.join(
179                cache_dir,
180                "%s-%s-%s.json" % (
181                    self.cache_id, self.VERSION, self.compute_grammar_hash(g)
182                )
183            )
184
185            if os.path.exists(cache_file):
186                with open(cache_file) as f:
187                    data = json.load(f)
188                if self.data_is_valid(g, data):
189                    table = LRTable.from_cache(g, data)
190        if table is None:
191            table = LRTable.from_grammar(g)
192
193            if self.cache_id is not None:
194                self._write_cache(cache_dir, cache_file, table)
195
196        if table.sr_conflicts:
197            warnings.warn(
198                "%d shift/reduce conflict%s" % (
199                    len(table.sr_conflicts),
200                    "s" if len(table.sr_conflicts) > 1 else ""
201                ),
202                ParserGeneratorWarning,
203                stacklevel=2,
204            )
205        if table.rr_conflicts:
206            warnings.warn(
207                "%d reduce/reduce conflict%s" % (
208                    len(table.rr_conflicts),
209                    "s" if len(table.rr_conflicts) > 1 else ""
210                ),
211                ParserGeneratorWarning,
212                stacklevel=2,
213            )
214        return LRParser(table, self.error_handler)
215
216    def _write_cache(self, cache_dir, cache_file, table):
217        if not os.path.exists(cache_dir):
218            try:
219                os.makedirs(cache_dir, mode=0o0700)
220            except OSError as e:
221                if e.errno == errno.EROFS:
222                    return
223                raise
224
225        with tempfile.NamedTemporaryFile(dir=cache_dir, delete=False, mode="w") as f:
226            json.dump(self.serialize_table(table), f)
227        os.rename(f.name, cache_file)
228
229
230def digraph(X, R, FP):
231    N = dict.fromkeys(X, 0)
232    stack = []
233    F = {}
234    for x in X:
235        if N[x] == 0:
236            traverse(x, N, stack, F, X, R, FP)
237    return F
238
239
240def traverse(x, N, stack, F, X, R, FP):
241    stack.append(x)
242    d = len(stack)
243    N[x] = d
244    F[x] = FP(x)
245
246    rel = R(x)
247    for y in rel:
248        if N[y] == 0:
249            traverse(y, N, stack, F, X, R, FP)
250        N[x] = min(N[x], N[y])
251        for a in F.get(y, []):
252            if a not in F[x]:
253                F[x].append(a)
254    if N[x] == d:
255        N[stack[-1]] = LARGE_VALUE
256        F[stack[-1]] = F[x]
257        element = stack.pop()
258        while element != x:
259            N[stack[-1]] = LARGE_VALUE
260            F[stack[-1]] = F[x]
261            element = stack.pop()
262
263
264class LRTable(object):
265    def __init__(self, grammar, lr_action, lr_goto, default_reductions,
266                 sr_conflicts, rr_conflicts):
267        self.grammar = grammar
268        self.lr_action = lr_action
269        self.lr_goto = lr_goto
270        self.default_reductions = default_reductions
271        self.sr_conflicts = sr_conflicts
272        self.rr_conflicts = rr_conflicts
273
274    @classmethod
275    def from_cache(cls, grammar, data):
276        lr_action = [
277            dict([(str(k), v) for k, v in iteritems(action)])
278            for action in data["lr_action"]
279        ]
280        lr_goto = [
281            dict([(str(k), v) for k, v in iteritems(goto)])
282            for goto in data["lr_goto"]
283        ]
284        return LRTable(
285            grammar,
286            lr_action,
287            lr_goto,
288            data["default_reductions"],
289            data["sr_conflicts"],
290            data["rr_conflicts"]
291        )
292
293    @classmethod
294    def from_grammar(cls, grammar):
295        cidhash = IdentityDict()
296        goto_cache = {}
297        add_count = Counter()
298        C = cls.lr0_items(grammar, add_count, cidhash, goto_cache)
299
300        cls.add_lalr_lookaheads(grammar, C, add_count, cidhash, goto_cache)
301
302        lr_action = [None] * len(C)
303        lr_goto = [None] * len(C)
304        sr_conflicts = []
305        rr_conflicts = []
306        for st, I in enumerate(C):
307            st_action = {}
308            st_actionp = {}
309            st_goto = {}
310            for p in I:
311                if p.getlength() == p.lr_index + 1:
312                    if p.name == "S'":
313                        # Start symbol. Accept!
314                        st_action["$end"] = 0
315                        st_actionp["$end"] = p
316                    else:
317                        laheads = p.lookaheads[st]
318                        for a in laheads:
319                            if a in st_action:
320                                r = st_action[a]
321                                if r > 0:
322                                    sprec, slevel = grammar.productions[st_actionp[a].number].prec
323                                    rprec, rlevel = grammar.precedence.get(a, ("right", 0))
324                                    if (slevel < rlevel) or (slevel == rlevel and rprec == "left"):
325                                        st_action[a] = -p.number
326                                        st_actionp[a] = p
327                                        if not slevel and not rlevel:
328                                            sr_conflicts.append((st, repr(a), "reduce"))
329                                        grammar.productions[p.number].reduced += 1
330                                    elif not (slevel == rlevel and rprec == "nonassoc"):
331                                        if not rlevel:
332                                            sr_conflicts.append((st, repr(a), "shift"))
333                                elif r < 0:
334                                    oldp = grammar.productions[-r]
335                                    pp = grammar.productions[p.number]
336                                    if oldp.number > pp.number:
337                                        st_action[a] = -p.number
338                                        st_actionp[a] = p
339                                        chosenp, rejectp = pp, oldp
340                                        grammar.productions[p.number].reduced += 1
341                                        grammar.productions[oldp.number].reduced -= 1
342                                    else:
343                                        chosenp, rejectp = oldp, pp
344                                    rr_conflicts.append((st, repr(chosenp), repr(rejectp)))
345                                else:
346                                    raise ParserGeneratorError("Unknown conflict in state %d" % st)
347                            else:
348                                st_action[a] = -p.number
349                                st_actionp[a] = p
350                                grammar.productions[p.number].reduced += 1
351                else:
352                    i = p.lr_index
353                    a = p.prod[i + 1]
354                    if a in grammar.terminals:
355                        g = cls.lr0_goto(I, a, add_count, goto_cache)
356                        j = cidhash.get(g, -1)
357                        if j >= 0:
358                            if a in st_action:
359                                r = st_action[a]
360                                if r > 0:
361                                    if r != j:
362                                        raise ParserGeneratorError("Shift/shift conflict in state %d" % st)
363                                elif r < 0:
364                                    rprec, rlevel = grammar.productions[st_actionp[a].number].prec
365                                    sprec, slevel = grammar.precedence.get(a, ("right", 0))
366                                    if (slevel > rlevel) or (slevel == rlevel and rprec == "right"):
367                                        grammar.productions[st_actionp[a].number].reduced -= 1
368                                        st_action[a] = j
369                                        st_actionp[a] = p
370                                        if not rlevel:
371                                            sr_conflicts.append((st, repr(a), "shift"))
372                                    elif not (slevel == rlevel and rprec == "nonassoc"):
373                                        if not slevel and not rlevel:
374                                            sr_conflicts.append((st, repr(a), "reduce"))
375                                else:
376                                    raise ParserGeneratorError("Unknown conflict in state %d" % st)
377                            else:
378                                st_action[a] = j
379                                st_actionp[a] = p
380            nkeys = set()
381            for ii in I:
382                for s in ii.unique_syms:
383                    if s in grammar.nonterminals:
384                        nkeys.add(s)
385            for n in nkeys:
386                g = cls.lr0_goto(I, n, add_count, goto_cache)
387                j = cidhash.get(g, -1)
388                if j >= 0:
389                    st_goto[n] = j
390
391            lr_action[st] = st_action
392            lr_goto[st] = st_goto
393
394        default_reductions = [0] * len(lr_action)
395        for state, actions in enumerate(lr_action):
396            actions = set(itervalues(actions))
397            if len(actions) == 1 and next(iter(actions)) < 0:
398                default_reductions[state] = next(iter(actions))
399        return LRTable(grammar, lr_action, lr_goto, default_reductions, sr_conflicts, rr_conflicts)
400
401    @classmethod
402    def lr0_items(cls, grammar, add_count, cidhash, goto_cache):
403        C = [cls.lr0_closure([grammar.productions[0].lr_next], add_count)]
404        for i, I in enumerate(C):
405            cidhash[I] = i
406
407        i = 0
408        while i < len(C):
409            I = C[i]
410            i += 1
411
412            asyms = set()
413            for ii in I:
414                asyms.update(ii.unique_syms)
415            for x in asyms:
416                g = cls.lr0_goto(I, x, add_count, goto_cache)
417                if not g:
418                    continue
419                if g in cidhash:
420                    continue
421                cidhash[g] = len(C)
422                C.append(g)
423        return C
424
425    @classmethod
426    def lr0_closure(cls, I, add_count):
427        add_count.incr()
428
429        J = I[:]
430        added = True
431        while added:
432            added = False
433            for j in J:
434                for x in j.lr_after:
435                    if x.lr0_added == add_count.value:
436                        continue
437                    J.append(x.lr_next)
438                    x.lr0_added = add_count.value
439                    added = True
440        return J
441
442    @classmethod
443    def lr0_goto(cls, I, x, add_count, goto_cache):
444        s = goto_cache.setdefault(x, IdentityDict())
445
446        gs = []
447        for p in I:
448            n = p.lr_next
449            if n and n.lr_before == x:
450                s1 = s.get(n)
451                if not s1:
452                    s1 = {}
453                    s[n] = s1
454                gs.append(n)
455                s = s1
456        g = s.get("$end")
457        if not g:
458            if gs:
459                g = cls.lr0_closure(gs, add_count)
460                s["$end"] = g
461            else:
462                s["$end"] = gs
463        return g
464
465    @classmethod
466    def add_lalr_lookaheads(cls, grammar, C, add_count, cidhash, goto_cache):
467        nullable = cls.compute_nullable_nonterminals(grammar)
468        trans = cls.find_nonterminal_transitions(grammar, C)
469        readsets = cls.compute_read_sets(grammar, C, trans, nullable, add_count, cidhash, goto_cache)
470        lookd, included = cls.compute_lookback_includes(grammar, C, trans, nullable, add_count, cidhash, goto_cache)
471        followsets = cls.compute_follow_sets(trans, readsets, included)
472        cls.add_lookaheads(lookd, followsets)
473
474    @classmethod
475    def compute_nullable_nonterminals(cls, grammar):
476        nullable = set()
477        num_nullable = 0
478        while True:
479            for p in grammar.productions[1:]:
480                if p.getlength() == 0:
481                    nullable.add(p.name)
482                    continue
483                for t in p.prod:
484                    if t not in nullable:
485                        break
486                else:
487                    nullable.add(p.name)
488            if len(nullable) == num_nullable:
489                break
490            num_nullable = len(nullable)
491        return nullable
492
493    @classmethod
494    def find_nonterminal_transitions(cls, grammar, C):
495        trans = []
496        for idx, state in enumerate(C):
497            for p in state:
498                if p.lr_index < p.getlength() - 1:
499                    t = (idx, p.prod[p.lr_index + 1])
500                    if t[1] in grammar.nonterminals and t not in trans:
501                        trans.append(t)
502        return trans
503
504    @classmethod
505    def compute_read_sets(cls, grammar, C, ntrans, nullable, add_count, cidhash, goto_cache):
506        return digraph(
507            ntrans,
508            R=lambda x: cls.reads_relation(C, x, nullable, add_count, cidhash, goto_cache),
509            FP=lambda x: cls.dr_relation(grammar, C, x, nullable, add_count, goto_cache)
510        )
511
512    @classmethod
513    def compute_follow_sets(cls, ntrans, readsets, includesets):
514        return digraph(
515            ntrans,
516            R=lambda x: includesets.get(x, []),
517            FP=lambda x: readsets[x],
518        )
519
520    @classmethod
521    def dr_relation(cls, grammar, C, trans, nullable, add_count, goto_cache):
522        state, N = trans
523        terms = []
524
525        g = cls.lr0_goto(C[state], N, add_count, goto_cache)
526        for p in g:
527            if p.lr_index < p.getlength() - 1:
528                a = p.prod[p.lr_index + 1]
529                if a in grammar.terminals and a not in terms:
530                    terms.append(a)
531        if state == 0 and N == grammar.productions[0].prod[0]:
532            terms.append("$end")
533        return terms
534
535    @classmethod
536    def reads_relation(cls, C, trans, empty, add_count, cidhash, goto_cache):
537        rel = []
538        state, N = trans
539
540        g = cls.lr0_goto(C[state], N, add_count, goto_cache)
541        j = cidhash.get(g, -1)
542        for p in g:
543            if p.lr_index < p.getlength() - 1:
544                a = p.prod[p.lr_index + 1]
545                if a in empty:
546                    rel.append((j, a))
547        return rel
548
549    @classmethod
550    def compute_lookback_includes(cls, grammar, C, trans, nullable, add_count, cidhash, goto_cache):
551        lookdict = {}
552        includedict = {}
553
554        dtrans = dict.fromkeys(trans, 1)
555
556        for state, N in trans:
557            lookb = []
558            includes = []
559            for p in C[state]:
560                if p.name != N:
561                    continue
562
563                lr_index = p.lr_index
564                j = state
565                while lr_index < p.getlength() - 1:
566                    lr_index += 1
567                    t = p.prod[lr_index]
568
569                    if (j, t) in dtrans:
570                        li = lr_index + 1
571                        while li < p.getlength():
572                            if p.prod[li] in grammar.terminals:
573                                break
574                            if p.prod[li] not in nullable:
575                                break
576                            li += 1
577                        else:
578                            includes.append((j, t))
579
580                    g = cls.lr0_goto(C[j], t, add_count, goto_cache)
581                    j = cidhash.get(g, -1)
582
583                for r in C[j]:
584                    if r.name != p.name:
585                        continue
586                    if r.getlength() != p.getlength():
587                        continue
588                    i = 0
589                    while i < r.lr_index:
590                        if r.prod[i] != p.prod[i + 1]:
591                            break
592                        i += 1
593                    else:
594                        lookb.append((j, r))
595
596            for i in includes:
597                includedict.setdefault(i, []).append((state, N))
598            lookdict[state, N] = lookb
599        return lookdict, includedict
600
601    @classmethod
602    def add_lookaheads(cls, lookbacks, followset):
603        for trans, lb in iteritems(lookbacks):
604            for state, p in lb:
605                f = followset.get(trans, [])
606                laheads = p.lookaheads.setdefault(state, [])
607                for a in f:
608                    if a not in laheads:
609                        laheads.append(a)
610