1"""Emit code and parser tables in Rust."""
2
3import json
4import re
5import unicodedata
6import sys
7import itertools
8import collections
9from contextlib import contextmanager
10
11from ..runtime import (ERROR, ErrorToken, SPECIAL_CASE_TAG)
12from ..ordered import OrderedSet
13
14from ..grammar import (Some, Nt, InitNt, End, ErrorSymbol)
15from ..actions import (Accept, Action, Replay, Unwind, Reduce, CheckNotOnNewLine, FilterStates,
16                       PushFlag, PopFlag, FunCall, Seq)
17
18from .. import types
19
20
21TERMINAL_NAMES = {
22    '&&=': 'LogicalAndAssign',
23    '||=': 'LogicalOrAssign',
24    '??=': 'CoalesceAssign',
25    '{': 'OpenBrace',
26    '}': 'CloseBrace',
27    '(': 'OpenParenthesis',
28    ')': 'CloseParenthesis',
29    '[': 'OpenBracket',
30    ']': 'CloseBracket',
31    '+': 'Plus',
32    '-': 'Minus',
33    '~': 'BitwiseNot',
34    '!': 'LogicalNot',
35    '++': 'Increment',
36    '--': 'Decrement',
37    ':': 'Colon',
38    '=>': 'Arrow',
39    '=': 'EqualSign',
40    '*=': 'MultiplyAssign',
41    '/=': 'DivideAssign',
42    '%=': 'RemainderAssign',
43    '+=': 'AddAssign',
44    '-=': 'SubtractAssign',
45    '<<=': 'LeftShiftAssign',
46    '>>=': 'SignedRightShiftAssign',
47    '>>>=': 'UnsignedRightShiftAssign',
48    '&=': 'BitwiseAndAssign',
49    '^=': 'BitwiseXorAssign',
50    '|=': 'BitwiseOrAssign',
51    '**=': 'ExponentiateAssign',
52    '.': 'Dot',
53    '**': 'Exponentiate',
54    '?.': 'OptionalChain',
55    '?': 'QuestionMark',
56    '??': 'Coalesce',
57    '*': 'Star',
58    '/': 'Divide',
59    '%': 'Remainder',
60    '<<': 'LeftShift',
61    '>>': 'SignedRightShift',
62    '>>>': 'UnsignedRightShift',
63    '<': 'LessThan',
64    '>': 'GreaterThan',
65    '<=': 'LessThanOrEqualTo',
66    '>=': 'GreaterThanOrEqualTo',
67    '==': 'LaxEqual',
68    '!=': 'LaxNotEqual',
69    '===': 'StrictEqual',
70    '!==': 'StrictNotEqual',
71    '&': 'BitwiseAnd',
72    '^': 'BitwiseXor',
73    '|': 'BitwiseOr',
74    '&&': 'LogicalAnd',
75    '||': 'LogicalOr',
76    ',': 'Comma',
77    '...': 'Ellipsis',
78}
79
80
81@contextmanager
82def indent(writer):
83    """This function is meant to be used with the `with` keyword of python, and
84    allow the user of it to add an indentation level to the code which is
85    enclosed in the `with` statement.
86
87    This has the advantage that the indentation of the python code is reflected
88    to the generated code when `with indent(self):` is used. """
89    writer.indent += 1
90    yield None
91    writer.indent -= 1
92
93def extract_ranges(iterator):
94    """Given a sorted iterator of integer, yield the contiguous ranges"""
95    # Identify contiguous ranges of states.
96    ranges = collections.defaultdict(list)
97    # A sorted list of contiguous integers implies that elements are separated
98    # by 1, as well as their indexes. Thus we can categorize them into buckets
99    # of contiguous integers using the base, which is the value v from which we
100    # remove the index i.
101    for i, v in enumerate(iterator):
102        ranges[v - i].append(v)
103    for l in ranges.values():
104        yield (l[0], l[-1])
105
106def rust_range(riter):
107    """Prettify a list of tuple of (min, max) of matched ranges into Rust
108    syntax."""
109    def minmax_join(rmin, rmax):
110        if rmin == rmax:
111            return str(rmin)
112        else:
113            return "{}..={}".format(rmin, rmax)
114    return " | ".join(minmax_join(rmin, rmax) for rmin, rmax in riter)
115
116class RustActionWriter:
117    """Write epsilon state transitions for a given action function."""
118    ast_builder = types.Type("AstBuilderDelegate", (types.Lifetime("alloc"),))
119
120    def __init__(self, writer, mode, traits, indent):
121        self.states = writer.states
122        self.writer = writer
123        self.mode = mode
124        self.traits = traits
125        self.indent = indent
126        self.has_ast_builder = self.ast_builder in traits
127        self.used_variables = set()
128        self.replay_args = []
129
130    def implement_trait(self, funcall):
131        "Returns True if this function call should be encoded"
132        ty = funcall.trait
133        if ty.name == "AstBuilder":
134            return "AstBuilderDelegate<'alloc>" in map(str, self.traits)
135        if ty in self.traits:
136            return True
137        if len(ty.args) == 0:
138            return ty.name in map(lambda t: t.name, self.traits)
139        return False
140
141    def reset(self, act):
142        "Traverse all action to collect preliminary information."
143        self.used_variables = set(self.collect_uses(act))
144
145    def collect_uses(self, act):
146        "Generator which visit all used variables."
147        assert isinstance(act, Action)
148        if isinstance(act, (Reduce, Unwind)):
149            yield "value"
150        elif isinstance(act, FunCall):
151            arg_offset = act.offset
152            if arg_offset < 0:
153                # See write_funcall.
154                arg_offset = 0
155            def map_with_offset(args):
156                for a in args:
157                    if isinstance(a, int):
158                        yield a + arg_offset
159                    if isinstance(a, str):
160                        yield a
161                    elif isinstance(a, Some):
162                        for offset in map_with_offset([a.inner]):
163                            yield offset
164            if self.implement_trait(act):
165                for var in map_with_offset(act.args):
166                    yield var
167        elif isinstance(act, Seq):
168            for a in act.actions:
169                for var in self.collect_uses(a):
170                    yield var
171
172    def write(self, string, *format_args):
173        "Delegate to the RustParserWriter.write function"
174        self.writer.write(self.indent, string, *format_args)
175
176    def write_state_transitions(self, state, replay_args):
177        "Given a state, generate the code corresponding to all outgoing epsilon edges."
178        try:
179            self.replay_args = replay_args
180            assert not state.is_inconsistent()
181            assert len(list(state.shifted_edges())) == 0
182            for ctx in self.writer.parse_table.debug_context(state.index, None):
183                self.write("// {}", ctx)
184            first, dest = next(state.edges(), (None, None))
185            if first is None:
186                return
187            self.reset(first)
188            if first.is_condition():
189                self.write_condition(state, first)
190            else:
191                assert len(list(state.edges())) == 1
192                self.write_action(first, dest)
193        except Exception as exc:
194            print("Error while writing code for {}\n\n".format(state))
195            self.writer.parse_table.debug_info = True
196            print(self.writer.parse_table.debug_context(state.index, "\n", "# "))
197            raise exc
198
199    def write_replay_args(self, n):
200        rp_args = self.replay_args[:n]
201        rp_stck = self.replay_args[n:]
202        for tv in rp_stck:
203            self.write("parser.replay({});", tv)
204        return rp_args
205
206
207    def write_epsilon_transition(self, dest):
208        # Replay arguments which are not accepted as input of the next state.
209        dest = self.states[dest]
210        rp_args = self.write_replay_args(dest.arguments)
211        self.write("// --> {}", dest.index)
212        if dest.index >= self.writer.shift_count:
213            self.write("{}_{}(parser{})", self.mode, dest.index, "".join(map(lambda v: ", " + v, rp_args)))
214        else:
215            assert dest.arguments == 0
216            self.write("parser.epsilon({});", dest.index)
217            self.write("Ok(false)")
218
219    def write_condition(self, state, first_act):
220        "Write code to test a conditions, and dispatch to the matching destination"
221        # NOTE: we already asserted that this state is consistent, this implies
222        # that the first state check the same variables as all remaining
223        # states. Thus we use the first action to produce the match statement.
224        assert isinstance(first_act, Action)
225        assert first_act.is_condition()
226        if isinstance(first_act, CheckNotOnNewLine):
227            # TODO: At the moment this is Action is implemented as a single
228            # operation with a single destination. However, we should implement
229            # it in the future as 2 branches, one which is verifying the lack
230            # of new lines, and one which is shifting an extra error token.
231            # This might help remove the overhead of backtracking in addition
232            # to make this backtracking visible through APS.
233            assert len(list(state.edges())) == 1
234            act, dest = next(state.edges())
235            assert len(self.replay_args) == 0
236            assert -act.offset > 0
237            self.write("// {}", str(act))
238            self.write("if !parser.check_not_on_new_line({})? {{", -act.offset)
239            with indent(self):
240                self.write("return Ok(false);")
241            self.write("}")
242            self.write_epsilon_transition(dest)
243        elif isinstance(first_act, FilterStates):
244            if len(state.epsilon) == 1:
245                # This is an attempt to avoid huge unending compilations.
246                _, dest = next(iter(state.epsilon), (None, None))
247                pattern = rust_range(extract_ranges(first_act.states))
248                self.write("// parser.top_state() in ({})", pattern)
249                self.write_epsilon_transition(dest)
250            else:
251                self.write("match parser.top_state() {")
252                with indent(self):
253                    # Consider the branch which has the largest number of
254                    # potential top-states to be most likely, and therefore the
255                    # default branch to go to if all other fail to match.
256                    default_weight = max(len(act.states) for act, dest in state.edges())
257                    default_states = []
258                    default_dest = None
259                    for act, dest in state.edges():
260                        assert first_act.check_same_variable(act)
261                        if default_dest is None and default_weight == len(act.states):
262                            # This range has the same weight as the default
263                            # branch. Ignore it and use it as the default
264                            # branch which would be generated at the end.
265                            default_states = act.states
266                            default_dest = dest
267                            continue
268                        pattern = rust_range(extract_ranges(act.states))
269                        self.write("{} => {{", pattern)
270                        with indent(self):
271                            self.write_epsilon_transition(dest)
272                        self.write("}")
273                    # Generate code for the default branch, which got skipped
274                    # while producing the loop.
275                    self.write("_ => {")
276                    with indent(self):
277                        pattern = rust_range(extract_ranges(default_states))
278                        self.write("// {}", pattern)
279                        self.write_epsilon_transition(default_dest)
280                    self.write("}")
281                self.write("}")
282        else:
283            raise ValueError("Unexpected action type")
284
285    def write_action(self, act, dest):
286        assert isinstance(act, Action)
287        assert not act.is_condition()
288        is_packed = {}
289
290        # Do not pop any of the stack elements if the reduce action has an
291        # accept function call. Ideally we should be returning the result
292        # instead of keeping it on the parser stack.
293        if act.update_stack() and not act.contains_accept():
294            stack_diff = act.update_stack_with()
295            start = 0
296            depth = stack_diff.pop
297            args = len(self.replay_args)
298            replay = stack_diff.replay
299            if replay < 0:
300                # At the moment, we do not handle having more arguments than
301                # what is being popped and replay, thus write back the extra
302                # arguments and continue.
303                if stack_diff.pop + replay < 0:
304                    self.replay_args = self.write_replay_args(replay)
305                replay = 0
306            if replay + stack_diff.pop - args > 0:
307                assert (replay >= 0 and args == 0) or \
308                    (replay == 0 and args >= 0)
309            if replay > 0:
310                # At the moment, assume that arguments are only added once we
311                # consumed all replayed terms. Thus the replay_args can only be
312                # non-empty once replay is 0. Otherwise some of the replay_args
313                # would have to be replayed.
314                assert args == 0
315                self.write("parser.rewind({});", replay)
316                start = replay
317                depth += start
318
319            inputs = []
320            for i in range(start, depth):
321                name = 's{}'.format(i + 1)
322                if i + 1 not in self.used_variables:
323                    name = '_' + name
324                inputs.append(name)
325            if stack_diff.pop > 0:
326                args_pop = min(len(self.replay_args), stack_diff.pop)
327                # Pop by moving arguments of the action function.
328                for i, name in enumerate(inputs[:args_pop]):
329                    self.write("let {} = {};", name, self.replay_args[-i - 1])
330                # Pop by removing elements from the parser stack.
331                for name in inputs[args_pop:]:
332                    self.write("let {} = parser.pop();", name)
333                if args_pop > 0:
334                    del self.replay_args[-args_pop:]
335
336        if isinstance(act, Seq):
337            for a in act.actions:
338                self.write_single_action(a, is_packed)
339                if a.contains_accept():
340                    break
341        else:
342            self.write_single_action(act, is_packed)
343
344        # If we fallthrough the execution of the action, then generate an
345        # epsilon transition.
346        if act.follow_edge() and not act.contains_accept():
347            assert 0 <= dest < self.writer.shift_count + self.writer.action_count
348            self.write_epsilon_transition(dest)
349
350    def write_single_action(self, act, is_packed):
351        self.write("// {}", str(act))
352        if isinstance(act, Replay):
353            self.write_replay(act)
354        elif isinstance(act, (Reduce, Unwind)):
355            self.write_reduce(act, is_packed)
356        elif isinstance(act, Accept):
357            self.write_accept()
358        elif isinstance(act, PushFlag):
359            raise ValueError("NYI: PushFlag action")
360        elif isinstance(act, PopFlag):
361            raise ValueError("NYI: PopFlag action")
362        elif isinstance(act, FunCall):
363            self.write_funcall(act, is_packed)
364        else:
365            raise ValueError("Unexpected action type")
366
367    def write_replay(self, act):
368        assert len(self.replay_args) == 0
369        for shift_state in act.replay_steps:
370            self.write("parser.shift_replayed({});", shift_state)
371
372    def write_reduce(self, act, is_packed):
373        value = "value"
374        if value in is_packed:
375            packed = is_packed[value]
376        else:
377            packed = False
378            value = "None"
379
380        if packed:
381            # Extract the StackValue from the packed TermValue
382            value = "{}.value".format(value)
383        elif self.has_ast_builder:
384            # Convert into a StackValue
385            value = "TryIntoStack::try_into_stack({})?".format(value)
386        else:
387            # Convert into a StackValue (when no ast-builder)
388            value = "value"
389
390        stack_diff = act.update_stack_with()
391        assert stack_diff.nt is not None
392        self.write("let term = NonterminalId::{}.into();",
393                   self.writer.nonterminal_to_camel(stack_diff.nt))
394        if value != "value":
395            self.write("let value = {};", value)
396        self.write("let reduced = TermValue { term, value };")
397        self.replay_args.append("reduced")
398
399    def write_accept(self):
400        self.write("return Ok(true);")
401
402    def write_funcall(self, act, is_packed):
403        arg_offset = act.offset
404        if arg_offset < 0:
405            # NOTE: When replacing replayed stack elements by arguments, the
406            # offset is reduced by -1, and can become negative for cases where
407            # we read the value associated with an argument instead of the
408            # value read from the stack. However, write_action shift everything
409            # as-if we had replayed all the necessary terms, and therefore
410            # variables are named as-if the offset were 0.
411            arg_offset = 0
412
413        def no_unpack(val):
414            return val
415
416        def unpack(val):
417            if val in is_packed:
418                packed = is_packed[val]
419            else:
420                packed = True
421            if packed:
422                return "{}.value.to_ast()?".format(val)
423            return val
424
425        def map_with_offset(args, unpack):
426            get_value = "s{}"
427            for a in args:
428                if isinstance(a, int):
429                    yield unpack(get_value.format(a + arg_offset))
430                elif isinstance(a, str):
431                    yield unpack(a)
432                elif isinstance(a, Some):
433                    yield "Some({})".format(next(map_with_offset([a.inner], unpack)))
434                elif a is None:
435                    yield "None"
436                else:
437                    raise ValueError(a)
438
439        packed = False
440        # If the variable is used, then generate the let binding.
441        set_var = ""
442        if act.set_to in self.used_variables:
443            set_var = "let {} = ".format(act.set_to)
444
445        # If the function cannot be call as the generated action function does
446        # not use the trait on which this function is implemented, then replace
447        # the value by `()`.
448        if not self.implement_trait(act):
449            self.write("{}();", set_var)
450            return
451
452        # NOTE: Currently "AstBuilder" is implemented through the
453        # AstBuilderDelegate which returns a mutable reference to the
454        # AstBuilder. This would call the specific special case method to get
455        # the actual AstBuilder.
456        delegate = ""
457        if str(act.trait) == "AstBuilder":
458            delegate = "ast_builder_refmut()."
459
460        # NOTE: Currently "AstBuilder" functions are made fallible
461        # using the fallible_methods taken from some Rust code
462        # which extract this information to produce a JSON file.
463        forward_errors = ""
464        if act.fallible or act.method in self.writer.fallible_methods:
465            forward_errors = "?"
466
467        # By default generate a method call, with the method name. However,
468        # there is a special case for the "id" function which is an artifact,
469        # which does not have to unpack the content of its argument.
470        value = "parser.{}{}({})".format(
471            delegate, act.method,
472            ", ".join(map_with_offset(act.args, unpack)))
473        packed = False
474        if act.method == "id":
475            assert len(act.args) == 1
476            value = next(map_with_offset(act.args, no_unpack))
477            if isinstance(act.args[0], str):
478                packed = is_packed[act.args[0]]
479            else:
480                assert isinstance(act.args[0], int)
481                packed = True
482
483        self.write("{}{}{};", set_var, value, forward_errors)
484        is_packed[act.set_to] = packed
485
486
487class RustParserWriter:
488    def __init__(self, out, pt, fallible_methods):
489        self.out = out
490        self.fallible_methods = fallible_methods
491        assert pt.exec_modes is not None
492        self.parse_table = pt
493        self.states = pt.states
494        self.shift_count = pt.count_shift_states()
495        self.action_count = pt.count_action_states()
496        self.action_from_shift_count = pt.count_action_from_shift_states()
497        self.init_state_map = pt.named_goals
498        self.terminals = list(OrderedSet(pt.terminals))
499        # This extra terminal is used to represent any ErrorySymbol transition,
500        # knowing that we assert that there is only one ErrorSymbol kind per
501        # state.
502        self.terminals.append("ErrorToken")
503        self.nonterminals = list(OrderedSet(pt.nonterminals))
504
505    def emit(self):
506        self.header()
507        self.terms_id()
508        self.shift()
509        self.error_codes()
510        self.check_camel_case()
511        self.actions()
512        self.entry()
513
514    def write(self, indentation, string, *format_args):
515        if len(format_args) == 0:
516            formatted = string
517        else:
518            formatted = string.format(*format_args)
519        self.out.write("    " * indentation + formatted + "\n")
520
521    def header(self):
522        self.write(0, "// WARNING: This file is autogenerated.")
523        self.write(0, "")
524        self.write(0, "use crate::ast_builder::AstBuilderDelegate;")
525        self.write(0, "use crate::stack_value_generated::{StackValue, TryIntoStack};")
526        self.write(0, "use crate::traits::{TermValue, ParserTrait};")
527        self.write(0, "use crate::error::Result;")
528        traits = OrderedSet()
529        for mode_traits in self.parse_table.exec_modes.values():
530            traits |= mode_traits
531        traits = list(traits)
532        traits = [ty for ty in traits if ty.name != "AstBuilderDelegate"]
533        traits = [ty for ty in traits if ty.name != "ParserTrait"]
534        if traits == []:
535            pass
536        elif len(traits) == 1:
537            self.write(0, "use crate::traits::{};", traits[0].name)
538        else:
539            self.write(0, "use crate::traits::{{{}}};", ", ".join(ty.name for ty in traits))
540        self.write(0, "")
541        self.write(0, "const ERROR: i64 = {};", hex(ERROR))
542        self.write(0, "")
543
544    def terminal_name(self, value):
545        if isinstance(value, End) or value is None:
546            return "End"
547        elif isinstance(value, ErrorSymbol) or value is ErrorToken:
548            return "ErrorToken"
549        elif value in TERMINAL_NAMES:
550            return TERMINAL_NAMES[value]
551        elif value.isalpha():
552            if value.islower():
553                return value.capitalize()
554            else:
555                return value
556        else:
557            raw_name = " ".join((unicodedata.name(c) for c in value))
558            snake_case = raw_name.replace("-", " ").replace(" ", "_").lower()
559            camel_case = self.to_camel_case(snake_case)
560            return camel_case
561
562    def terminal_name_camel(self, value):
563        return self.to_camel_case(self.terminal_name(value))
564
565    def terms_id(self):
566        self.write(0, "#[derive(Copy, Clone, Debug, PartialEq)]")
567        self.write(0, "#[repr(u32)]")
568        self.write(0, "pub enum TerminalId {")
569        for i, t in enumerate(self.terminals):
570            name = self.terminal_name(t)
571            self.write(1, "{} = {}, // {}", name, i, repr(t))
572        self.write(0, "}")
573        self.write(0, "")
574        self.write(0, "#[derive(Clone, Copy, Debug, PartialEq)]")
575        self.write(0, "#[repr(u32)]")
576        self.write(0, "pub enum NonterminalId {")
577        offset = len(self.terminals)
578        for i, nt in enumerate(self.nonterminals):
579            self.write(1, "{} = {},", self.nonterminal_to_camel(nt), i + offset)
580        self.write(0, "}")
581        self.write(0, "")
582        self.write(0, "#[derive(Clone, Copy, Debug, PartialEq)]")
583        self.write(0, "pub struct Term(u32);")
584        self.write(0, "")
585        self.write(0, "impl Term {")
586        self.write(1, "pub fn is_terminal(&self) -> bool {")
587        self.write(2, "self.0 < {}", offset)
588        self.write(1, "}")
589        self.write(1, "pub fn to_terminal(&self) -> TerminalId {")
590        self.write(2, "assert!(self.is_terminal());")
591        self.write(2, "unsafe { std::mem::transmute(self.0) }")
592        self.write(1, "}")
593        self.write(0, "}")
594        self.write(0, "")
595        self.write(0, "impl From<TerminalId> for Term {")
596        self.write(1, "fn from(t: TerminalId) -> Self {")
597        self.write(2, "Term(t as _)")
598        self.write(1, "}")
599        self.write(0, "}")
600        self.write(0, "")
601        self.write(0, "impl From<NonterminalId> for Term {")
602        self.write(1, "fn from(nt: NonterminalId) -> Self {")
603        self.write(2, "Term(nt as _)")
604        self.write(1, "}")
605        self.write(0, "}")
606        self.write(0, "")
607        self.write(0, "impl From<Term> for usize {")
608        self.write(1, "fn from(term: Term) -> Self {")
609        self.write(2, "term.0 as _")
610        self.write(1, "}")
611        self.write(0, "}")
612        self.write(0, "")
613        self.write(0, "impl From<Term> for &'static str {")
614        self.write(1, "fn from(term: Term) -> Self {")
615        self.write(2, "match term.0 {")
616        for i, t in enumerate(self.terminals):
617            self.write(3, "{} => &\"{}\",", i, repr(t))
618        for j, nt in enumerate(self.nonterminals):
619            i = j + offset
620            self.write(3, "{} => &\"{}\",", i, str(nt.name))
621        self.write(3, "_ => panic!(\"unknown Term\")", i, str(nt.name))
622        self.write(2, "}")
623        self.write(1, "}")
624        self.write(0, "}")
625        self.write(0, "")
626
627    def shift(self):
628        self.write(0, "#[rustfmt::skip]")
629        width = len(self.terminals) + len(self.nonterminals)
630        num_shifted_edges = 0
631
632        def state_get(state, t):
633            nonlocal num_shifted_edges
634            res = state.get(t, "ERROR")
635            if res == "ERROR":
636                error_symbol = state.get_error_symbol()
637                if t == "ErrorToken" and error_symbol:
638                    res = state[error_symbol]
639                    num_shifted_edges += 1
640            else:
641                num_shifted_edges += 1
642            return res
643
644        self.write(0, "static SHIFT: [i64; {}] = [", self.shift_count * width)
645        assert self.terminals[-1] == "ErrorToken"
646        for i, state in enumerate(self.states[:self.shift_count]):
647            num_shifted_edges = 0
648            self.write(1, "// {}.", i)
649            for ctx in self.parse_table.debug_context(state.index, None):
650                self.write(1, "// {}", ctx)
651            self.write(1, "{}",
652                       ' '.join("{},".format(state_get(state, t)) for t in self.terminals))
653            self.write(1, "{}",
654                       ' '.join("{},".format(state_get(state, t)) for t in self.nonterminals))
655            try:
656                assert sum(1 for _ in state.shifted_edges()) == num_shifted_edges
657            except Exception:
658                print("Some edges are not encoded.")
659                print("List of terminals: {}".format(', '.join(map(repr, self.terminals))))
660                print("List of nonterminals: {}".format(', '.join(map(repr, self.nonterminals))))
661                print("State having the issue: {}".format(str(state)))
662                raise
663        self.write(0, "];")
664        self.write(0, "")
665
666    def render_action(self, action):
667        if isinstance(action, tuple):
668            if action[0] == 'IfSameLine':
669                _, a1, a2 = action
670                if a1 is None:
671                    a1 = 'ERROR'
672                if a2 is None:
673                    a2 = 'ERROR'
674                index = self.add_special_case(
675                    "if token.is_on_new_line { %s } else { %s }"
676                    % (a2, a1))
677            else:
678                raise ValueError("unrecognized kind of special case: {!r}".format(action))
679            return SPECIAL_CASE_TAG + index
680        elif action == 'ERROR':
681            return action
682        else:
683            assert isinstance(action, int)
684            return action
685
686    def emit_special_cases(self):
687        self.write(0, "static SPECIAL_CASES: [fn(&Token) -> i64; {}] = [",
688                   len(self.special_cases))
689        for i, code in enumerate(self.special_cases):
690            self.write(1, "|token| {{ {} }},", code)
691        self.write(0, "];")
692        self.write(0, "")
693
694    def error_codes(self):
695        self.write(0, "#[derive(Clone, Copy, Debug, PartialEq)]")
696        self.write(0, "pub enum ErrorCode {")
697        error_symbols = (s.get_error_symbol() for s in self.states[:self.shift_count])
698        error_codes = (e.error_code for e in error_symbols if e is not None)
699        for error_code in OrderedSet(error_codes):
700            self.write(1, "{},", self.to_camel_case(error_code))
701        self.write(0, "}")
702        self.write(0, "")
703
704        self.write(0, "static STATE_TO_ERROR_CODE: [Option<ErrorCode>; {}] = [",
705                   self.shift_count)
706        for i, state in enumerate(self.states[:self.shift_count]):
707            error_symbol = state.get_error_symbol()
708            if error_symbol is None:
709                self.write(1, "None,")
710            else:
711                self.write(1, "// {}.", i)
712                for ctx in self.parse_table.debug_context(state.index, None):
713                    self.write(1, "// {}", ctx)
714                self.write(1, "Some(ErrorCode::{}),",
715                           self.to_camel_case(error_symbol.error_code))
716        self.write(0, "];")
717        self.write(0, "")
718
719    def nonterminal_to_snake(self, ident):
720        if isinstance(ident, Nt):
721            if isinstance(ident.name, InitNt):
722                name = "Start" + ident.name.goal.name
723            else:
724                name = ident.name
725            base_name = self.to_snek_case(name)
726            args = ''.join((("_" + self.to_snek_case(name))
727                            for name, value in ident.args if value))
728            return base_name + args
729        else:
730            assert isinstance(ident, str)
731            return self.to_snek_case(ident)
732
733    def nonterminal_to_camel(self, nt):
734        return self.to_camel_case(self.nonterminal_to_snake(nt))
735
736    def to_camel_case(self, ident):
737        if '_' in ident:
738            return ''.join(word.capitalize() for word in ident.split('_'))
739        elif ident.islower():
740            return ident.capitalize()
741        else:
742            return ident
743
744    def check_camel_case(self):
745        seen = {}
746        for nt in self.nonterminals:
747            cc = self.nonterminal_to_camel(nt)
748            if cc in seen:
749                raise ValueError("{} and {} have the same camel-case spelling ({})".format(
750                    seen[cc], nt, cc))
751            seen[cc] = nt
752
753    def to_snek_case(self, ident):
754        # https://stackoverflow.com/questions/1175208
755        s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', ident)
756        return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
757
758    def type_to_rust(self, ty, namespace="", boxed=False):
759        """
760        Convert a jsparagus type (see types.py) to Rust.
761
762        Pass boxed=True if the type needs to be boxed.
763        """
764        if isinstance(ty, types.Lifetime):
765            assert not boxed
766            rty = "'" + ty.name
767        elif ty == types.UnitType:
768            assert not boxed
769            rty = '()'
770        elif ty == types.TokenType:
771            rty = "Token"
772        elif ty.name == 'Option' and len(ty.args) == 1:
773            # We auto-translate `Box<Option<T>>` to `Option<Box<T>>` since
774            # that's basically the same thing but more efficient.
775            [arg] = ty.args
776            return 'Option<{}>'.format(self.type_to_rust(arg, namespace, boxed))
777        elif ty.name == 'Vec' and len(ty.args) == 1:
778            [arg] = ty.args
779            rty = "Vec<'alloc, {}>".format(self.type_to_rust(arg, namespace, boxed=False))
780        else:
781            if namespace == "":
782                rty = ty.name
783            else:
784                rty = namespace + '::' + ty.name
785            if ty.args:
786                rty += '<{}>'.format(', '.join(self.type_to_rust(arg, namespace, boxed)
787                                               for arg in ty.args))
788        if boxed:
789            return "Box<'alloc, {}>".format(rty)
790        else:
791            return rty
792
793    def actions(self):
794        # For each execution mode, add a corresponding function which
795        # implements various traits. The trait list is used for filtering which
796        # function is added in the generated code.
797        for mode, traits in self.parse_table.exec_modes.items():
798            action_writer = RustActionWriter(self, mode, traits, 2)
799            start_at = self.shift_count
800            end_at = start_at + self.action_from_shift_count
801            assert len(self.states[self.shift_count:]) == self.action_count
802            traits_text = ' + '.join(map(self.type_to_rust, traits))
803            table_holder_name = self.to_camel_case(mode)
804            table_holder_type = table_holder_name + "<'alloc, Handler>"
805            # As we do not have default associated types yet in Rust
806            # (rust-lang#29661), we have to peak from the parameter of the
807            # ParserTrait.
808            assert list(traits)[0].name == "ParserTrait"
809            arg_type = "TermValue<" + self.type_to_rust(list(traits)[0].args[1]) + ">"
810            self.write(0, "struct {} {{", table_holder_type)
811            self.write(1, "fns: [fn(&mut Handler) -> Result<'alloc, bool>; {}]",
812                       self.action_from_shift_count)
813            self.write(0, "}")
814            self.write(0, "impl<'alloc, Handler> {}", table_holder_type)
815            self.write(0, "where")
816            self.write(1, "Handler: {}", traits_text)
817            self.write(0, "{")
818            self.write(1, "const TABLE : {} = {} {{", table_holder_type, table_holder_name)
819            self.write(2, "fns: [")
820            for state in self.states[start_at:end_at]:
821                assert state.arguments == 0
822                self.write(3, "{}_{},", mode, state.index)
823            self.write(2, "],")
824            self.write(1, "};")
825            self.write(0, "}")
826            self.write(0, "")
827            self.write(0,
828                       "pub fn {}<'alloc, Handler>(parser: &mut Handler, state: usize) "
829                       "-> Result<'alloc, bool>",
830                       mode)
831            self.write(0, "where")
832            self.write(1, "Handler: {}", traits_text)
833            self.write(0, "{")
834            self.write(1, "{}::<'alloc, Handler>::TABLE.fns[state - {}](parser)",
835                       table_holder_name, start_at)
836            self.write(0, "}")
837            self.write(0, "")
838            for state in self.states[self.shift_count:]:
839                state_args = ""
840                for i in range(state.arguments):
841                    state_args += ", v{}: {}".format(i, arg_type)
842                replay_args = ["v{}".format(i) for i in range(state.arguments)]
843                self.write(0, "#[inline]")
844                self.write(0, "#[allow(unused)]")
845                self.write(0,
846                           "pub fn {}_{}<'alloc, Handler>(parser: &mut Handler{}) "
847                           "-> Result<'alloc, bool>",
848                           mode, state.index, state_args)
849                self.write(0, "where")
850                self.write(1, "Handler: {}", ' + '.join(map(self.type_to_rust, traits)))
851                self.write(0, "{")
852                action_writer.write_state_transitions(state, replay_args)
853                self.write(0, "}")
854
855    def entry(self):
856        self.write(0, "#[derive(Clone, Copy)]")
857        self.write(0, "pub struct ParseTable<'a> {")
858        self.write(1, "pub shift_count: usize,")
859        self.write(1, "pub action_count: usize,")
860        self.write(1, "pub action_from_shift_count: usize,")
861        self.write(1, "pub shift_table: &'a [i64],")
862        self.write(1, "pub shift_width: usize,")
863        self.write(1, "pub error_codes: &'a [Option<ErrorCode>],")
864        self.write(0, "}")
865        self.write(0, "")
866
867        self.write(0, "impl<'a> ParseTable<'a> {")
868        self.write(1, "pub fn check(&self) {")
869        self.write(2, "assert_eq!(")
870        self.write(3, "self.shift_table.len(),")
871        self.write(3, "(self.shift_count * self.shift_width) as usize")
872        self.write(2, ");")
873        self.write(1, "}")
874        self.write(0, "}")
875        self.write(0, "")
876
877        self.write(0, "pub static TABLES: ParseTable<'static> = ParseTable {")
878        self.write(1, "shift_count: {},", self.shift_count)
879        self.write(1, "action_count: {},", self.action_count)
880        self.write(1, "action_from_shift_count: {},", self.action_from_shift_count)
881        self.write(1, "shift_table: &SHIFT,")
882        self.write(1, "shift_width: {},", len(self.terminals) + len(self.nonterminals))
883        self.write(1, "error_codes: &STATE_TO_ERROR_CODE,")
884        self.write(0, "};")
885        self.write(0, "")
886
887        for init_nt, index in self.init_state_map:
888            assert init_nt.args == ()
889            self.write(0, "pub static START_STATE_{}: usize = {};",
890                       self.nonterminal_to_snake(init_nt).upper(), index)
891            self.write(0, "")
892
893
894def write_rust_parse_table(out, parse_table, handler_info):
895    if not handler_info:
896        print("WARNING: info.json is not provided", file=sys.stderr)
897        fallible_methods = []
898    else:
899        with open(handler_info, "r") as json_file:
900            handler_info_json = json.load(json_file)
901        fallible_methods = handler_info_json["fallible-methods"]
902
903    RustParserWriter(out, parse_table, fallible_methods).emit()
904