1# Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
2# Use of this file is governed by the BSD 3-clause license that
3# can be found in the LICENSE.txt file in the project root.
4#/
5from uuid import UUID
6from io import StringIO
7from typing import Callable
8from antlr4.Token import Token
9from antlr4.atn.ATN import ATN
10from antlr4.atn.ATNType import ATNType
11from antlr4.atn.ATNState import *
12from antlr4.atn.Transition import *
13from antlr4.atn.LexerAction import *
14from antlr4.atn.ATNDeserializationOptions import ATNDeserializationOptions
15
16# This is the earliest supported serialized UUID.
17BASE_SERIALIZED_UUID = UUID("AADB8D7E-AEEF-4415-AD2B-8204D6CF042E")
18
19# This UUID indicates the serialized ATN contains two sets of
20# IntervalSets, where the second set's values are encoded as
21# 32-bit integers to support the full Unicode SMP range up to U+10FFFF.
22ADDED_UNICODE_SMP = UUID("59627784-3BE5-417A-B9EB-8131A7286089")
23
24# This list contains all of the currently supported UUIDs, ordered by when
25# the feature first appeared in this branch.
26SUPPORTED_UUIDS = [ BASE_SERIALIZED_UUID, ADDED_UNICODE_SMP ]
27
28SERIALIZED_VERSION = 3
29
30# This is the current serialized UUID.
31SERIALIZED_UUID = ADDED_UNICODE_SMP
32
33class ATNDeserializer (object):
34
35    def __init__(self, options : ATNDeserializationOptions = None):
36        if options is None:
37            options = ATNDeserializationOptions.defaultOptions
38        self.deserializationOptions = options
39
40    # Determines if a particular serialized representation of an ATN supports
41    # a particular feature, identified by the {@link UUID} used for serializing
42    # the ATN at the time the feature was first introduced.
43    #
44    # @param feature The {@link UUID} marking the first time the feature was
45    # supported in the serialized ATN.
46    # @param actualUuid The {@link UUID} of the actual serialized ATN which is
47    # currently being deserialized.
48    # @return {@code true} if the {@code actualUuid} value represents a
49    # serialized ATN at or after the feature identified by {@code feature} was
50    # introduced; otherwise, {@code false}.
51
52    def isFeatureSupported(self, feature : UUID , actualUuid : UUID ):
53        idx1 = SUPPORTED_UUIDS.index(feature)
54        if idx1<0:
55            return False
56        idx2 = SUPPORTED_UUIDS.index(actualUuid)
57        return idx2 >= idx1
58
59    def deserialize(self, data : str):
60        self.reset(data)
61        self.checkVersion()
62        self.checkUUID()
63        atn = self.readATN()
64        self.readStates(atn)
65        self.readRules(atn)
66        self.readModes(atn)
67        sets = []
68        # First, read all sets with 16-bit Unicode code points <= U+FFFF.
69        self.readSets(atn, sets, self.readInt)
70        # Next, if the ATN was serialized with the Unicode SMP feature,
71        # deserialize sets with 32-bit arguments <= U+10FFFF.
72        if self.isFeatureSupported(ADDED_UNICODE_SMP, self.uuid):
73            self.readSets(atn, sets, self.readInt32)
74        self.readEdges(atn, sets)
75        self.readDecisions(atn)
76        self.readLexerActions(atn)
77        self.markPrecedenceDecisions(atn)
78        self.verifyATN(atn)
79        if self.deserializationOptions.generateRuleBypassTransitions \
80                and atn.grammarType == ATNType.PARSER:
81            self.generateRuleBypassTransitions(atn)
82            # re-verify after modification
83            self.verifyATN(atn)
84        return atn
85
86    def reset(self, data:str):
87        def adjust(c):
88            v = ord(c)
89            return v-2 if v>1 else v + 65533
90        temp = [ adjust(c) for c in data ]
91        # don't adjust the first value since that's the version number
92        temp[0] = ord(data[0])
93        self.data = temp
94        self.pos = 0
95
96    def checkVersion(self):
97        version = self.readInt()
98        if version != SERIALIZED_VERSION:
99            raise Exception("Could not deserialize ATN with version " + str(version) + " (expected " + str(SERIALIZED_VERSION) + ").")
100
101    def checkUUID(self):
102        uuid = self.readUUID()
103        if not uuid in SUPPORTED_UUIDS:
104            raise Exception("Could not deserialize ATN with UUID: " + str(uuid) + \
105                            " (expected " + str(SERIALIZED_UUID) + " or a legacy UUID).", uuid, SERIALIZED_UUID)
106        self.uuid = uuid
107
108    def readATN(self):
109        idx = self.readInt()
110        grammarType = ATNType.fromOrdinal(idx)
111        maxTokenType = self.readInt()
112        return ATN(grammarType, maxTokenType)
113
114    def readStates(self, atn:ATN):
115        loopBackStateNumbers = []
116        endStateNumbers = []
117        nstates = self.readInt()
118        for i in range(0, nstates):
119            stype = self.readInt()
120            # ignore bad type of states
121            if stype==ATNState.INVALID_TYPE:
122                atn.addState(None)
123                continue
124            ruleIndex = self.readInt()
125            if ruleIndex == 0xFFFF:
126                ruleIndex = -1
127
128            s = self.stateFactory(stype, ruleIndex)
129            if stype == ATNState.LOOP_END: # special case
130                loopBackStateNumber = self.readInt()
131                loopBackStateNumbers.append((s, loopBackStateNumber))
132            elif isinstance(s, BlockStartState):
133                endStateNumber = self.readInt()
134                endStateNumbers.append((s, endStateNumber))
135
136            atn.addState(s)
137
138        # delay the assignment of loop back and end states until we know all the state instances have been initialized
139        for pair in loopBackStateNumbers:
140            pair[0].loopBackState = atn.states[pair[1]]
141
142        for pair in endStateNumbers:
143            pair[0].endState = atn.states[pair[1]]
144
145        numNonGreedyStates = self.readInt()
146        for i in range(0, numNonGreedyStates):
147            stateNumber = self.readInt()
148            atn.states[stateNumber].nonGreedy = True
149
150        numPrecedenceStates = self.readInt()
151        for i in range(0, numPrecedenceStates):
152            stateNumber = self.readInt()
153            atn.states[stateNumber].isPrecedenceRule = True
154
155    def readRules(self, atn:ATN):
156        nrules = self.readInt()
157        if atn.grammarType == ATNType.LEXER:
158            atn.ruleToTokenType = [0] * nrules
159
160        atn.ruleToStartState = [0] * nrules
161        for i in range(0, nrules):
162            s = self.readInt()
163            startState = atn.states[s]
164            atn.ruleToStartState[i] = startState
165            if atn.grammarType == ATNType.LEXER:
166                tokenType = self.readInt()
167                if tokenType == 0xFFFF:
168                    tokenType = Token.EOF
169
170                atn.ruleToTokenType[i] = tokenType
171
172        atn.ruleToStopState = [0] * nrules
173        for state in atn.states:
174            if not isinstance(state, RuleStopState):
175                continue
176            atn.ruleToStopState[state.ruleIndex] = state
177            atn.ruleToStartState[state.ruleIndex].stopState = state
178
179    def readModes(self, atn:ATN):
180        nmodes = self.readInt()
181        for i in range(0, nmodes):
182            s = self.readInt()
183            atn.modeToStartState.append(atn.states[s])
184
185    def readSets(self, atn:ATN, sets:list, readUnicode:Callable[[], int]):
186        m = self.readInt()
187        for i in range(0, m):
188            iset = IntervalSet()
189            sets.append(iset)
190            n = self.readInt()
191            containsEof = self.readInt()
192            if containsEof!=0:
193                iset.addOne(-1)
194            for j in range(0, n):
195                i1 = readUnicode()
196                i2 = readUnicode()
197                iset.addRange(range(i1, i2 + 1)) # range upper limit is exclusive
198
199    def readEdges(self, atn:ATN, sets:list):
200        nedges = self.readInt()
201        for i in range(0, nedges):
202            src = self.readInt()
203            trg = self.readInt()
204            ttype = self.readInt()
205            arg1 = self.readInt()
206            arg2 = self.readInt()
207            arg3 = self.readInt()
208            trans = self.edgeFactory(atn, ttype, src, trg, arg1, arg2, arg3, sets)
209            srcState = atn.states[src]
210            srcState.addTransition(trans)
211
212        # edges for rule stop states can be derived, so they aren't serialized
213        for state in atn.states:
214            for i in range(0, len(state.transitions)):
215                t = state.transitions[i]
216                if not isinstance(t, RuleTransition):
217                    continue
218                outermostPrecedenceReturn = -1
219                if atn.ruleToStartState[t.target.ruleIndex].isPrecedenceRule:
220                    if t.precedence == 0:
221                        outermostPrecedenceReturn = t.target.ruleIndex
222                trans = EpsilonTransition(t.followState, outermostPrecedenceReturn)
223                atn.ruleToStopState[t.target.ruleIndex].addTransition(trans)
224
225        for state in atn.states:
226            if isinstance(state, BlockStartState):
227                # we need to know the end state to set its start state
228                if state.endState is None:
229                    raise Exception("IllegalState")
230                # block end states can only be associated to a single block start state
231                if state.endState.startState is not None:
232                    raise Exception("IllegalState")
233                state.endState.startState = state
234
235            if isinstance(state, PlusLoopbackState):
236                for i in range(0, len(state.transitions)):
237                    target = state.transitions[i].target
238                    if isinstance(target, PlusBlockStartState):
239                        target.loopBackState = state
240            elif isinstance(state, StarLoopbackState):
241                for i in range(0, len(state.transitions)):
242                    target = state.transitions[i].target
243                    if isinstance(target, StarLoopEntryState):
244                        target.loopBackState = state
245
246    def readDecisions(self, atn:ATN):
247        ndecisions = self.readInt()
248        for i in range(0, ndecisions):
249            s = self.readInt()
250            decState = atn.states[s]
251            atn.decisionToState.append(decState)
252            decState.decision = i
253
254    def readLexerActions(self, atn:ATN):
255        if atn.grammarType == ATNType.LEXER:
256            count = self.readInt()
257            atn.lexerActions = [ None ] * count
258            for i in range(0, count):
259                actionType = self.readInt()
260                data1 = self.readInt()
261                if data1 == 0xFFFF:
262                    data1 = -1
263                data2 = self.readInt()
264                if data2 == 0xFFFF:
265                    data2 = -1
266                lexerAction = self.lexerActionFactory(actionType, data1, data2)
267                atn.lexerActions[i] = lexerAction
268
269    def generateRuleBypassTransitions(self, atn:ATN):
270
271        count = len(atn.ruleToStartState)
272        atn.ruleToTokenType = [ 0 ] * count
273        for i in range(0, count):
274            atn.ruleToTokenType[i] = atn.maxTokenType + i + 1
275
276        for i in range(0, count):
277            self.generateRuleBypassTransition(atn, i)
278
279    def generateRuleBypassTransition(self, atn:ATN, idx:int):
280
281        bypassStart = BasicBlockStartState()
282        bypassStart.ruleIndex = idx
283        atn.addState(bypassStart)
284
285        bypassStop = BlockEndState()
286        bypassStop.ruleIndex = idx
287        atn.addState(bypassStop)
288
289        bypassStart.endState = bypassStop
290        atn.defineDecisionState(bypassStart)
291
292        bypassStop.startState = bypassStart
293
294        excludeTransition = None
295
296        if atn.ruleToStartState[idx].isPrecedenceRule:
297            # wrap from the beginning of the rule to the StarLoopEntryState
298            endState = None
299            for state in atn.states:
300                if self.stateIsEndStateFor(state, idx):
301                    endState = state
302                    excludeTransition = state.loopBackState.transitions[0]
303                    break
304
305            if excludeTransition is None:
306                raise Exception("Couldn't identify final state of the precedence rule prefix section.")
307
308        else:
309
310            endState = atn.ruleToStopState[idx]
311
312        # all non-excluded transitions that currently target end state need to target blockEnd instead
313        for state in atn.states:
314            for transition in state.transitions:
315                if transition == excludeTransition:
316                    continue
317                if transition.target == endState:
318                    transition.target = bypassStop
319
320        # all transitions leaving the rule start state need to leave blockStart instead
321        ruleToStartState = atn.ruleToStartState[idx]
322        count = len(ruleToStartState.transitions)
323        while count > 0:
324            bypassStart.addTransition(ruleToStartState.transitions[count-1])
325            del ruleToStartState.transitions[-1]
326
327        # link the new states
328        atn.ruleToStartState[idx].addTransition(EpsilonTransition(bypassStart))
329        bypassStop.addTransition(EpsilonTransition(endState))
330
331        matchState = BasicState()
332        atn.addState(matchState)
333        matchState.addTransition(AtomTransition(bypassStop, atn.ruleToTokenType[idx]))
334        bypassStart.addTransition(EpsilonTransition(matchState))
335
336
337    def stateIsEndStateFor(self, state:ATNState, idx:int):
338        if state.ruleIndex != idx:
339            return None
340        if not isinstance(state, StarLoopEntryState):
341            return None
342
343        maybeLoopEndState = state.transitions[len(state.transitions) - 1].target
344        if not isinstance(maybeLoopEndState, LoopEndState):
345            return None
346
347        if maybeLoopEndState.epsilonOnlyTransitions and \
348                isinstance(maybeLoopEndState.transitions[0].target, RuleStopState):
349            return state
350        else:
351            return None
352
353
354    #
355    # Analyze the {@link StarLoopEntryState} states in the specified ATN to set
356    # the {@link StarLoopEntryState#isPrecedenceDecision} field to the
357    # correct value.
358    #
359    # @param atn The ATN.
360    #
361    def markPrecedenceDecisions(self, atn:ATN):
362        for state in atn.states:
363            if not isinstance(state, StarLoopEntryState):
364                continue
365
366            # We analyze the ATN to determine if this ATN decision state is the
367            # decision for the closure block that determines whether a
368            # precedence rule should continue or complete.
369            #
370            if atn.ruleToStartState[state.ruleIndex].isPrecedenceRule:
371                maybeLoopEndState = state.transitions[len(state.transitions) - 1].target
372                if isinstance(maybeLoopEndState, LoopEndState):
373                    if maybeLoopEndState.epsilonOnlyTransitions and \
374                            isinstance(maybeLoopEndState.transitions[0].target, RuleStopState):
375                        state.isPrecedenceDecision = True
376
377    def verifyATN(self, atn:ATN):
378        if not self.deserializationOptions.verifyATN:
379            return
380        # verify assumptions
381        for state in atn.states:
382            if state is None:
383                continue
384
385            self.checkCondition(state.epsilonOnlyTransitions or len(state.transitions) <= 1)
386
387            if isinstance(state, PlusBlockStartState):
388                self.checkCondition(state.loopBackState is not None)
389
390            if isinstance(state, StarLoopEntryState):
391                self.checkCondition(state.loopBackState is not None)
392                self.checkCondition(len(state.transitions) == 2)
393
394                if isinstance(state.transitions[0].target, StarBlockStartState):
395                    self.checkCondition(isinstance(state.transitions[1].target, LoopEndState))
396                    self.checkCondition(not state.nonGreedy)
397                elif isinstance(state.transitions[0].target, LoopEndState):
398                    self.checkCondition(isinstance(state.transitions[1].target, StarBlockStartState))
399                    self.checkCondition(state.nonGreedy)
400                else:
401                    raise Exception("IllegalState")
402
403            if isinstance(state, StarLoopbackState):
404                self.checkCondition(len(state.transitions) == 1)
405                self.checkCondition(isinstance(state.transitions[0].target, StarLoopEntryState))
406
407            if isinstance(state, LoopEndState):
408                self.checkCondition(state.loopBackState is not None)
409
410            if isinstance(state, RuleStartState):
411                self.checkCondition(state.stopState is not None)
412
413            if isinstance(state, BlockStartState):
414                self.checkCondition(state.endState is not None)
415
416            if isinstance(state, BlockEndState):
417                self.checkCondition(state.startState is not None)
418
419            if isinstance(state, DecisionState):
420                self.checkCondition(len(state.transitions) <= 1 or state.decision >= 0)
421            else:
422                self.checkCondition(len(state.transitions) <= 1 or isinstance(state, RuleStopState))
423
424    def checkCondition(self, condition:bool, message=None):
425        if not condition:
426            if message is None:
427                message = "IllegalState"
428            raise Exception(message)
429
430    def readInt(self):
431        i = self.data[self.pos]
432        self.pos += 1
433        return i
434
435    def readInt32(self):
436        low = self.readInt()
437        high = self.readInt()
438        return low | (high << 16)
439
440    def readLong(self):
441        low = self.readInt32()
442        high = self.readInt32()
443        return (low & 0x00000000FFFFFFFF) | (high << 32)
444
445    def readUUID(self):
446        low = self.readLong()
447        high = self.readLong()
448        allBits = (low & 0xFFFFFFFFFFFFFFFF) | (high << 64)
449        return UUID(int=allBits)
450
451    edgeFactories = [ lambda args : None,
452                      lambda atn, src, trg, arg1, arg2, arg3, sets, target : EpsilonTransition(target),
453                      lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
454                        RangeTransition(target, Token.EOF, arg2) if arg3 != 0 else RangeTransition(target, arg1, arg2),
455                      lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
456                        RuleTransition(atn.states[arg1], arg2, arg3, target),
457                      lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
458                        PredicateTransition(target, arg1, arg2, arg3 != 0),
459                      lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
460                        AtomTransition(target, Token.EOF) if arg3 != 0 else AtomTransition(target, arg1),
461                      lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
462                        ActionTransition(target, arg1, arg2, arg3 != 0),
463                      lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
464                        SetTransition(target, sets[arg1]),
465                      lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
466                        NotSetTransition(target, sets[arg1]),
467                      lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
468                        WildcardTransition(target),
469                      lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
470                        PrecedencePredicateTransition(target, arg1)
471                      ]
472
473    def edgeFactory(self, atn:ATN, type:int, src:int, trg:int, arg1:int, arg2:int, arg3:int, sets:list):
474        target = atn.states[trg]
475        if type > len(self.edgeFactories) or self.edgeFactories[type] is None:
476            raise Exception("The specified transition type: " + str(type) + " is not valid.")
477        else:
478            return self.edgeFactories[type](atn, src, trg, arg1, arg2, arg3, sets, target)
479
480    stateFactories = [  lambda : None,
481                        lambda : BasicState(),
482                        lambda : RuleStartState(),
483                        lambda : BasicBlockStartState(),
484                        lambda : PlusBlockStartState(),
485                        lambda : StarBlockStartState(),
486                        lambda : TokensStartState(),
487                        lambda : RuleStopState(),
488                        lambda : BlockEndState(),
489                        lambda : StarLoopbackState(),
490                        lambda : StarLoopEntryState(),
491                        lambda : PlusLoopbackState(),
492                        lambda : LoopEndState()
493                    ]
494
495    def stateFactory(self, type:int, ruleIndex:int):
496        if type> len(self.stateFactories) or self.stateFactories[type] is None:
497            raise Exception("The specified state type " + str(type) + " is not valid.")
498        else:
499            s = self.stateFactories[type]()
500            if s is not None:
501                s.ruleIndex = ruleIndex
502        return s
503
504    CHANNEL = 0     #The type of a {@link LexerChannelAction} action.
505    CUSTOM = 1      #The type of a {@link LexerCustomAction} action.
506    MODE = 2        #The type of a {@link LexerModeAction} action.
507    MORE = 3        #The type of a {@link LexerMoreAction} action.
508    POP_MODE = 4    #The type of a {@link LexerPopModeAction} action.
509    PUSH_MODE = 5   #The type of a {@link LexerPushModeAction} action.
510    SKIP = 6        #The type of a {@link LexerSkipAction} action.
511    TYPE = 7        #The type of a {@link LexerTypeAction} action.
512
513    actionFactories = [ lambda data1, data2: LexerChannelAction(data1),
514                        lambda data1, data2: LexerCustomAction(data1, data2),
515                        lambda data1, data2: LexerModeAction(data1),
516                        lambda data1, data2: LexerMoreAction.INSTANCE,
517                        lambda data1, data2: LexerPopModeAction.INSTANCE,
518                        lambda data1, data2: LexerPushModeAction(data1),
519                        lambda data1, data2: LexerSkipAction.INSTANCE,
520                        lambda data1, data2: LexerTypeAction(data1)
521                      ]
522
523    def lexerActionFactory(self, type:int, data1:int, data2:int):
524
525        if type > len(self.actionFactories) or self.actionFactories[type] is None:
526            raise Exception("The specified lexer action type " + str(type) + " is not valid.")
527        else:
528            return self.actionFactories[type](data1, data2)
529