1import ast
2import contextlib
3import re
4from abc import abstractmethod
5from typing import (
6    IO,
7    AbstractSet,
8    Any,
9    Dict,
10    Iterable,
11    Iterator,
12    List,
13    Optional,
14    Set,
15    Text,
16    Tuple,
17    Union,
18)
19
20from pegen import sccutils
21from pegen.grammar import (
22    Alt,
23    Cut,
24    Forced,
25    Gather,
26    Grammar,
27    GrammarError,
28    GrammarVisitor,
29    Group,
30    Lookahead,
31    NamedItem,
32    NameLeaf,
33    Opt,
34    Plain,
35    Repeat0,
36    Repeat1,
37    Rhs,
38    Rule,
39    StringLeaf,
40)
41
42
43class RuleCollectorVisitor(GrammarVisitor):
44    """Visitor that invokes a provieded callmaker visitor with just the NamedItem nodes"""
45
46    def __init__(self, rules: Dict[str, Rule], callmakervisitor: GrammarVisitor) -> None:
47        self.rulses = rules
48        self.callmaker = callmakervisitor
49
50    def visit_Rule(self, rule: Rule) -> None:
51        self.visit(rule.flatten())
52
53    def visit_NamedItem(self, item: NamedItem) -> None:
54        self.callmaker.visit(item)
55
56
57class KeywordCollectorVisitor(GrammarVisitor):
58    """Visitor that collects all the keywods and soft keywords in the Grammar"""
59
60    def __init__(self, gen: "ParserGenerator", keywords: Dict[str, int], soft_keywords: Set[str]):
61        self.generator = gen
62        self.keywords = keywords
63        self.soft_keywords = soft_keywords
64
65    def visit_StringLeaf(self, node: StringLeaf) -> None:
66        val = ast.literal_eval(node.value)
67        if re.match(r"[a-zA-Z_]\w*\Z", val):  # This is a keyword
68            if node.value.endswith("'") and node.value not in self.keywords:
69                self.keywords[val] = self.generator.keyword_type()
70            else:
71                return self.soft_keywords.add(node.value.replace('"', ""))
72
73
74class RuleCheckingVisitor(GrammarVisitor):
75    def __init__(self, rules: Dict[str, Rule], tokens: Set[str]):
76        self.rules = rules
77        self.tokens = tokens
78
79    def visit_NameLeaf(self, node: NameLeaf) -> None:
80        if node.value not in self.rules and node.value not in self.tokens:
81            raise GrammarError(f"Dangling reference to rule {node.value!r}")
82
83    def visit_NamedItem(self, node: NamedItem) -> None:
84        if node.name and node.name.startswith("_"):
85            raise GrammarError(f"Variable names cannot start with underscore: '{node.name}'")
86        self.visit(node.item)
87
88
89class ParserGenerator:
90
91    callmakervisitor: GrammarVisitor
92
93    def __init__(self, grammar: Grammar, tokens: Set[str], file: Optional[IO[Text]]):
94        self.grammar = grammar
95        self.tokens = tokens
96        self.keywords: Dict[str, int] = {}
97        self.soft_keywords: Set[str] = set()
98        self.rules = grammar.rules
99        self.validate_rule_names()
100        if "trailer" not in grammar.metas and "start" not in self.rules:
101            raise GrammarError("Grammar without a trailer must have a 'start' rule")
102        checker = RuleCheckingVisitor(self.rules, self.tokens)
103        for rule in self.rules.values():
104            checker.visit(rule)
105        self.file = file
106        self.level = 0
107        self.first_graph, self.first_sccs = compute_left_recursives(self.rules)
108        self.counter = 0  # For name_rule()/name_loop()
109        self.keyword_counter = 499  # For keyword_type()
110        self.all_rules: Dict[str, Rule] = self.rules.copy()  # Rules + temporal rules
111        self._local_variable_stack: List[List[str]] = []
112
113    def validate_rule_names(self) -> None:
114        for rule in self.rules:
115            if rule.startswith("_"):
116                raise GrammarError(f"Rule names cannot start with underscore: '{rule}'")
117
118    @contextlib.contextmanager
119    def local_variable_context(self) -> Iterator[None]:
120        self._local_variable_stack.append([])
121        yield
122        self._local_variable_stack.pop()
123
124    @property
125    def local_variable_names(self) -> List[str]:
126        return self._local_variable_stack[-1]
127
128    @abstractmethod
129    def generate(self, filename: str) -> None:
130        raise NotImplementedError
131
132    @contextlib.contextmanager
133    def indent(self) -> Iterator[None]:
134        self.level += 1
135        try:
136            yield
137        finally:
138            self.level -= 1
139
140    def print(self, *args: object) -> None:
141        if not args:
142            print(file=self.file)
143        else:
144            print("    " * self.level, end="", file=self.file)
145            print(*args, file=self.file)
146
147    def printblock(self, lines: str) -> None:
148        for line in lines.splitlines():
149            self.print(line)
150
151    def collect_rules(self) -> None:
152        keyword_collector = KeywordCollectorVisitor(self, self.keywords, self.soft_keywords)
153        for rule in self.all_rules.values():
154            keyword_collector.visit(rule)
155
156        rule_collector = RuleCollectorVisitor(self.rules, self.callmakervisitor)
157        done: Set[str] = set()
158        while True:
159            computed_rules = list(self.all_rules)
160            todo = [i for i in computed_rules if i not in done]
161            if not todo:
162                break
163            done = set(self.all_rules)
164            for rulename in todo:
165                rule_collector.visit(self.all_rules[rulename])
166
167    def keyword_type(self) -> int:
168        self.keyword_counter += 1
169        return self.keyword_counter
170
171    def artifical_rule_from_rhs(self, rhs: Rhs) -> str:
172        self.counter += 1
173        name = f"_tmp_{self.counter}"  # TODO: Pick a nicer name.
174        self.all_rules[name] = Rule(name, None, rhs)
175        return name
176
177    def artificial_rule_from_repeat(self, node: Plain, is_repeat1: bool) -> str:
178        self.counter += 1
179        if is_repeat1:
180            prefix = "_loop1_"
181        else:
182            prefix = "_loop0_"
183        name = f"{prefix}{self.counter}"
184        self.all_rules[name] = Rule(name, None, Rhs([Alt([NamedItem(None, node)])]))
185        return name
186
187    def artifical_rule_from_gather(self, node: Gather) -> str:
188        self.counter += 1
189        name = f"_gather_{self.counter}"
190        self.counter += 1
191        extra_function_name = f"_loop0_{self.counter}"
192        extra_function_alt = Alt(
193            [NamedItem(None, node.separator), NamedItem("elem", node.node)],
194            action="elem",
195        )
196        self.all_rules[extra_function_name] = Rule(
197            extra_function_name,
198            None,
199            Rhs([extra_function_alt]),
200        )
201        alt = Alt(
202            [NamedItem("elem", node.node), NamedItem("seq", NameLeaf(extra_function_name))],
203        )
204        self.all_rules[name] = Rule(
205            name,
206            None,
207            Rhs([alt]),
208        )
209        return name
210
211    def dedupe(self, name: str) -> str:
212        origname = name
213        counter = 0
214        while name in self.local_variable_names:
215            counter += 1
216            name = f"{origname}_{counter}"
217        self.local_variable_names.append(name)
218        return name
219
220
221class NullableVisitor(GrammarVisitor):
222    def __init__(self, rules: Dict[str, Rule]) -> None:
223        self.rules = rules
224        self.visited: Set[Any] = set()
225        self.nullables: Set[Union[Rule, NamedItem]] = set()
226
227    def visit_Rule(self, rule: Rule) -> bool:
228        if rule in self.visited:
229            return False
230        self.visited.add(rule)
231        if self.visit(rule.rhs):
232            self.nullables.add(rule)
233        return rule in self.nullables
234
235    def visit_Rhs(self, rhs: Rhs) -> bool:
236        for alt in rhs.alts:
237            if self.visit(alt):
238                return True
239        return False
240
241    def visit_Alt(self, alt: Alt) -> bool:
242        for item in alt.items:
243            if not self.visit(item):
244                return False
245        return True
246
247    def visit_Forced(self, force: Forced) -> bool:
248        return True
249
250    def visit_LookAhead(self, lookahead: Lookahead) -> bool:
251        return True
252
253    def visit_Opt(self, opt: Opt) -> bool:
254        return True
255
256    def visit_Repeat0(self, repeat: Repeat0) -> bool:
257        return True
258
259    def visit_Repeat1(self, repeat: Repeat1) -> bool:
260        return False
261
262    def visit_Gather(self, gather: Gather) -> bool:
263        return False
264
265    def visit_Cut(self, cut: Cut) -> bool:
266        return False
267
268    def visit_Group(self, group: Group) -> bool:
269        return self.visit(group.rhs)
270
271    def visit_NamedItem(self, item: NamedItem) -> bool:
272        if self.visit(item.item):
273            self.nullables.add(item)
274        return item in self.nullables
275
276    def visit_NameLeaf(self, node: NameLeaf) -> bool:
277        if node.value in self.rules:
278            return self.visit(self.rules[node.value])
279        # Token or unknown; never empty.
280        return False
281
282    def visit_StringLeaf(self, node: StringLeaf) -> bool:
283        # The string token '' is considered empty.
284        return not node.value
285
286
287def compute_nullables(rules: Dict[str, Rule]) -> Set[Any]:
288    """Compute which rules in a grammar are nullable.
289
290    Thanks to TatSu (tatsu/leftrec.py) for inspiration.
291    """
292    nullable_visitor = NullableVisitor(rules)
293    for rule in rules.values():
294        nullable_visitor.visit(rule)
295    return nullable_visitor.nullables
296
297
298class InitialNamesVisitor(GrammarVisitor):
299    def __init__(self, rules: Dict[str, Rule]) -> None:
300        self.rules = rules
301        self.nullables = compute_nullables(rules)
302
303    def generic_visit(self, node: Iterable[Any], *args: Any, **kwargs: Any) -> Set[Any]:
304        names: Set[str] = set()
305        for value in node:
306            if isinstance(value, list):
307                for item in value:
308                    names |= self.visit(item, *args, **kwargs)
309            else:
310                names |= self.visit(value, *args, **kwargs)
311        return names
312
313    def visit_Alt(self, alt: Alt) -> Set[Any]:
314        names: Set[str] = set()
315        for item in alt.items:
316            names |= self.visit(item)
317            if item not in self.nullables:
318                break
319        return names
320
321    def visit_Forced(self, force: Forced) -> Set[Any]:
322        return set()
323
324    def visit_LookAhead(self, lookahead: Lookahead) -> Set[Any]:
325        return set()
326
327    def visit_Cut(self, cut: Cut) -> Set[Any]:
328        return set()
329
330    def visit_NameLeaf(self, node: NameLeaf) -> Set[Any]:
331        return {node.value}
332
333    def visit_StringLeaf(self, node: StringLeaf) -> Set[Any]:
334        return set()
335
336
337def compute_left_recursives(
338    rules: Dict[str, Rule]
339) -> Tuple[Dict[str, AbstractSet[str]], List[AbstractSet[str]]]:
340    graph = make_first_graph(rules)
341    sccs = list(sccutils.strongly_connected_components(graph.keys(), graph))
342    for scc in sccs:
343        if len(scc) > 1:
344            for name in scc:
345                rules[name].left_recursive = True
346            # Try to find a leader such that all cycles go through it.
347            leaders = set(scc)
348            for start in scc:
349                for cycle in sccutils.find_cycles_in_scc(graph, scc, start):
350                    # print("Cycle:", " -> ".join(cycle))
351                    leaders -= scc - set(cycle)
352                    if not leaders:
353                        raise ValueError(
354                            f"SCC {scc} has no leadership candidate (no element is included in all cycles)"
355                        )
356            # print("Leaders:", leaders)
357            leader = min(leaders)  # Pick an arbitrary leader from the candidates.
358            rules[leader].leader = True
359        else:
360            name = min(scc)  # The only element.
361            if name in graph[name]:
362                rules[name].left_recursive = True
363                rules[name].leader = True
364    return graph, sccs
365
366
367def make_first_graph(rules: Dict[str, Rule]) -> Dict[str, AbstractSet[str]]:
368    """Compute the graph of left-invocations.
369
370    There's an edge from A to B if A may invoke B at its initial
371    position.
372
373    Note that this requires the nullable flags to have been computed.
374    """
375    initial_name_visitor = InitialNamesVisitor(rules)
376    graph = {}
377    vertices: Set[str] = set()
378    for rulename, rhs in rules.items():
379        graph[rulename] = names = initial_name_visitor.visit(rhs)
380        vertices |= names
381    for vertex in vertices:
382        graph.setdefault(vertex, set())
383    return graph
384