1from __future__ import annotations
2# mypy: disallow-untyped-defs, disallow-incomplete-defs, disallow-untyped-calls
3
4import typing
5import dataclasses
6
7from .ordered import OrderedFrozenSet
8from .grammar import Element, ErrorSymbol, InitNt, Nt
9from . import types, grammar
10
11# Avoid circular reference between this module and parse_table.py
12if typing.TYPE_CHECKING:
13    from .parse_table import StateId
14
15
16@dataclasses.dataclass(frozen=True)
17class StackDiff:
18    """StackDiff represent stack mutations which have to be performed when executing an action.
19    """
20    __slots__ = ['pop', 'nt', 'replay']
21
22    # Example: We have shifted `b * c X Y`. We want to reduce `b * c` to Mul.
23    #
24    # In the initial LR0 pass over the grammar, this produces a Reduce edge.
25    #
26    # action         pop          replay
27    # -------------- ------       ---------
28    # Reduce         3 (`b * c`)  2 (`X Y`)
29    #   The parser moves `X Y` to the replay queue, pops `b * c`, creates the
30    #   new `Mul` nonterminal, consults the stack and parse table to determine
31    #   the new state id, then replays the new nonterminal. Reduce leaves `X Y`
32    #   on the runtime replay queue. It's the runtime's responsibility to
33    #   notice that they are there and replay them.
34    #
35    # Later, the Reduce edge might be lowered into an [Unwind; FilterState;
36    # Replay] sequence, which encode both the Reduce action, and the expected
37    # behavior of the runtime.
38    #
39    # action         pop          replay
40    # -------------- ------       ---------
41    # Unwind         3            2
42    #   The parser moves `X Y` to the replay queue, pops `b * c`, creates the
43    #   new `Mul` nonterminal, and inserts it at the front of the replay queue.
44    #
45    # FilterState    ---          ---
46    #   Determines the new state id, if it's context-dependent.
47    #   This doesn't touch the stack, so no StackDiff.
48    #
49    # Replay         0            -3
50    #   Shift the three elements we left on the replay queue back to the stack:
51    #   `(b*c) X Y`.
52
53    # Number of elements to be popped from the stack, this is used when
54    # reducing the stack with a non-terminal.
55    #
56    # This number is always positive or zero.
57    pop: int
58
59    # When reducing, a non-terminal is pushed after removing all replayed and
60    # popped elements. If not None, this is the non-terminal which is produced
61    # by reducing the action.
62    nt: typing.Union[Nt, ErrorSymbol, None]
63
64    # Number of terms this action moves from the stack to the runtime replay
65    # queue (not counting `self.nt`), or from the replay queue to the stack if
66    # negative.
67    #
68    # When executing actions, some lookahead might have been used to make the
69    # parse table consistent. Replayed terms are popped before popping any
70    # elements from the stack, and they are added in reversed order in the
71    # replay list, such that they would be shifted after shifting the `reduced`
72    # non-terminal.
73    #
74    # This number might also be negative, in which case some lookahead terms
75    # are expected to exists in the replay list, and they are shifted back.
76    # This must happen only follow_edge is True.
77    replay: int
78
79    def reduce_stack(self) -> bool:
80        """Returns whether the action is reducing the stack by replacing popped
81        elements by a non-terminal. Note, this test is simpler than checking
82        for instances, as Reduce / Unwind might either be present, or present
83        as part of the last element of a Seq action. """
84        return self.nt is not None
85
86
87class Action:
88    __slots__ = ["_hash"]
89
90    # Cached hash.
91    _hash: typing.Optional[int]
92
93    def __init__(self) -> None:
94        self._hash = None
95
96    def is_inconsistent(self) -> bool:
97        """Returns True if this action is inconsistent. An action can be
98        inconsistent if the parameters it is given cannot be evaluated given
99        its current location in the parse table. Such as CheckNotOnNewLine.
100        """
101        return False
102
103    def is_condition(self) -> bool:
104        "Unordered condition, which accept or not to reach the next state."
105        return False
106
107    def condition(self) -> Action:
108        "Return the conditional action."
109        raise TypeError("Action.condition not implemented")
110
111    def check_same_variable(self, other: Action) -> bool:
112        "Return whether both conditionals are checking the same variable."
113        assert self.is_condition()
114        raise TypeError("Action.check_same_variable not implemented")
115
116    def check_different_values(self, other: Action) -> bool:
117        "Return whether these 2 conditions are mutually exclusive."
118        assert self.is_condition()
119        raise TypeError("Action.check_different_values not implemented")
120
121    def follow_edge(self) -> bool:
122        """Whether the execution of this action resume following the epsilon transition
123        (True) or if it breaks the graph epsilon transition (False) and returns
124        at a different location, defined by the top of the stack."""
125        return True
126
127    def update_stack(self) -> bool:
128        """Whether the execution of this action changes the parser stack."""
129        return False
130
131    def update_stack_with(self) -> StackDiff:
132        """Returns a StackDiff which represents the mutation to be applied to the
133        parser stack."""
134        assert self.update_stack()
135        raise TypeError("Action.update_stack_with not implemented")
136
137    def unshift_action(self, num: int) -> Action:
138        """When manipulating stack operation, we have the option to unshift some
139        replayed token which were shifted to disambiguate the grammar. However,
140        they might no longer be needed in some cases."""
141        raise TypeError("{} cannot be unshifted".format(self.__class__.__name__))
142
143    def shifted_action(self, shifted_term: Element) -> ShiftedAction:
144        """Transpose this action with shifting the given terminal or Nt.
145
146        That is, the sequence of:
147        - performing the action `self`, then
148        - shifting `shifted_term`
149        has the same effect as:
150        - shifting `shifted_term`, then
151        - performing the action `self.shifted_action(shifted_term)`.
152
153        If the resulting shifted action would be a no-op, instead return True.
154
155        If this is a conditional action and `shifted_term` indicates that the
156        condition wasn't met, return False.
157        """
158        return self
159
160    def contains_accept(self) -> bool:
161        "Returns whether the current action stops the parser."
162        return False
163
164    def rewrite_state_indexes(self, state_map: typing.Dict[StateId, StateId]) -> Action:
165        """If the action contains any state index, use the map to map the old index to
166        the new indexes"""
167        return self
168
169    def fold_by_destination(self, actions: typing.List[Action]) -> typing.List[Action]:
170        """If after rewriting state indexes, multiple condition are reaching the same
171        destination state, we attempt to fold them by destination. Not
172        implementing this function can lead to the introduction of inconsistent
173        states, as the conditions might be redundant. """
174
175        # By default do nothing.
176        return actions
177
178    def state_refs(self) -> typing.List[StateId]:
179        """List of states which are referenced by this action."""
180        # By default do nothing.
181        return []
182
183    def __eq__(self, other: object) -> bool:
184        if self.__class__ != other.__class__:
185            return False
186        assert isinstance(other, Action)
187        for s in self.__slots__:
188            if getattr(self, s) != getattr(other, s):
189                return False
190        return True
191
192    def __hash__(self) -> int:
193        if self._hash is not None:
194            return self._hash
195
196        def hashed_content() -> typing.Iterator[object]:
197            yield self.__class__
198            for s in self.__slots__:
199                yield repr(getattr(self, s))
200
201        self._hash = hash(tuple(hashed_content()))
202        return self._hash
203
204    def __lt__(self, other: Action) -> bool:
205        return hash(self) < hash(other)
206
207    def __repr__(self) -> str:
208        return str(self)
209
210    def stable_str(self, states: typing.Any) -> str:
211        return str(self)
212
213
214ShiftedAction = typing.Union[Action, bool]
215
216
217class Replay(Action):
218    """Replay a term which was previously saved by the Unwind function. Note that
219    this does not Shift a term given as argument as the replay action should
220    always be garanteed and that we want to maximize the sharing of code when
221    possible."""
222    __slots__ = ['replay_steps']
223
224    replay_steps: typing.Tuple[StateId, ...]
225
226    def __init__(self, replay_steps: typing.Iterable[StateId]):
227        super().__init__()
228        self.replay_steps = tuple(replay_steps)
229
230    def update_stack(self) -> bool:
231        return True
232
233    def update_stack_with(self) -> StackDiff:
234        return StackDiff(0, None, -len(self.replay_steps))
235
236    def rewrite_state_indexes(self, state_map: typing.Dict[StateId, StateId]) -> Replay:
237        return Replay(map(lambda s: state_map[s], self.replay_steps))
238
239    def state_refs(self) -> typing.List[StateId]:
240        return list(self.replay_steps)
241
242    def __str__(self) -> str:
243        return "Replay({})".format(str(self.replay_steps))
244
245
246class Unwind(Action):
247    """Define an unwind operation which pops N elements of the stack and pushes one
248    non-terminal. The replay argument of an unwind action corresponds to the
249    number of stack elements which would have to be popped and pushed again
250    using the parser table after executing this operation."""
251    __slots__ = ['nt', 'replay', 'pop']
252
253    nt: Nt
254    pop: int
255    replay: int
256
257    def __init__(self, nt: Nt, pop: int, replay: int = 0) -> None:
258        super().__init__()
259        self.nt = nt    # Non-terminal which is reduced
260        self.pop = pop  # Number of stack elements which should be replayed.
261        self.replay = replay  # List of terms to shift back
262
263    def __str__(self) -> str:
264        return "Unwind({}, {}, {})".format(self.nt, self.pop, self.replay)
265
266    def update_stack(self) -> bool:
267        return True
268
269    def update_stack_with(self) -> StackDiff:
270        return StackDiff(self.pop, self.nt, self.replay)
271
272    def unshift_action(self, num: int) -> Unwind:
273        return Unwind(self.nt, self.pop, replay=self.replay - num)
274
275    def shifted_action(self, shifted_term: Element) -> Unwind:
276        return Unwind(self.nt, self.pop, replay=self.replay + 1)
277
278
279class Reduce(Action):
280    """Prevent the fall-through to the epsilon transition and returns to the shift
281    table execution to resume shifting or replaying terms."""
282    __slots__ = ['unwind']
283
284    unwind: Unwind
285
286    def __init__(self, unwind: Unwind) -> None:
287        nt_name = unwind.nt.name
288        if isinstance(nt_name, InitNt):
289            name = "Start_" + str(nt_name.goal.name)
290        else:
291            name = nt_name
292        super().__init__()
293        self.unwind = unwind
294
295    def __str__(self) -> str:
296        return "Reduce({})".format(str(self.unwind))
297
298    def follow_edge(self) -> bool:
299        return False
300
301    def update_stack(self) -> bool:
302        return self.unwind.update_stack()
303
304    def update_stack_with(self) -> StackDiff:
305        return self.unwind.update_stack_with()
306
307    def unshift_action(self, num: int) -> Reduce:
308        unwind = self.unwind.unshift_action(num)
309        return Reduce(unwind)
310
311    def shifted_action(self, shifted_term: Element) -> Reduce:
312        unwind = self.unwind.shifted_action(shifted_term)
313        return Reduce(unwind)
314
315
316class Accept(Action):
317    """This state terminate the parser by accepting the content consumed until
318    now."""
319    __slots__: typing.List[str] = []
320
321    def __init__(self) -> None:
322        super().__init__()
323
324    def __str__(self) -> str:
325        return "Accept()"
326
327    def contains_accept(self) -> bool:
328        "Returns whether the current action stops the parser."
329        return True
330
331    def shifted_action(self, shifted_term: Element) -> Accept:
332        return Accept()
333
334
335class Lookahead(Action):
336    """Define a Lookahead assertion which is meant to either accept or reject
337    sequences of terminal/non-terminals sequences."""
338    __slots__ = ['terms', 'accept']
339
340    terms: typing.FrozenSet[str]
341    accept: bool
342
343    def __init__(self, terms: typing.FrozenSet[str], accept: bool):
344        super().__init__()
345        self.terms = terms
346        self.accept = accept
347
348    def is_inconsistent(self) -> bool:
349        # A lookahead restriction cannot be encoded in code, it has to be
350        # solved using fix_with_lookahead, which encodes the lookahead
351        # resolution in the generated parse table.
352        return True
353
354    def is_condition(self) -> bool:
355        return True
356
357    def condition(self) -> Lookahead:
358        return self
359
360    def check_same_variable(self, other: Action) -> bool:
361        raise TypeError("Lookahead.check_same_variables: Lookahead are always inconsistent")
362
363    def check_different_values(self, other: Action) -> bool:
364        raise TypeError("Lookahead.check_different_values: Lookahead are always inconsistent")
365
366    def __str__(self) -> str:
367        return "Lookahead({}, {})".format(self.terms, self.accept)
368
369    def shifted_action(self, shifted_term: Element) -> ShiftedAction:
370        if isinstance(shifted_term, Nt):
371            return True
372        if shifted_term in self.terms:
373            return self.accept
374        return not self.accept
375
376
377class CheckNotOnNewLine(Action):
378    """Check whether the terminal at the given stack offset is on a new line or
379    not. If not this would produce an Error, otherwise this rule would be
380    shifted."""
381    __slots__ = ['offset']
382
383    offset: int
384
385    def __init__(self, offset: int = 0) -> None:
386        # assert offset >= -1 and "Smaller offsets are not supported on all backends."
387        super().__init__()
388        self.offset = offset
389
390    def is_inconsistent(self) -> bool:
391        # We can only look at stacked terminals. Having an offset of 0 implies
392        # that we are looking for the next terminal, which is not yet shifted.
393        # Therefore this action is inconsistent as long as the terminal is not
394        # on the stack.
395        return self.offset >= 0
396
397    def is_condition(self) -> bool:
398        return True
399
400    def condition(self) -> CheckNotOnNewLine:
401        return self
402
403    def check_same_variable(self, other: Action) -> bool:
404        return isinstance(other, CheckNotOnNewLine) and self.offset == other.offset
405
406    def check_different_values(self, other: Action) -> bool:
407        return False
408
409    def shifted_action(self, shifted_term: Element) -> ShiftedAction:
410        if isinstance(shifted_term, Nt):
411            return True
412        return CheckNotOnNewLine(self.offset - 1)
413
414    def __str__(self) -> str:
415        return "CheckNotOnNewLine({})".format(self.offset)
416
417
418class FilterStates(Action):
419    """Check whether the stack at a given depth match the state value, if so
420    transition to the destination, otherwise check other states."""
421    __slots__ = ['states']
422
423    states: OrderedFrozenSet[StateId]
424
425    def __init__(self, states: typing.Iterable[StateId]):
426        super().__init__()
427        # Set of states which can follow this transition.
428        self.states = OrderedFrozenSet(sorted(states))
429
430    def is_condition(self) -> bool:
431        return True
432
433    def condition(self) -> FilterStates:
434        return self
435
436    def check_same_variable(self, other: Action) -> bool:
437        return isinstance(other, FilterStates)
438
439    def check_different_values(self, other: Action) -> bool:
440        assert isinstance(other, FilterStates)
441        return self.states.is_disjoint(other.states)
442
443    def rewrite_state_indexes(self, state_map: typing.Dict[StateId, StateId]) -> FilterStates:
444        states = list(state_map[s] for s in self.states)
445        return FilterStates(states)
446
447    def fold_by_destination(self, actions: typing.List[Action]) -> typing.List[Action]:
448        states: typing.List[StateId] = []
449        for a in actions:
450            if not isinstance(a, FilterStates):
451                # Do nothing in case the state is inconsistent.
452                return actions
453            states.extend(a.states)
454        return [FilterStates(states)]
455
456    def state_refs(self) -> typing.List[StateId]:
457        return list(self.states)
458
459    def __str__(self) -> str:
460        return "FilterStates({})".format(self.states)
461
462
463class FilterFlag(Action):
464    """Define a filter which check for one value of the flag, and continue to the
465    next state if the top of the flag stack matches the expected value."""
466    __slots__ = ['flag', 'value']
467
468    flag: str
469    value: object
470
471    def __init__(self, flag: str, value: object) -> None:
472        super().__init__()
473        self.flag = flag
474        self.value = value
475
476    def is_condition(self) -> bool:
477        return True
478
479    def condition(self) -> FilterFlag:
480        return self
481
482    def check_same_variable(self, other: Action) -> bool:
483        return isinstance(other, FilterFlag) and self.flag == other.flag
484
485    def check_different_values(self, other: Action) -> bool:
486        assert isinstance(other, FilterFlag)
487        return self.value != other.value
488
489    def __str__(self) -> str:
490        return "FilterFlag({}, {})".format(self.flag, self.value)
491
492
493class PushFlag(Action):
494    """Define an action which pushes a value on a stack dedicated to the flag. This
495    other stack correspond to another parse stack which live next to the
496    default state machine and is popped by PopFlag, as-if this was another
497    reduce action. This is particularly useful to raise the parse table from a
498    LR(0) to an LR(k) without needing as much state duplications."""
499    __slots__ = ['flag', 'value']
500
501    flag: str
502    value: object
503
504    def __init__(self, flag: str, value: object) -> None:
505        super().__init__()
506        self.flag = flag
507        self.value = value
508
509    def __str__(self) -> str:
510        return "PushFlag({}, {})".format(self.flag, self.value)
511
512
513class PopFlag(Action):
514    """Define an action which pops a flag from the flag bit stack."""
515    __slots__ = ['flag']
516
517    flag: str
518
519    def __init__(self, flag: str) -> None:
520        super().__init__()
521        self.flag = flag
522
523    def __str__(self) -> str:
524        return "PopFlag({})".format(self.flag)
525
526
527# OutputExpr: An expression mini-language that compiles very directly to code
528# in the output language (Rust or Python). An OutputExpr is one of:
529#
530# str - an identifier in the generated code
531# int - an index into the runtime stack
532# None or Some(FunCallArg) - an optional value
533#
534OutputExpr = typing.Union[str, int, None, grammar.Some]
535
536
537class FunCall(Action):
538    """Define a call method operation which reads N elements of he stack and
539    pushpathne non-terminal. The replay attribute of a reduce action correspond
540    to the number of stack elements which would have to be popped and pushed
541    again using the parser table after reducing this operation. """
542    __slots__ = ['trait', 'method', 'offset', 'args', 'fallible', 'set_to']
543
544    trait: types.Type
545    method: str
546    offset: int
547    args: typing.Tuple[OutputExpr, ...]
548    fallible: bool
549    set_to: str
550
551    def __init__(
552            self,
553            method: str,
554            args: typing.Tuple[OutputExpr, ...],
555            trait: types.Type = types.Type("AstBuilder"),
556            fallible: bool = False,
557            set_to: str = "val",
558            offset: int = 0,
559    ) -> None:
560        super().__init__()
561        self.trait = trait        # Trait on which this method is implemented.
562        self.method = method      # Method and argument to be read for calling it.
563        self.fallible = fallible  # Whether the function call can fail.
564        self.offset = offset      # Offset to add to each argument offset.
565        self.args = args          # Tuple of arguments offsets.
566        self.set_to = set_to      # Temporary variable name to set with the result.
567
568    def __str__(self) -> str:
569        return "{} = {}::{}({}){} [off: {}]".format(
570            self.set_to, self.trait, self.method,
571            ", ".join(map(str, self.args)),
572            self.fallible and '?' or '',
573            self.offset)
574
575    def __repr__(self) -> str:
576        return "FunCall({})".format(', '.join(map(repr, [
577            self.trait, self.method, self.fallible,
578            self.args, self.set_to, self.offset
579        ])))
580
581    def unshift_action(self, num: int) -> FunCall:
582        return FunCall(self.method, self.args,
583                       trait=self.trait,
584                       fallible=self.fallible,
585                       set_to=self.set_to,
586                       offset=self.offset - num)
587
588    def shifted_action(self, shifted_term: Element) -> FunCall:
589        return FunCall(self.method,
590                       self.args,
591                       trait=self.trait,
592                       fallible=self.fallible,
593                       set_to=self.set_to,
594                       offset=self.offset + 1)
595
596
597class Seq(Action):
598    """Aggregate multiple actions in one statement. Note, that the aggregated
599    actions should not contain any condition or action which are mutating the
600    state. Only the last action aggregated can update the parser stack"""
601    __slots__ = ['actions']
602
603    actions: typing.Tuple[Action, ...]
604
605    def __init__(self, actions: typing.Sequence[Action]) -> None:
606        super().__init__()
607        self.actions = tuple(actions)   # Ordered list of actions to execute.
608        assert all([not a.is_condition() for a in actions])
609        assert all([not isinstance(a, Seq) for a in actions])
610        assert all([a.follow_edge() for a in actions[:-1]])
611        assert all([not a.update_stack() for a in actions[:-1]])
612
613    def __str__(self) -> str:
614        return "{{ {} }}".format("; ".join(map(str, self.actions)))
615
616    def __repr__(self) -> str:
617        return "Seq({})".format(repr(self.actions))
618
619    def follow_edge(self) -> bool:
620        return self.actions[-1].follow_edge()
621
622    def update_stack(self) -> bool:
623        return self.actions[-1].update_stack()
624
625    def update_stack_with(self) -> StackDiff:
626        return self.actions[-1].update_stack_with()
627
628    def unshift_action(self, num: int) -> Seq:
629        actions = list(map(lambda a: a.unshift_action(num), self.actions))
630        return Seq(actions)
631
632    def shifted_action(self, shift: Element) -> ShiftedAction:
633        actions: typing.List[Action] = []
634        for a in self.actions:
635            b = a.shifted_action(shift)
636            if isinstance(b, bool):
637                if b is False:
638                    return False
639            else:
640                actions.append(b)
641        return Seq(actions)
642
643    def contains_accept(self) -> bool:
644        return any(a.contains_accept() for a in self.actions)
645
646    def rewrite_state_indexes(self, state_map: typing.Dict[StateId, StateId]) -> Seq:
647        actions = list(map(lambda a: a.rewrite_state_indexes(state_map), self.actions))
648        return Seq(actions)
649
650    def state_refs(self) -> typing.List[StateId]:
651        return [s for a in self.actions for s in a.state_refs()]
652