1"""Data structure extracted from parsing the EDSL which are added within the
2Rust code."""
3
4from __future__ import annotations
5# mypy: disallow-untyped-defs, disallow-incomplete-defs, disallow-untyped-calls
6
7import typing
8import os
9
10from dataclasses import dataclass
11from .utils import keep_until
12from .grammar import Element, Grammar, LenientNt, NtDef, Production
13
14
15@dataclass(frozen=True)
16class ImplFor:
17    __slots__ = ['param', 'trait', 'for_type']
18    param: str
19    trait: str
20    for_type: str
21
22
23def eq_productions(grammar: Grammar, prod1: Production, prod2: Production) -> bool:
24    s1 = tuple(e for e in prod1.body if grammar.is_shifted_element(e))
25    s2 = tuple(e for e in prod2.body if grammar.is_shifted_element(e))
26    return s1 == s2
27
28
29def merge_productions(grammar: Grammar, prod1: Production, prod2: Production) -> Production:
30    # Consider all shifted elements as non-moveable elements, and insert other
31    # around these.
32    assert eq_productions(grammar, prod1, prod2)
33    l1 = list(prod1.body)
34    l2 = list(prod2.body)
35    body: typing.List[Element] = []
36    while l1 != [] and l2 != []:
37        front1 = list(keep_until(l1, grammar.is_shifted_element))
38        front2 = list(keep_until(l2, grammar.is_shifted_element))
39        assert front1[-1] == front2[-1]
40        l1 = l1[len(front1):]
41        l2 = l2[len(front2):]
42        if len(front1) == 1:
43            body = body + front2
44        elif len(front2) == 1:
45            body = body + front1
46        else:
47            raise ValueError("We do not know how to sort operations yet.")
48    return prod1.copy_with(body=body)
49
50
51@dataclass(frozen=True)
52class ExtPatch:
53    "Patch an existing grammar rule by adding Code"
54
55    prod: typing.Tuple[LenientNt, str, NtDef]
56
57    def apply_patch(
58            self,
59            filename: os.PathLike,
60            grammar: Grammar,
61            nonterminals: typing.Dict[LenientNt, NtDef]
62    ) -> None:
63        # - name: non-terminal.
64        # - namespace: ":" for syntactic or "::" for lexical. Always ":" as
65        #     defined by rust_nt_def.
66        # - nt_def: A single non-terminal definition with a single production.
67        (name, namespace, nt_def) = self.prod
68        gnt_def = nonterminals[name]
69        # Find a matching production in the grammar.
70        assert nt_def.params == gnt_def.params
71        new_rhs_list = []
72        assert len(nt_def.rhs_list) == 1
73        patch_prod = nt_def.rhs_list[0]
74        applied = False
75        for grammar_prod in gnt_def.rhs_list:
76            if eq_productions(grammar, grammar_prod, patch_prod):
77                grammar_prod = merge_productions(grammar, grammar_prod, patch_prod)
78                applied = True
79            new_rhs_list.append(grammar_prod)
80        if not applied:
81            raise ValueError("{}: Unable to find a matching production for {} in the grammar:\n {}"
82                             .format(filename, name, grammar.production_to_str(name, patch_prod)))
83        result = gnt_def.with_rhs_list(new_rhs_list)
84        nonterminals[name] = result
85
86
87@dataclass
88class GrammarExtension:
89    """A collection of grammar extensions, with added code, added traits for the
90    action functions.
91
92    """
93
94    target: None
95    grammar: typing.List[ExtPatch]
96    filename: os.PathLike
97
98    def apply_patch(
99            self,
100            grammar: Grammar,
101            nonterminals: typing.Dict[LenientNt, NtDef]
102    ) -> None:
103        # A grammar extension is composed of multiple production patches.
104        for ext in self.grammar:
105            if isinstance(ext, ExtPatch):
106                ext.apply_patch(self.filename, grammar, nonterminals)
107            else:
108                raise ValueError("Extension of type {} not yet supported.".format(ext.__class__))
109