1# Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. 2# Use of this file is governed by the BSD 3-clause license that 3# can be found in the LICENSE.txt file in the project root. 4#/ 5from uuid import UUID 6from io import StringIO 7from typing import Callable 8from antlr4.Token import Token 9from antlr4.atn.ATN import ATN 10from antlr4.atn.ATNType import ATNType 11from antlr4.atn.ATNState import * 12from antlr4.atn.Transition import * 13from antlr4.atn.LexerAction import * 14from antlr4.atn.ATNDeserializationOptions import ATNDeserializationOptions 15 16# This is the earliest supported serialized UUID. 17BASE_SERIALIZED_UUID = UUID("AADB8D7E-AEEF-4415-AD2B-8204D6CF042E") 18 19# This UUID indicates the serialized ATN contains two sets of 20# IntervalSets, where the second set's values are encoded as 21# 32-bit integers to support the full Unicode SMP range up to U+10FFFF. 22ADDED_UNICODE_SMP = UUID("59627784-3BE5-417A-B9EB-8131A7286089") 23 24# This list contains all of the currently supported UUIDs, ordered by when 25# the feature first appeared in this branch. 26SUPPORTED_UUIDS = [ BASE_SERIALIZED_UUID, ADDED_UNICODE_SMP ] 27 28SERIALIZED_VERSION = 3 29 30# This is the current serialized UUID. 31SERIALIZED_UUID = ADDED_UNICODE_SMP 32 33class ATNDeserializer (object): 34 35 def __init__(self, options : ATNDeserializationOptions = None): 36 if options is None: 37 options = ATNDeserializationOptions.defaultOptions 38 self.deserializationOptions = options 39 40 # Determines if a particular serialized representation of an ATN supports 41 # a particular feature, identified by the {@link UUID} used for serializing 42 # the ATN at the time the feature was first introduced. 43 # 44 # @param feature The {@link UUID} marking the first time the feature was 45 # supported in the serialized ATN. 46 # @param actualUuid The {@link UUID} of the actual serialized ATN which is 47 # currently being deserialized. 48 # @return {@code true} if the {@code actualUuid} value represents a 49 # serialized ATN at or after the feature identified by {@code feature} was 50 # introduced; otherwise, {@code false}. 51 52 def isFeatureSupported(self, feature : UUID , actualUuid : UUID ): 53 idx1 = SUPPORTED_UUIDS.index(feature) 54 if idx1<0: 55 return False 56 idx2 = SUPPORTED_UUIDS.index(actualUuid) 57 return idx2 >= idx1 58 59 def deserialize(self, data : str): 60 self.reset(data) 61 self.checkVersion() 62 self.checkUUID() 63 atn = self.readATN() 64 self.readStates(atn) 65 self.readRules(atn) 66 self.readModes(atn) 67 sets = [] 68 # First, read all sets with 16-bit Unicode code points <= U+FFFF. 69 self.readSets(atn, sets, self.readInt) 70 # Next, if the ATN was serialized with the Unicode SMP feature, 71 # deserialize sets with 32-bit arguments <= U+10FFFF. 72 if self.isFeatureSupported(ADDED_UNICODE_SMP, self.uuid): 73 self.readSets(atn, sets, self.readInt32) 74 self.readEdges(atn, sets) 75 self.readDecisions(atn) 76 self.readLexerActions(atn) 77 self.markPrecedenceDecisions(atn) 78 self.verifyATN(atn) 79 if self.deserializationOptions.generateRuleBypassTransitions \ 80 and atn.grammarType == ATNType.PARSER: 81 self.generateRuleBypassTransitions(atn) 82 # re-verify after modification 83 self.verifyATN(atn) 84 return atn 85 86 def reset(self, data:str): 87 def adjust(c): 88 v = ord(c) 89 return v-2 if v>1 else v + 65533 90 temp = [ adjust(c) for c in data ] 91 # don't adjust the first value since that's the version number 92 temp[0] = ord(data[0]) 93 self.data = temp 94 self.pos = 0 95 96 def checkVersion(self): 97 version = self.readInt() 98 if version != SERIALIZED_VERSION: 99 raise Exception("Could not deserialize ATN with version " + str(version) + " (expected " + str(SERIALIZED_VERSION) + ").") 100 101 def checkUUID(self): 102 uuid = self.readUUID() 103 if not uuid in SUPPORTED_UUIDS: 104 raise Exception("Could not deserialize ATN with UUID: " + str(uuid) + \ 105 " (expected " + str(SERIALIZED_UUID) + " or a legacy UUID).", uuid, SERIALIZED_UUID) 106 self.uuid = uuid 107 108 def readATN(self): 109 idx = self.readInt() 110 grammarType = ATNType.fromOrdinal(idx) 111 maxTokenType = self.readInt() 112 return ATN(grammarType, maxTokenType) 113 114 def readStates(self, atn:ATN): 115 loopBackStateNumbers = [] 116 endStateNumbers = [] 117 nstates = self.readInt() 118 for i in range(0, nstates): 119 stype = self.readInt() 120 # ignore bad type of states 121 if stype==ATNState.INVALID_TYPE: 122 atn.addState(None) 123 continue 124 ruleIndex = self.readInt() 125 if ruleIndex == 0xFFFF: 126 ruleIndex = -1 127 128 s = self.stateFactory(stype, ruleIndex) 129 if stype == ATNState.LOOP_END: # special case 130 loopBackStateNumber = self.readInt() 131 loopBackStateNumbers.append((s, loopBackStateNumber)) 132 elif isinstance(s, BlockStartState): 133 endStateNumber = self.readInt() 134 endStateNumbers.append((s, endStateNumber)) 135 136 atn.addState(s) 137 138 # delay the assignment of loop back and end states until we know all the state instances have been initialized 139 for pair in loopBackStateNumbers: 140 pair[0].loopBackState = atn.states[pair[1]] 141 142 for pair in endStateNumbers: 143 pair[0].endState = atn.states[pair[1]] 144 145 numNonGreedyStates = self.readInt() 146 for i in range(0, numNonGreedyStates): 147 stateNumber = self.readInt() 148 atn.states[stateNumber].nonGreedy = True 149 150 numPrecedenceStates = self.readInt() 151 for i in range(0, numPrecedenceStates): 152 stateNumber = self.readInt() 153 atn.states[stateNumber].isPrecedenceRule = True 154 155 def readRules(self, atn:ATN): 156 nrules = self.readInt() 157 if atn.grammarType == ATNType.LEXER: 158 atn.ruleToTokenType = [0] * nrules 159 160 atn.ruleToStartState = [0] * nrules 161 for i in range(0, nrules): 162 s = self.readInt() 163 startState = atn.states[s] 164 atn.ruleToStartState[i] = startState 165 if atn.grammarType == ATNType.LEXER: 166 tokenType = self.readInt() 167 if tokenType == 0xFFFF: 168 tokenType = Token.EOF 169 170 atn.ruleToTokenType[i] = tokenType 171 172 atn.ruleToStopState = [0] * nrules 173 for state in atn.states: 174 if not isinstance(state, RuleStopState): 175 continue 176 atn.ruleToStopState[state.ruleIndex] = state 177 atn.ruleToStartState[state.ruleIndex].stopState = state 178 179 def readModes(self, atn:ATN): 180 nmodes = self.readInt() 181 for i in range(0, nmodes): 182 s = self.readInt() 183 atn.modeToStartState.append(atn.states[s]) 184 185 def readSets(self, atn:ATN, sets:list, readUnicode:Callable[[], int]): 186 m = self.readInt() 187 for i in range(0, m): 188 iset = IntervalSet() 189 sets.append(iset) 190 n = self.readInt() 191 containsEof = self.readInt() 192 if containsEof!=0: 193 iset.addOne(-1) 194 for j in range(0, n): 195 i1 = readUnicode() 196 i2 = readUnicode() 197 iset.addRange(range(i1, i2 + 1)) # range upper limit is exclusive 198 199 def readEdges(self, atn:ATN, sets:list): 200 nedges = self.readInt() 201 for i in range(0, nedges): 202 src = self.readInt() 203 trg = self.readInt() 204 ttype = self.readInt() 205 arg1 = self.readInt() 206 arg2 = self.readInt() 207 arg3 = self.readInt() 208 trans = self.edgeFactory(atn, ttype, src, trg, arg1, arg2, arg3, sets) 209 srcState = atn.states[src] 210 srcState.addTransition(trans) 211 212 # edges for rule stop states can be derived, so they aren't serialized 213 for state in atn.states: 214 for i in range(0, len(state.transitions)): 215 t = state.transitions[i] 216 if not isinstance(t, RuleTransition): 217 continue 218 outermostPrecedenceReturn = -1 219 if atn.ruleToStartState[t.target.ruleIndex].isPrecedenceRule: 220 if t.precedence == 0: 221 outermostPrecedenceReturn = t.target.ruleIndex 222 trans = EpsilonTransition(t.followState, outermostPrecedenceReturn) 223 atn.ruleToStopState[t.target.ruleIndex].addTransition(trans) 224 225 for state in atn.states: 226 if isinstance(state, BlockStartState): 227 # we need to know the end state to set its start state 228 if state.endState is None: 229 raise Exception("IllegalState") 230 # block end states can only be associated to a single block start state 231 if state.endState.startState is not None: 232 raise Exception("IllegalState") 233 state.endState.startState = state 234 235 if isinstance(state, PlusLoopbackState): 236 for i in range(0, len(state.transitions)): 237 target = state.transitions[i].target 238 if isinstance(target, PlusBlockStartState): 239 target.loopBackState = state 240 elif isinstance(state, StarLoopbackState): 241 for i in range(0, len(state.transitions)): 242 target = state.transitions[i].target 243 if isinstance(target, StarLoopEntryState): 244 target.loopBackState = state 245 246 def readDecisions(self, atn:ATN): 247 ndecisions = self.readInt() 248 for i in range(0, ndecisions): 249 s = self.readInt() 250 decState = atn.states[s] 251 atn.decisionToState.append(decState) 252 decState.decision = i 253 254 def readLexerActions(self, atn:ATN): 255 if atn.grammarType == ATNType.LEXER: 256 count = self.readInt() 257 atn.lexerActions = [ None ] * count 258 for i in range(0, count): 259 actionType = self.readInt() 260 data1 = self.readInt() 261 if data1 == 0xFFFF: 262 data1 = -1 263 data2 = self.readInt() 264 if data2 == 0xFFFF: 265 data2 = -1 266 lexerAction = self.lexerActionFactory(actionType, data1, data2) 267 atn.lexerActions[i] = lexerAction 268 269 def generateRuleBypassTransitions(self, atn:ATN): 270 271 count = len(atn.ruleToStartState) 272 atn.ruleToTokenType = [ 0 ] * count 273 for i in range(0, count): 274 atn.ruleToTokenType[i] = atn.maxTokenType + i + 1 275 276 for i in range(0, count): 277 self.generateRuleBypassTransition(atn, i) 278 279 def generateRuleBypassTransition(self, atn:ATN, idx:int): 280 281 bypassStart = BasicBlockStartState() 282 bypassStart.ruleIndex = idx 283 atn.addState(bypassStart) 284 285 bypassStop = BlockEndState() 286 bypassStop.ruleIndex = idx 287 atn.addState(bypassStop) 288 289 bypassStart.endState = bypassStop 290 atn.defineDecisionState(bypassStart) 291 292 bypassStop.startState = bypassStart 293 294 excludeTransition = None 295 296 if atn.ruleToStartState[idx].isPrecedenceRule: 297 # wrap from the beginning of the rule to the StarLoopEntryState 298 endState = None 299 for state in atn.states: 300 if self.stateIsEndStateFor(state, idx): 301 endState = state 302 excludeTransition = state.loopBackState.transitions[0] 303 break 304 305 if excludeTransition is None: 306 raise Exception("Couldn't identify final state of the precedence rule prefix section.") 307 308 else: 309 310 endState = atn.ruleToStopState[idx] 311 312 # all non-excluded transitions that currently target end state need to target blockEnd instead 313 for state in atn.states: 314 for transition in state.transitions: 315 if transition == excludeTransition: 316 continue 317 if transition.target == endState: 318 transition.target = bypassStop 319 320 # all transitions leaving the rule start state need to leave blockStart instead 321 ruleToStartState = atn.ruleToStartState[idx] 322 count = len(ruleToStartState.transitions) 323 while count > 0: 324 bypassStart.addTransition(ruleToStartState.transitions[count-1]) 325 del ruleToStartState.transitions[-1] 326 327 # link the new states 328 atn.ruleToStartState[idx].addTransition(EpsilonTransition(bypassStart)) 329 bypassStop.addTransition(EpsilonTransition(endState)) 330 331 matchState = BasicState() 332 atn.addState(matchState) 333 matchState.addTransition(AtomTransition(bypassStop, atn.ruleToTokenType[idx])) 334 bypassStart.addTransition(EpsilonTransition(matchState)) 335 336 337 def stateIsEndStateFor(self, state:ATNState, idx:int): 338 if state.ruleIndex != idx: 339 return None 340 if not isinstance(state, StarLoopEntryState): 341 return None 342 343 maybeLoopEndState = state.transitions[len(state.transitions) - 1].target 344 if not isinstance(maybeLoopEndState, LoopEndState): 345 return None 346 347 if maybeLoopEndState.epsilonOnlyTransitions and \ 348 isinstance(maybeLoopEndState.transitions[0].target, RuleStopState): 349 return state 350 else: 351 return None 352 353 354 # 355 # Analyze the {@link StarLoopEntryState} states in the specified ATN to set 356 # the {@link StarLoopEntryState#isPrecedenceDecision} field to the 357 # correct value. 358 # 359 # @param atn The ATN. 360 # 361 def markPrecedenceDecisions(self, atn:ATN): 362 for state in atn.states: 363 if not isinstance(state, StarLoopEntryState): 364 continue 365 366 # We analyze the ATN to determine if this ATN decision state is the 367 # decision for the closure block that determines whether a 368 # precedence rule should continue or complete. 369 # 370 if atn.ruleToStartState[state.ruleIndex].isPrecedenceRule: 371 maybeLoopEndState = state.transitions[len(state.transitions) - 1].target 372 if isinstance(maybeLoopEndState, LoopEndState): 373 if maybeLoopEndState.epsilonOnlyTransitions and \ 374 isinstance(maybeLoopEndState.transitions[0].target, RuleStopState): 375 state.isPrecedenceDecision = True 376 377 def verifyATN(self, atn:ATN): 378 if not self.deserializationOptions.verifyATN: 379 return 380 # verify assumptions 381 for state in atn.states: 382 if state is None: 383 continue 384 385 self.checkCondition(state.epsilonOnlyTransitions or len(state.transitions) <= 1) 386 387 if isinstance(state, PlusBlockStartState): 388 self.checkCondition(state.loopBackState is not None) 389 390 if isinstance(state, StarLoopEntryState): 391 self.checkCondition(state.loopBackState is not None) 392 self.checkCondition(len(state.transitions) == 2) 393 394 if isinstance(state.transitions[0].target, StarBlockStartState): 395 self.checkCondition(isinstance(state.transitions[1].target, LoopEndState)) 396 self.checkCondition(not state.nonGreedy) 397 elif isinstance(state.transitions[0].target, LoopEndState): 398 self.checkCondition(isinstance(state.transitions[1].target, StarBlockStartState)) 399 self.checkCondition(state.nonGreedy) 400 else: 401 raise Exception("IllegalState") 402 403 if isinstance(state, StarLoopbackState): 404 self.checkCondition(len(state.transitions) == 1) 405 self.checkCondition(isinstance(state.transitions[0].target, StarLoopEntryState)) 406 407 if isinstance(state, LoopEndState): 408 self.checkCondition(state.loopBackState is not None) 409 410 if isinstance(state, RuleStartState): 411 self.checkCondition(state.stopState is not None) 412 413 if isinstance(state, BlockStartState): 414 self.checkCondition(state.endState is not None) 415 416 if isinstance(state, BlockEndState): 417 self.checkCondition(state.startState is not None) 418 419 if isinstance(state, DecisionState): 420 self.checkCondition(len(state.transitions) <= 1 or state.decision >= 0) 421 else: 422 self.checkCondition(len(state.transitions) <= 1 or isinstance(state, RuleStopState)) 423 424 def checkCondition(self, condition:bool, message=None): 425 if not condition: 426 if message is None: 427 message = "IllegalState" 428 raise Exception(message) 429 430 def readInt(self): 431 i = self.data[self.pos] 432 self.pos += 1 433 return i 434 435 def readInt32(self): 436 low = self.readInt() 437 high = self.readInt() 438 return low | (high << 16) 439 440 def readLong(self): 441 low = self.readInt32() 442 high = self.readInt32() 443 return (low & 0x00000000FFFFFFFF) | (high << 32) 444 445 def readUUID(self): 446 low = self.readLong() 447 high = self.readLong() 448 allBits = (low & 0xFFFFFFFFFFFFFFFF) | (high << 64) 449 return UUID(int=allBits) 450 451 edgeFactories = [ lambda args : None, 452 lambda atn, src, trg, arg1, arg2, arg3, sets, target : EpsilonTransition(target), 453 lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ 454 RangeTransition(target, Token.EOF, arg2) if arg3 != 0 else RangeTransition(target, arg1, arg2), 455 lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ 456 RuleTransition(atn.states[arg1], arg2, arg3, target), 457 lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ 458 PredicateTransition(target, arg1, arg2, arg3 != 0), 459 lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ 460 AtomTransition(target, Token.EOF) if arg3 != 0 else AtomTransition(target, arg1), 461 lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ 462 ActionTransition(target, arg1, arg2, arg3 != 0), 463 lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ 464 SetTransition(target, sets[arg1]), 465 lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ 466 NotSetTransition(target, sets[arg1]), 467 lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ 468 WildcardTransition(target), 469 lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ 470 PrecedencePredicateTransition(target, arg1) 471 ] 472 473 def edgeFactory(self, atn:ATN, type:int, src:int, trg:int, arg1:int, arg2:int, arg3:int, sets:list): 474 target = atn.states[trg] 475 if type > len(self.edgeFactories) or self.edgeFactories[type] is None: 476 raise Exception("The specified transition type: " + str(type) + " is not valid.") 477 else: 478 return self.edgeFactories[type](atn, src, trg, arg1, arg2, arg3, sets, target) 479 480 stateFactories = [ lambda : None, 481 lambda : BasicState(), 482 lambda : RuleStartState(), 483 lambda : BasicBlockStartState(), 484 lambda : PlusBlockStartState(), 485 lambda : StarBlockStartState(), 486 lambda : TokensStartState(), 487 lambda : RuleStopState(), 488 lambda : BlockEndState(), 489 lambda : StarLoopbackState(), 490 lambda : StarLoopEntryState(), 491 lambda : PlusLoopbackState(), 492 lambda : LoopEndState() 493 ] 494 495 def stateFactory(self, type:int, ruleIndex:int): 496 if type> len(self.stateFactories) or self.stateFactories[type] is None: 497 raise Exception("The specified state type " + str(type) + " is not valid.") 498 else: 499 s = self.stateFactories[type]() 500 if s is not None: 501 s.ruleIndex = ruleIndex 502 return s 503 504 CHANNEL = 0 #The type of a {@link LexerChannelAction} action. 505 CUSTOM = 1 #The type of a {@link LexerCustomAction} action. 506 MODE = 2 #The type of a {@link LexerModeAction} action. 507 MORE = 3 #The type of a {@link LexerMoreAction} action. 508 POP_MODE = 4 #The type of a {@link LexerPopModeAction} action. 509 PUSH_MODE = 5 #The type of a {@link LexerPushModeAction} action. 510 SKIP = 6 #The type of a {@link LexerSkipAction} action. 511 TYPE = 7 #The type of a {@link LexerTypeAction} action. 512 513 actionFactories = [ lambda data1, data2: LexerChannelAction(data1), 514 lambda data1, data2: LexerCustomAction(data1, data2), 515 lambda data1, data2: LexerModeAction(data1), 516 lambda data1, data2: LexerMoreAction.INSTANCE, 517 lambda data1, data2: LexerPopModeAction.INSTANCE, 518 lambda data1, data2: LexerPushModeAction(data1), 519 lambda data1, data2: LexerSkipAction.INSTANCE, 520 lambda data1, data2: LexerTypeAction(data1) 521 ] 522 523 def lexerActionFactory(self, type:int, data1:int, data2:int): 524 525 if type > len(self.actionFactories) or self.actionFactories[type] is None: 526 raise Exception("The specified lexer action type " + str(type) + " is not valid.") 527 else: 528 return self.actionFactories[type](data1, data2) 529