1"""Emit code and parser tables in Rust.""" 2 3import json 4import re 5import unicodedata 6import sys 7import itertools 8import collections 9from contextlib import contextmanager 10 11from ..runtime import (ERROR, ErrorToken, SPECIAL_CASE_TAG) 12from ..ordered import OrderedSet 13 14from ..grammar import (Some, Nt, InitNt, End, ErrorSymbol) 15from ..actions import (Accept, Action, Replay, Unwind, Reduce, CheckNotOnNewLine, FilterStates, 16 PushFlag, PopFlag, FunCall, Seq) 17 18from .. import types 19 20 21TERMINAL_NAMES = { 22 '&&=': 'LogicalAndAssign', 23 '||=': 'LogicalOrAssign', 24 '??=': 'CoalesceAssign', 25 '{': 'OpenBrace', 26 '}': 'CloseBrace', 27 '(': 'OpenParenthesis', 28 ')': 'CloseParenthesis', 29 '[': 'OpenBracket', 30 ']': 'CloseBracket', 31 '+': 'Plus', 32 '-': 'Minus', 33 '~': 'BitwiseNot', 34 '!': 'LogicalNot', 35 '++': 'Increment', 36 '--': 'Decrement', 37 ':': 'Colon', 38 '=>': 'Arrow', 39 '=': 'EqualSign', 40 '*=': 'MultiplyAssign', 41 '/=': 'DivideAssign', 42 '%=': 'RemainderAssign', 43 '+=': 'AddAssign', 44 '-=': 'SubtractAssign', 45 '<<=': 'LeftShiftAssign', 46 '>>=': 'SignedRightShiftAssign', 47 '>>>=': 'UnsignedRightShiftAssign', 48 '&=': 'BitwiseAndAssign', 49 '^=': 'BitwiseXorAssign', 50 '|=': 'BitwiseOrAssign', 51 '**=': 'ExponentiateAssign', 52 '.': 'Dot', 53 '**': 'Exponentiate', 54 '?.': 'OptionalChain', 55 '?': 'QuestionMark', 56 '??': 'Coalesce', 57 '*': 'Star', 58 '/': 'Divide', 59 '%': 'Remainder', 60 '<<': 'LeftShift', 61 '>>': 'SignedRightShift', 62 '>>>': 'UnsignedRightShift', 63 '<': 'LessThan', 64 '>': 'GreaterThan', 65 '<=': 'LessThanOrEqualTo', 66 '>=': 'GreaterThanOrEqualTo', 67 '==': 'LaxEqual', 68 '!=': 'LaxNotEqual', 69 '===': 'StrictEqual', 70 '!==': 'StrictNotEqual', 71 '&': 'BitwiseAnd', 72 '^': 'BitwiseXor', 73 '|': 'BitwiseOr', 74 '&&': 'LogicalAnd', 75 '||': 'LogicalOr', 76 ',': 'Comma', 77 '...': 'Ellipsis', 78} 79 80 81@contextmanager 82def indent(writer): 83 """This function is meant to be used with the `with` keyword of python, and 84 allow the user of it to add an indentation level to the code which is 85 enclosed in the `with` statement. 86 87 This has the advantage that the indentation of the python code is reflected 88 to the generated code when `with indent(self):` is used. """ 89 writer.indent += 1 90 yield None 91 writer.indent -= 1 92 93def extract_ranges(iterator): 94 """Given a sorted iterator of integer, yield the contiguous ranges""" 95 # Identify contiguous ranges of states. 96 ranges = collections.defaultdict(list) 97 # A sorted list of contiguous integers implies that elements are separated 98 # by 1, as well as their indexes. Thus we can categorize them into buckets 99 # of contiguous integers using the base, which is the value v from which we 100 # remove the index i. 101 for i, v in enumerate(iterator): 102 ranges[v - i].append(v) 103 for l in ranges.values(): 104 yield (l[0], l[-1]) 105 106def rust_range(riter): 107 """Prettify a list of tuple of (min, max) of matched ranges into Rust 108 syntax.""" 109 def minmax_join(rmin, rmax): 110 if rmin == rmax: 111 return str(rmin) 112 else: 113 return "{}..={}".format(rmin, rmax) 114 return " | ".join(minmax_join(rmin, rmax) for rmin, rmax in riter) 115 116class RustActionWriter: 117 """Write epsilon state transitions for a given action function.""" 118 ast_builder = types.Type("AstBuilderDelegate", (types.Lifetime("alloc"),)) 119 120 def __init__(self, writer, mode, traits, indent): 121 self.states = writer.states 122 self.writer = writer 123 self.mode = mode 124 self.traits = traits 125 self.indent = indent 126 self.has_ast_builder = self.ast_builder in traits 127 self.used_variables = set() 128 self.replay_args = [] 129 130 def implement_trait(self, funcall): 131 "Returns True if this function call should be encoded" 132 ty = funcall.trait 133 if ty.name == "AstBuilder": 134 return "AstBuilderDelegate<'alloc>" in map(str, self.traits) 135 if ty in self.traits: 136 return True 137 if len(ty.args) == 0: 138 return ty.name in map(lambda t: t.name, self.traits) 139 return False 140 141 def reset(self, act): 142 "Traverse all action to collect preliminary information." 143 self.used_variables = set(self.collect_uses(act)) 144 145 def collect_uses(self, act): 146 "Generator which visit all used variables." 147 assert isinstance(act, Action) 148 if isinstance(act, (Reduce, Unwind)): 149 yield "value" 150 elif isinstance(act, FunCall): 151 arg_offset = act.offset 152 if arg_offset < 0: 153 # See write_funcall. 154 arg_offset = 0 155 def map_with_offset(args): 156 for a in args: 157 if isinstance(a, int): 158 yield a + arg_offset 159 if isinstance(a, str): 160 yield a 161 elif isinstance(a, Some): 162 for offset in map_with_offset([a.inner]): 163 yield offset 164 if self.implement_trait(act): 165 for var in map_with_offset(act.args): 166 yield var 167 elif isinstance(act, Seq): 168 for a in act.actions: 169 for var in self.collect_uses(a): 170 yield var 171 172 def write(self, string, *format_args): 173 "Delegate to the RustParserWriter.write function" 174 self.writer.write(self.indent, string, *format_args) 175 176 def write_state_transitions(self, state, replay_args): 177 "Given a state, generate the code corresponding to all outgoing epsilon edges." 178 try: 179 self.replay_args = replay_args 180 assert not state.is_inconsistent() 181 assert len(list(state.shifted_edges())) == 0 182 for ctx in self.writer.parse_table.debug_context(state.index, None): 183 self.write("// {}", ctx) 184 first, dest = next(state.edges(), (None, None)) 185 if first is None: 186 return 187 self.reset(first) 188 if first.is_condition(): 189 self.write_condition(state, first) 190 else: 191 assert len(list(state.edges())) == 1 192 self.write_action(first, dest) 193 except Exception as exc: 194 print("Error while writing code for {}\n\n".format(state)) 195 self.writer.parse_table.debug_info = True 196 print(self.writer.parse_table.debug_context(state.index, "\n", "# ")) 197 raise exc 198 199 def write_replay_args(self, n): 200 rp_args = self.replay_args[:n] 201 rp_stck = self.replay_args[n:] 202 for tv in rp_stck: 203 self.write("parser.replay({});", tv) 204 return rp_args 205 206 207 def write_epsilon_transition(self, dest): 208 # Replay arguments which are not accepted as input of the next state. 209 dest = self.states[dest] 210 rp_args = self.write_replay_args(dest.arguments) 211 self.write("// --> {}", dest.index) 212 if dest.index >= self.writer.shift_count: 213 self.write("{}_{}(parser{})", self.mode, dest.index, "".join(map(lambda v: ", " + v, rp_args))) 214 else: 215 assert dest.arguments == 0 216 self.write("parser.epsilon({});", dest.index) 217 self.write("Ok(false)") 218 219 def write_condition(self, state, first_act): 220 "Write code to test a conditions, and dispatch to the matching destination" 221 # NOTE: we already asserted that this state is consistent, this implies 222 # that the first state check the same variables as all remaining 223 # states. Thus we use the first action to produce the match statement. 224 assert isinstance(first_act, Action) 225 assert first_act.is_condition() 226 if isinstance(first_act, CheckNotOnNewLine): 227 # TODO: At the moment this is Action is implemented as a single 228 # operation with a single destination. However, we should implement 229 # it in the future as 2 branches, one which is verifying the lack 230 # of new lines, and one which is shifting an extra error token. 231 # This might help remove the overhead of backtracking in addition 232 # to make this backtracking visible through APS. 233 assert len(list(state.edges())) == 1 234 act, dest = next(state.edges()) 235 assert len(self.replay_args) == 0 236 assert -act.offset > 0 237 self.write("// {}", str(act)) 238 self.write("if !parser.check_not_on_new_line({})? {{", -act.offset) 239 with indent(self): 240 self.write("return Ok(false);") 241 self.write("}") 242 self.write_epsilon_transition(dest) 243 elif isinstance(first_act, FilterStates): 244 if len(state.epsilon) == 1: 245 # This is an attempt to avoid huge unending compilations. 246 _, dest = next(iter(state.epsilon), (None, None)) 247 pattern = rust_range(extract_ranges(first_act.states)) 248 self.write("// parser.top_state() in ({})", pattern) 249 self.write_epsilon_transition(dest) 250 else: 251 self.write("match parser.top_state() {") 252 with indent(self): 253 # Consider the branch which has the largest number of 254 # potential top-states to be most likely, and therefore the 255 # default branch to go to if all other fail to match. 256 default_weight = max(len(act.states) for act, dest in state.edges()) 257 default_states = [] 258 default_dest = None 259 for act, dest in state.edges(): 260 assert first_act.check_same_variable(act) 261 if default_dest is None and default_weight == len(act.states): 262 # This range has the same weight as the default 263 # branch. Ignore it and use it as the default 264 # branch which would be generated at the end. 265 default_states = act.states 266 default_dest = dest 267 continue 268 pattern = rust_range(extract_ranges(act.states)) 269 self.write("{} => {{", pattern) 270 with indent(self): 271 self.write_epsilon_transition(dest) 272 self.write("}") 273 # Generate code for the default branch, which got skipped 274 # while producing the loop. 275 self.write("_ => {") 276 with indent(self): 277 pattern = rust_range(extract_ranges(default_states)) 278 self.write("// {}", pattern) 279 self.write_epsilon_transition(default_dest) 280 self.write("}") 281 self.write("}") 282 else: 283 raise ValueError("Unexpected action type") 284 285 def write_action(self, act, dest): 286 assert isinstance(act, Action) 287 assert not act.is_condition() 288 is_packed = {} 289 290 # Do not pop any of the stack elements if the reduce action has an 291 # accept function call. Ideally we should be returning the result 292 # instead of keeping it on the parser stack. 293 if act.update_stack() and not act.contains_accept(): 294 stack_diff = act.update_stack_with() 295 start = 0 296 depth = stack_diff.pop 297 args = len(self.replay_args) 298 replay = stack_diff.replay 299 if replay < 0: 300 # At the moment, we do not handle having more arguments than 301 # what is being popped and replay, thus write back the extra 302 # arguments and continue. 303 if stack_diff.pop + replay < 0: 304 self.replay_args = self.write_replay_args(replay) 305 replay = 0 306 if replay + stack_diff.pop - args > 0: 307 assert (replay >= 0 and args == 0) or \ 308 (replay == 0 and args >= 0) 309 if replay > 0: 310 # At the moment, assume that arguments are only added once we 311 # consumed all replayed terms. Thus the replay_args can only be 312 # non-empty once replay is 0. Otherwise some of the replay_args 313 # would have to be replayed. 314 assert args == 0 315 self.write("parser.rewind({});", replay) 316 start = replay 317 depth += start 318 319 inputs = [] 320 for i in range(start, depth): 321 name = 's{}'.format(i + 1) 322 if i + 1 not in self.used_variables: 323 name = '_' + name 324 inputs.append(name) 325 if stack_diff.pop > 0: 326 args_pop = min(len(self.replay_args), stack_diff.pop) 327 # Pop by moving arguments of the action function. 328 for i, name in enumerate(inputs[:args_pop]): 329 self.write("let {} = {};", name, self.replay_args[-i - 1]) 330 # Pop by removing elements from the parser stack. 331 for name in inputs[args_pop:]: 332 self.write("let {} = parser.pop();", name) 333 if args_pop > 0: 334 del self.replay_args[-args_pop:] 335 336 if isinstance(act, Seq): 337 for a in act.actions: 338 self.write_single_action(a, is_packed) 339 if a.contains_accept(): 340 break 341 else: 342 self.write_single_action(act, is_packed) 343 344 # If we fallthrough the execution of the action, then generate an 345 # epsilon transition. 346 if act.follow_edge() and not act.contains_accept(): 347 assert 0 <= dest < self.writer.shift_count + self.writer.action_count 348 self.write_epsilon_transition(dest) 349 350 def write_single_action(self, act, is_packed): 351 self.write("// {}", str(act)) 352 if isinstance(act, Replay): 353 self.write_replay(act) 354 elif isinstance(act, (Reduce, Unwind)): 355 self.write_reduce(act, is_packed) 356 elif isinstance(act, Accept): 357 self.write_accept() 358 elif isinstance(act, PushFlag): 359 raise ValueError("NYI: PushFlag action") 360 elif isinstance(act, PopFlag): 361 raise ValueError("NYI: PopFlag action") 362 elif isinstance(act, FunCall): 363 self.write_funcall(act, is_packed) 364 else: 365 raise ValueError("Unexpected action type") 366 367 def write_replay(self, act): 368 assert len(self.replay_args) == 0 369 for shift_state in act.replay_steps: 370 self.write("parser.shift_replayed({});", shift_state) 371 372 def write_reduce(self, act, is_packed): 373 value = "value" 374 if value in is_packed: 375 packed = is_packed[value] 376 else: 377 packed = False 378 value = "None" 379 380 if packed: 381 # Extract the StackValue from the packed TermValue 382 value = "{}.value".format(value) 383 elif self.has_ast_builder: 384 # Convert into a StackValue 385 value = "TryIntoStack::try_into_stack({})?".format(value) 386 else: 387 # Convert into a StackValue (when no ast-builder) 388 value = "value" 389 390 stack_diff = act.update_stack_with() 391 assert stack_diff.nt is not None 392 self.write("let term = NonterminalId::{}.into();", 393 self.writer.nonterminal_to_camel(stack_diff.nt)) 394 if value != "value": 395 self.write("let value = {};", value) 396 self.write("let reduced = TermValue { term, value };") 397 self.replay_args.append("reduced") 398 399 def write_accept(self): 400 self.write("return Ok(true);") 401 402 def write_funcall(self, act, is_packed): 403 arg_offset = act.offset 404 if arg_offset < 0: 405 # NOTE: When replacing replayed stack elements by arguments, the 406 # offset is reduced by -1, and can become negative for cases where 407 # we read the value associated with an argument instead of the 408 # value read from the stack. However, write_action shift everything 409 # as-if we had replayed all the necessary terms, and therefore 410 # variables are named as-if the offset were 0. 411 arg_offset = 0 412 413 def no_unpack(val): 414 return val 415 416 def unpack(val): 417 if val in is_packed: 418 packed = is_packed[val] 419 else: 420 packed = True 421 if packed: 422 return "{}.value.to_ast()?".format(val) 423 return val 424 425 def map_with_offset(args, unpack): 426 get_value = "s{}" 427 for a in args: 428 if isinstance(a, int): 429 yield unpack(get_value.format(a + arg_offset)) 430 elif isinstance(a, str): 431 yield unpack(a) 432 elif isinstance(a, Some): 433 yield "Some({})".format(next(map_with_offset([a.inner], unpack))) 434 elif a is None: 435 yield "None" 436 else: 437 raise ValueError(a) 438 439 packed = False 440 # If the variable is used, then generate the let binding. 441 set_var = "" 442 if act.set_to in self.used_variables: 443 set_var = "let {} = ".format(act.set_to) 444 445 # If the function cannot be call as the generated action function does 446 # not use the trait on which this function is implemented, then replace 447 # the value by `()`. 448 if not self.implement_trait(act): 449 self.write("{}();", set_var) 450 return 451 452 # NOTE: Currently "AstBuilder" is implemented through the 453 # AstBuilderDelegate which returns a mutable reference to the 454 # AstBuilder. This would call the specific special case method to get 455 # the actual AstBuilder. 456 delegate = "" 457 if str(act.trait) == "AstBuilder": 458 delegate = "ast_builder_refmut()." 459 460 # NOTE: Currently "AstBuilder" functions are made fallible 461 # using the fallible_methods taken from some Rust code 462 # which extract this information to produce a JSON file. 463 forward_errors = "" 464 if act.fallible or act.method in self.writer.fallible_methods: 465 forward_errors = "?" 466 467 # By default generate a method call, with the method name. However, 468 # there is a special case for the "id" function which is an artifact, 469 # which does not have to unpack the content of its argument. 470 value = "parser.{}{}({})".format( 471 delegate, act.method, 472 ", ".join(map_with_offset(act.args, unpack))) 473 packed = False 474 if act.method == "id": 475 assert len(act.args) == 1 476 value = next(map_with_offset(act.args, no_unpack)) 477 if isinstance(act.args[0], str): 478 packed = is_packed[act.args[0]] 479 else: 480 assert isinstance(act.args[0], int) 481 packed = True 482 483 self.write("{}{}{};", set_var, value, forward_errors) 484 is_packed[act.set_to] = packed 485 486 487class RustParserWriter: 488 def __init__(self, out, pt, fallible_methods): 489 self.out = out 490 self.fallible_methods = fallible_methods 491 assert pt.exec_modes is not None 492 self.parse_table = pt 493 self.states = pt.states 494 self.shift_count = pt.count_shift_states() 495 self.action_count = pt.count_action_states() 496 self.action_from_shift_count = pt.count_action_from_shift_states() 497 self.init_state_map = pt.named_goals 498 self.terminals = list(OrderedSet(pt.terminals)) 499 # This extra terminal is used to represent any ErrorySymbol transition, 500 # knowing that we assert that there is only one ErrorSymbol kind per 501 # state. 502 self.terminals.append("ErrorToken") 503 self.nonterminals = list(OrderedSet(pt.nonterminals)) 504 505 def emit(self): 506 self.header() 507 self.terms_id() 508 self.shift() 509 self.error_codes() 510 self.check_camel_case() 511 self.actions() 512 self.entry() 513 514 def write(self, indentation, string, *format_args): 515 if len(format_args) == 0: 516 formatted = string 517 else: 518 formatted = string.format(*format_args) 519 self.out.write(" " * indentation + formatted + "\n") 520 521 def header(self): 522 self.write(0, "// WARNING: This file is autogenerated.") 523 self.write(0, "") 524 self.write(0, "use crate::ast_builder::AstBuilderDelegate;") 525 self.write(0, "use crate::stack_value_generated::{StackValue, TryIntoStack};") 526 self.write(0, "use crate::traits::{TermValue, ParserTrait};") 527 self.write(0, "use crate::error::Result;") 528 traits = OrderedSet() 529 for mode_traits in self.parse_table.exec_modes.values(): 530 traits |= mode_traits 531 traits = list(traits) 532 traits = [ty for ty in traits if ty.name != "AstBuilderDelegate"] 533 traits = [ty for ty in traits if ty.name != "ParserTrait"] 534 if traits == []: 535 pass 536 elif len(traits) == 1: 537 self.write(0, "use crate::traits::{};", traits[0].name) 538 else: 539 self.write(0, "use crate::traits::{{{}}};", ", ".join(ty.name for ty in traits)) 540 self.write(0, "") 541 self.write(0, "const ERROR: i64 = {};", hex(ERROR)) 542 self.write(0, "") 543 544 def terminal_name(self, value): 545 if isinstance(value, End) or value is None: 546 return "End" 547 elif isinstance(value, ErrorSymbol) or value is ErrorToken: 548 return "ErrorToken" 549 elif value in TERMINAL_NAMES: 550 return TERMINAL_NAMES[value] 551 elif value.isalpha(): 552 if value.islower(): 553 return value.capitalize() 554 else: 555 return value 556 else: 557 raw_name = " ".join((unicodedata.name(c) for c in value)) 558 snake_case = raw_name.replace("-", " ").replace(" ", "_").lower() 559 camel_case = self.to_camel_case(snake_case) 560 return camel_case 561 562 def terminal_name_camel(self, value): 563 return self.to_camel_case(self.terminal_name(value)) 564 565 def terms_id(self): 566 self.write(0, "#[derive(Copy, Clone, Debug, PartialEq)]") 567 self.write(0, "#[repr(u32)]") 568 self.write(0, "pub enum TerminalId {") 569 for i, t in enumerate(self.terminals): 570 name = self.terminal_name(t) 571 self.write(1, "{} = {}, // {}", name, i, repr(t)) 572 self.write(0, "}") 573 self.write(0, "") 574 self.write(0, "#[derive(Clone, Copy, Debug, PartialEq)]") 575 self.write(0, "#[repr(u32)]") 576 self.write(0, "pub enum NonterminalId {") 577 offset = len(self.terminals) 578 for i, nt in enumerate(self.nonterminals): 579 self.write(1, "{} = {},", self.nonterminal_to_camel(nt), i + offset) 580 self.write(0, "}") 581 self.write(0, "") 582 self.write(0, "#[derive(Clone, Copy, Debug, PartialEq)]") 583 self.write(0, "pub struct Term(u32);") 584 self.write(0, "") 585 self.write(0, "impl Term {") 586 self.write(1, "pub fn is_terminal(&self) -> bool {") 587 self.write(2, "self.0 < {}", offset) 588 self.write(1, "}") 589 self.write(1, "pub fn to_terminal(&self) -> TerminalId {") 590 self.write(2, "assert!(self.is_terminal());") 591 self.write(2, "unsafe { std::mem::transmute(self.0) }") 592 self.write(1, "}") 593 self.write(0, "}") 594 self.write(0, "") 595 self.write(0, "impl From<TerminalId> for Term {") 596 self.write(1, "fn from(t: TerminalId) -> Self {") 597 self.write(2, "Term(t as _)") 598 self.write(1, "}") 599 self.write(0, "}") 600 self.write(0, "") 601 self.write(0, "impl From<NonterminalId> for Term {") 602 self.write(1, "fn from(nt: NonterminalId) -> Self {") 603 self.write(2, "Term(nt as _)") 604 self.write(1, "}") 605 self.write(0, "}") 606 self.write(0, "") 607 self.write(0, "impl From<Term> for usize {") 608 self.write(1, "fn from(term: Term) -> Self {") 609 self.write(2, "term.0 as _") 610 self.write(1, "}") 611 self.write(0, "}") 612 self.write(0, "") 613 self.write(0, "impl From<Term> for &'static str {") 614 self.write(1, "fn from(term: Term) -> Self {") 615 self.write(2, "match term.0 {") 616 for i, t in enumerate(self.terminals): 617 self.write(3, "{} => &\"{}\",", i, repr(t)) 618 for j, nt in enumerate(self.nonterminals): 619 i = j + offset 620 self.write(3, "{} => &\"{}\",", i, str(nt.name)) 621 self.write(3, "_ => panic!(\"unknown Term\")", i, str(nt.name)) 622 self.write(2, "}") 623 self.write(1, "}") 624 self.write(0, "}") 625 self.write(0, "") 626 627 def shift(self): 628 self.write(0, "#[rustfmt::skip]") 629 width = len(self.terminals) + len(self.nonterminals) 630 num_shifted_edges = 0 631 632 def state_get(state, t): 633 nonlocal num_shifted_edges 634 res = state.get(t, "ERROR") 635 if res == "ERROR": 636 error_symbol = state.get_error_symbol() 637 if t == "ErrorToken" and error_symbol: 638 res = state[error_symbol] 639 num_shifted_edges += 1 640 else: 641 num_shifted_edges += 1 642 return res 643 644 self.write(0, "static SHIFT: [i64; {}] = [", self.shift_count * width) 645 assert self.terminals[-1] == "ErrorToken" 646 for i, state in enumerate(self.states[:self.shift_count]): 647 num_shifted_edges = 0 648 self.write(1, "// {}.", i) 649 for ctx in self.parse_table.debug_context(state.index, None): 650 self.write(1, "// {}", ctx) 651 self.write(1, "{}", 652 ' '.join("{},".format(state_get(state, t)) for t in self.terminals)) 653 self.write(1, "{}", 654 ' '.join("{},".format(state_get(state, t)) for t in self.nonterminals)) 655 try: 656 assert sum(1 for _ in state.shifted_edges()) == num_shifted_edges 657 except Exception: 658 print("Some edges are not encoded.") 659 print("List of terminals: {}".format(', '.join(map(repr, self.terminals)))) 660 print("List of nonterminals: {}".format(', '.join(map(repr, self.nonterminals)))) 661 print("State having the issue: {}".format(str(state))) 662 raise 663 self.write(0, "];") 664 self.write(0, "") 665 666 def render_action(self, action): 667 if isinstance(action, tuple): 668 if action[0] == 'IfSameLine': 669 _, a1, a2 = action 670 if a1 is None: 671 a1 = 'ERROR' 672 if a2 is None: 673 a2 = 'ERROR' 674 index = self.add_special_case( 675 "if token.is_on_new_line { %s } else { %s }" 676 % (a2, a1)) 677 else: 678 raise ValueError("unrecognized kind of special case: {!r}".format(action)) 679 return SPECIAL_CASE_TAG + index 680 elif action == 'ERROR': 681 return action 682 else: 683 assert isinstance(action, int) 684 return action 685 686 def emit_special_cases(self): 687 self.write(0, "static SPECIAL_CASES: [fn(&Token) -> i64; {}] = [", 688 len(self.special_cases)) 689 for i, code in enumerate(self.special_cases): 690 self.write(1, "|token| {{ {} }},", code) 691 self.write(0, "];") 692 self.write(0, "") 693 694 def error_codes(self): 695 self.write(0, "#[derive(Clone, Copy, Debug, PartialEq)]") 696 self.write(0, "pub enum ErrorCode {") 697 error_symbols = (s.get_error_symbol() for s in self.states[:self.shift_count]) 698 error_codes = (e.error_code for e in error_symbols if e is not None) 699 for error_code in OrderedSet(error_codes): 700 self.write(1, "{},", self.to_camel_case(error_code)) 701 self.write(0, "}") 702 self.write(0, "") 703 704 self.write(0, "static STATE_TO_ERROR_CODE: [Option<ErrorCode>; {}] = [", 705 self.shift_count) 706 for i, state in enumerate(self.states[:self.shift_count]): 707 error_symbol = state.get_error_symbol() 708 if error_symbol is None: 709 self.write(1, "None,") 710 else: 711 self.write(1, "// {}.", i) 712 for ctx in self.parse_table.debug_context(state.index, None): 713 self.write(1, "// {}", ctx) 714 self.write(1, "Some(ErrorCode::{}),", 715 self.to_camel_case(error_symbol.error_code)) 716 self.write(0, "];") 717 self.write(0, "") 718 719 def nonterminal_to_snake(self, ident): 720 if isinstance(ident, Nt): 721 if isinstance(ident.name, InitNt): 722 name = "Start" + ident.name.goal.name 723 else: 724 name = ident.name 725 base_name = self.to_snek_case(name) 726 args = ''.join((("_" + self.to_snek_case(name)) 727 for name, value in ident.args if value)) 728 return base_name + args 729 else: 730 assert isinstance(ident, str) 731 return self.to_snek_case(ident) 732 733 def nonterminal_to_camel(self, nt): 734 return self.to_camel_case(self.nonterminal_to_snake(nt)) 735 736 def to_camel_case(self, ident): 737 if '_' in ident: 738 return ''.join(word.capitalize() for word in ident.split('_')) 739 elif ident.islower(): 740 return ident.capitalize() 741 else: 742 return ident 743 744 def check_camel_case(self): 745 seen = {} 746 for nt in self.nonterminals: 747 cc = self.nonterminal_to_camel(nt) 748 if cc in seen: 749 raise ValueError("{} and {} have the same camel-case spelling ({})".format( 750 seen[cc], nt, cc)) 751 seen[cc] = nt 752 753 def to_snek_case(self, ident): 754 # https://stackoverflow.com/questions/1175208 755 s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', ident) 756 return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 757 758 def type_to_rust(self, ty, namespace="", boxed=False): 759 """ 760 Convert a jsparagus type (see types.py) to Rust. 761 762 Pass boxed=True if the type needs to be boxed. 763 """ 764 if isinstance(ty, types.Lifetime): 765 assert not boxed 766 rty = "'" + ty.name 767 elif ty == types.UnitType: 768 assert not boxed 769 rty = '()' 770 elif ty == types.TokenType: 771 rty = "Token" 772 elif ty.name == 'Option' and len(ty.args) == 1: 773 # We auto-translate `Box<Option<T>>` to `Option<Box<T>>` since 774 # that's basically the same thing but more efficient. 775 [arg] = ty.args 776 return 'Option<{}>'.format(self.type_to_rust(arg, namespace, boxed)) 777 elif ty.name == 'Vec' and len(ty.args) == 1: 778 [arg] = ty.args 779 rty = "Vec<'alloc, {}>".format(self.type_to_rust(arg, namespace, boxed=False)) 780 else: 781 if namespace == "": 782 rty = ty.name 783 else: 784 rty = namespace + '::' + ty.name 785 if ty.args: 786 rty += '<{}>'.format(', '.join(self.type_to_rust(arg, namespace, boxed) 787 for arg in ty.args)) 788 if boxed: 789 return "Box<'alloc, {}>".format(rty) 790 else: 791 return rty 792 793 def actions(self): 794 # For each execution mode, add a corresponding function which 795 # implements various traits. The trait list is used for filtering which 796 # function is added in the generated code. 797 for mode, traits in self.parse_table.exec_modes.items(): 798 action_writer = RustActionWriter(self, mode, traits, 2) 799 start_at = self.shift_count 800 end_at = start_at + self.action_from_shift_count 801 assert len(self.states[self.shift_count:]) == self.action_count 802 traits_text = ' + '.join(map(self.type_to_rust, traits)) 803 table_holder_name = self.to_camel_case(mode) 804 table_holder_type = table_holder_name + "<'alloc, Handler>" 805 # As we do not have default associated types yet in Rust 806 # (rust-lang#29661), we have to peak from the parameter of the 807 # ParserTrait. 808 assert list(traits)[0].name == "ParserTrait" 809 arg_type = "TermValue<" + self.type_to_rust(list(traits)[0].args[1]) + ">" 810 self.write(0, "struct {} {{", table_holder_type) 811 self.write(1, "fns: [fn(&mut Handler) -> Result<'alloc, bool>; {}]", 812 self.action_from_shift_count) 813 self.write(0, "}") 814 self.write(0, "impl<'alloc, Handler> {}", table_holder_type) 815 self.write(0, "where") 816 self.write(1, "Handler: {}", traits_text) 817 self.write(0, "{") 818 self.write(1, "const TABLE : {} = {} {{", table_holder_type, table_holder_name) 819 self.write(2, "fns: [") 820 for state in self.states[start_at:end_at]: 821 assert state.arguments == 0 822 self.write(3, "{}_{},", mode, state.index) 823 self.write(2, "],") 824 self.write(1, "};") 825 self.write(0, "}") 826 self.write(0, "") 827 self.write(0, 828 "pub fn {}<'alloc, Handler>(parser: &mut Handler, state: usize) " 829 "-> Result<'alloc, bool>", 830 mode) 831 self.write(0, "where") 832 self.write(1, "Handler: {}", traits_text) 833 self.write(0, "{") 834 self.write(1, "{}::<'alloc, Handler>::TABLE.fns[state - {}](parser)", 835 table_holder_name, start_at) 836 self.write(0, "}") 837 self.write(0, "") 838 for state in self.states[self.shift_count:]: 839 state_args = "" 840 for i in range(state.arguments): 841 state_args += ", v{}: {}".format(i, arg_type) 842 replay_args = ["v{}".format(i) for i in range(state.arguments)] 843 self.write(0, "#[inline]") 844 self.write(0, "#[allow(unused)]") 845 self.write(0, 846 "pub fn {}_{}<'alloc, Handler>(parser: &mut Handler{}) " 847 "-> Result<'alloc, bool>", 848 mode, state.index, state_args) 849 self.write(0, "where") 850 self.write(1, "Handler: {}", ' + '.join(map(self.type_to_rust, traits))) 851 self.write(0, "{") 852 action_writer.write_state_transitions(state, replay_args) 853 self.write(0, "}") 854 855 def entry(self): 856 self.write(0, "#[derive(Clone, Copy)]") 857 self.write(0, "pub struct ParseTable<'a> {") 858 self.write(1, "pub shift_count: usize,") 859 self.write(1, "pub action_count: usize,") 860 self.write(1, "pub action_from_shift_count: usize,") 861 self.write(1, "pub shift_table: &'a [i64],") 862 self.write(1, "pub shift_width: usize,") 863 self.write(1, "pub error_codes: &'a [Option<ErrorCode>],") 864 self.write(0, "}") 865 self.write(0, "") 866 867 self.write(0, "impl<'a> ParseTable<'a> {") 868 self.write(1, "pub fn check(&self) {") 869 self.write(2, "assert_eq!(") 870 self.write(3, "self.shift_table.len(),") 871 self.write(3, "(self.shift_count * self.shift_width) as usize") 872 self.write(2, ");") 873 self.write(1, "}") 874 self.write(0, "}") 875 self.write(0, "") 876 877 self.write(0, "pub static TABLES: ParseTable<'static> = ParseTable {") 878 self.write(1, "shift_count: {},", self.shift_count) 879 self.write(1, "action_count: {},", self.action_count) 880 self.write(1, "action_from_shift_count: {},", self.action_from_shift_count) 881 self.write(1, "shift_table: &SHIFT,") 882 self.write(1, "shift_width: {},", len(self.terminals) + len(self.nonterminals)) 883 self.write(1, "error_codes: &STATE_TO_ERROR_CODE,") 884 self.write(0, "};") 885 self.write(0, "") 886 887 for init_nt, index in self.init_state_map: 888 assert init_nt.args == () 889 self.write(0, "pub static START_STATE_{}: usize = {};", 890 self.nonterminal_to_snake(init_nt).upper(), index) 891 self.write(0, "") 892 893 894def write_rust_parse_table(out, parse_table, handler_info): 895 if not handler_info: 896 print("WARNING: info.json is not provided", file=sys.stderr) 897 fallible_methods = [] 898 else: 899 with open(handler_info, "r") as json_file: 900 handler_info_json = json.load(json_file) 901 fallible_methods = handler_info_json["fallible-methods"] 902 903 RustParserWriter(out, parse_table, fallible_methods).emit() 904