1"""Data structure extracted from parsing the EDSL which are added within the 2Rust code.""" 3 4from __future__ import annotations 5# mypy: disallow-untyped-defs, disallow-incomplete-defs, disallow-untyped-calls 6 7import typing 8import os 9 10from dataclasses import dataclass 11from .utils import keep_until 12from .grammar import Element, Grammar, LenientNt, NtDef, Production 13 14 15@dataclass(frozen=True) 16class ImplFor: 17 __slots__ = ['param', 'trait', 'for_type'] 18 param: str 19 trait: str 20 for_type: str 21 22 23def eq_productions(grammar: Grammar, prod1: Production, prod2: Production) -> bool: 24 s1 = tuple(e for e in prod1.body if grammar.is_shifted_element(e)) 25 s2 = tuple(e for e in prod2.body if grammar.is_shifted_element(e)) 26 return s1 == s2 27 28 29def merge_productions(grammar: Grammar, prod1: Production, prod2: Production) -> Production: 30 # Consider all shifted elements as non-moveable elements, and insert other 31 # around these. 32 assert eq_productions(grammar, prod1, prod2) 33 l1 = list(prod1.body) 34 l2 = list(prod2.body) 35 body: typing.List[Element] = [] 36 while l1 != [] and l2 != []: 37 front1 = list(keep_until(l1, grammar.is_shifted_element)) 38 front2 = list(keep_until(l2, grammar.is_shifted_element)) 39 assert front1[-1] == front2[-1] 40 l1 = l1[len(front1):] 41 l2 = l2[len(front2):] 42 if len(front1) == 1: 43 body = body + front2 44 elif len(front2) == 1: 45 body = body + front1 46 else: 47 raise ValueError("We do not know how to sort operations yet.") 48 return prod1.copy_with(body=body) 49 50 51@dataclass(frozen=True) 52class ExtPatch: 53 "Patch an existing grammar rule by adding Code" 54 55 prod: typing.Tuple[LenientNt, str, NtDef] 56 57 def apply_patch( 58 self, 59 filename: os.PathLike, 60 grammar: Grammar, 61 nonterminals: typing.Dict[LenientNt, NtDef] 62 ) -> None: 63 # - name: non-terminal. 64 # - namespace: ":" for syntactic or "::" for lexical. Always ":" as 65 # defined by rust_nt_def. 66 # - nt_def: A single non-terminal definition with a single production. 67 (name, namespace, nt_def) = self.prod 68 gnt_def = nonterminals[name] 69 # Find a matching production in the grammar. 70 assert nt_def.params == gnt_def.params 71 new_rhs_list = [] 72 assert len(nt_def.rhs_list) == 1 73 patch_prod = nt_def.rhs_list[0] 74 applied = False 75 for grammar_prod in gnt_def.rhs_list: 76 if eq_productions(grammar, grammar_prod, patch_prod): 77 grammar_prod = merge_productions(grammar, grammar_prod, patch_prod) 78 applied = True 79 new_rhs_list.append(grammar_prod) 80 if not applied: 81 raise ValueError("{}: Unable to find a matching production for {} in the grammar:\n {}" 82 .format(filename, name, grammar.production_to_str(name, patch_prod))) 83 result = gnt_def.with_rhs_list(new_rhs_list) 84 nonterminals[name] = result 85 86 87@dataclass 88class GrammarExtension: 89 """A collection of grammar extensions, with added code, added traits for the 90 action functions. 91 92 """ 93 94 target: None 95 grammar: typing.List[ExtPatch] 96 filename: os.PathLike 97 98 def apply_patch( 99 self, 100 grammar: Grammar, 101 nonterminals: typing.Dict[LenientNt, NtDef] 102 ) -> None: 103 # A grammar extension is composed of multiple production patches. 104 for ext in self.grammar: 105 if isinstance(ext, ExtPatch): 106 ext.apply_patch(self.filename, grammar, nonterminals) 107 else: 108 raise ValueError("Extension of type {} not yet supported.".format(ext.__class__)) 109