1#
2# Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
3# Use of this file is governed by the BSD 3-clause license that
4# can be found in the LICENSE.txt file in the project root.
5#/
6from io import StringIO
7from antlr4.RuleContext import RuleContext
8from antlr4.error.Errors import IllegalStateException
9
10class PredictionContext(object):
11
12    # Represents {@code $} in local context prediction, which means wildcard.
13    # {@code#+x =#}.
14    #/
15    EMPTY = None
16
17    # Represents {@code $} in an array in full context mode, when {@code $}
18    # doesn't mean wildcard: {@code $ + x = [$,x]}. Here,
19    # {@code $} = {@link #EMPTY_RETURN_STATE}.
20    #/
21    EMPTY_RETURN_STATE = 0x7FFFFFFF
22
23    globalNodeCount = 1
24    id = globalNodeCount
25
26    # Stores the computed hash code of this {@link PredictionContext}. The hash
27    # code is computed in parts to match the following reference algorithm.
28    #
29    # <pre>
30    #  private int referenceHashCode() {
31    #      int hash = {@link MurmurHash#initialize MurmurHash.initialize}({@link #INITIAL_HASH});
32    #
33    #      for (int i = 0; i &lt; {@link #size()}; i++) {
34    #          hash = {@link MurmurHash#update MurmurHash.update}(hash, {@link #getParent getParent}(i));
35    #      }
36    #
37    #      for (int i = 0; i &lt; {@link #size()}; i++) {
38    #          hash = {@link MurmurHash#update MurmurHash.update}(hash, {@link #getReturnState getReturnState}(i));
39    #      }
40    #
41    #      hash = {@link MurmurHash#finish MurmurHash.finish}(hash, 2# {@link #size()});
42    #      return hash;
43    #  }
44    # </pre>
45    #/
46
47    def __init__(self, cachedHashCode):
48        self.cachedHashCode = cachedHashCode
49
50    def __len__(self):
51        return 0
52
53    # This means only the {@link #EMPTY} context is in set.
54    def isEmpty(self):
55        return self is self.EMPTY
56
57    def hasEmptyPath(self):
58        return self.getReturnState(len(self) - 1) == self.EMPTY_RETURN_STATE
59
60    def getReturnState(self, index):
61        raise IllegalStateException("illegal!")
62
63    def __hash__(self):
64        return self.cachedHashCode
65
66    def __str__(self):
67        return unicode(self)
68
69
70def calculateHashCode(parent, returnState):
71    return hash("") if parent is None else hash((hash(parent), returnState))
72
73def calculateListsHashCode(parents, returnStates ):
74    h = 0
75    for parent, returnState in zip(parents, returnStates):
76        h = hash((h, calculateHashCode(parent, returnState)))
77    return h
78
79#  Used to cache {@link PredictionContext} objects. Its used for the shared
80#  context cash associated with contexts in DFA states. This cache
81#  can be used for both lexers and parsers.
82
83class PredictionContextCache(object):
84
85    def __init__(self):
86        self.cache = dict()
87
88    #  Add a context to the cache and return it. If the context already exists,
89    #  return that one instead and do not add a new context to the cache.
90    #  Protect shared cache from unsafe thread access.
91    #
92    def add(self, ctx):
93        if ctx==PredictionContext.EMPTY:
94            return PredictionContext.EMPTY
95        existing = self.cache.get(ctx, None)
96        if existing is not None:
97            return existing
98        self.cache[ctx] = ctx
99        return ctx
100
101    def get(self, ctx):
102        return self.cache.get(ctx, None)
103
104    def __len__(self):
105        return len(self.cache)
106
107
108class SingletonPredictionContext(PredictionContext):
109
110    @staticmethod
111    def create(parent , returnState ):
112        if returnState == PredictionContext.EMPTY_RETURN_STATE and parent is None:
113            # someone can pass in the bits of an array ctx that mean $
114            return SingletonPredictionContext.EMPTY
115        else:
116            return SingletonPredictionContext(parent, returnState)
117
118    def __init__(self, parent, returnState):
119        hashCode = calculateHashCode(parent, returnState)
120        super(SingletonPredictionContext, self).__init__(hashCode)
121        self.parentCtx = parent
122        self.returnState = returnState
123
124    def __len__(self):
125        return 1
126
127    def getParent(self, index):
128        return self.parentCtx
129
130    def getReturnState(self, index):
131        return self.returnState
132
133    def __eq__(self, other):
134        if self is other:
135            return True
136        elif other is None:
137            return False
138        elif not isinstance(other, SingletonPredictionContext):
139            return False
140        else:
141            return self.returnState == other.returnState and self.parentCtx==other.parentCtx
142
143    def __hash__(self):
144        return self.cachedHashCode
145
146    def __unicode__(self):
147        up = "" if self.parentCtx is None else unicode(self.parentCtx)
148        if len(up)==0:
149            if self.returnState == self.EMPTY_RETURN_STATE:
150                return u"$"
151            else:
152                return unicode(self.returnState)
153        else:
154            return unicode(self.returnState) + u" " + up
155
156
157class EmptyPredictionContext(SingletonPredictionContext):
158
159    def __init__(self):
160        super(EmptyPredictionContext, self).__init__(None, self.EMPTY_RETURN_STATE)
161
162    def isEmpty(self):
163        return True
164
165    def __eq__(self, other):
166        return self is other
167
168    def __hash__(self):
169        return self.cachedHashCode
170
171    def __unicode__(self):
172        return "$"
173
174
175PredictionContext.EMPTY = EmptyPredictionContext()
176
177class ArrayPredictionContext(PredictionContext):
178    # Parent can be null only if full ctx mode and we make an array
179    #  from {@link #EMPTY} and non-empty. We merge {@link #EMPTY} by using null parent and
180    #  returnState == {@link #EMPTY_RETURN_STATE}.
181
182    def __init__(self, parents, returnStates):
183        super(ArrayPredictionContext, self).__init__(calculateListsHashCode(parents, returnStates))
184        self.parents = parents
185        self.returnStates = returnStates
186
187    def isEmpty(self):
188        # since EMPTY_RETURN_STATE can only appear in the last position, we
189        # don't need to verify that size==1
190        return self.returnStates[0]==PredictionContext.EMPTY_RETURN_STATE
191
192    def __len__(self):
193        return len(self.returnStates)
194
195    def getParent(self, index):
196        return self.parents[index]
197
198    def getReturnState(self, index):
199        return self.returnStates[index]
200
201    def __eq__(self, other):
202        if self is other:
203            return True
204        elif not isinstance(other, ArrayPredictionContext):
205            return False
206        elif hash(self) != hash(other):
207            return False # can't be same if hash is different
208        else:
209            return self.returnStates==other.returnStates and self.parents==other.parents
210
211    def __unicode__(self):
212        if self.isEmpty():
213            return "[]"
214        with StringIO() as buf:
215            buf.write(u"[")
216            for i in range(0,len(self.returnStates)):
217                if i>0:
218                    buf.write(u", ")
219                if self.returnStates[i]==PredictionContext.EMPTY_RETURN_STATE:
220                    buf.write(u"$")
221                    continue
222                buf.write(self.returnStates[i])
223                if self.parents[i] is not None:
224                    buf.write(u' ')
225                    buf.write(unicode(self.parents[i]))
226                else:
227                    buf.write(u"null")
228            buf.write(u"]")
229            return buf.getvalue()
230
231    def __hash__(self):
232        return self.cachedHashCode
233
234
235
236#  Convert a {@link RuleContext} tree to a {@link PredictionContext} graph.
237#  Return {@link #EMPTY} if {@code outerContext} is empty or null.
238#/
239def PredictionContextFromRuleContext(atn, outerContext=None):
240    if outerContext is None:
241        outerContext = RuleContext.EMPTY
242
243    # if we are in RuleContext of start rule, s, then PredictionContext
244    # is EMPTY. Nobody called us. (if we are empty, return empty)
245    if outerContext.parentCtx is None or outerContext is RuleContext.EMPTY:
246        return PredictionContext.EMPTY
247
248    # If we have a parent, convert it to a PredictionContext graph
249    parent = PredictionContextFromRuleContext(atn, outerContext.parentCtx)
250    state = atn.states[outerContext.invokingState]
251    transition = state.transitions[0]
252    return SingletonPredictionContext.create(parent, transition.followState.stateNumber)
253
254
255def merge(a, b, rootIsWildcard, mergeCache):
256
257    # share same graph if both same
258    if a==b:
259        return a
260
261    if isinstance(a, SingletonPredictionContext) and isinstance(b, SingletonPredictionContext):
262        return mergeSingletons(a, b, rootIsWildcard, mergeCache)
263
264    # At least one of a or b is array
265    # If one is $ and rootIsWildcard, return $ as# wildcard
266    if rootIsWildcard:
267        if isinstance( a, EmptyPredictionContext ):
268            return a
269        if isinstance( b, EmptyPredictionContext ):
270            return b
271
272    # convert singleton so both are arrays to normalize
273    if isinstance( a, SingletonPredictionContext ):
274        a = ArrayPredictionContext([a.parentCtx], [a.returnState])
275    if isinstance( b, SingletonPredictionContext):
276        b = ArrayPredictionContext([b.parentCtx], [b.returnState])
277    return mergeArrays(a, b, rootIsWildcard, mergeCache)
278
279
280#
281# Merge two {@link SingletonPredictionContext} instances.
282#
283# <p>Stack tops equal, parents merge is same; return left graph.<br>
284# <embed src="images/SingletonMerge_SameRootSamePar.svg" type="image/svg+xml"/></p>
285#
286# <p>Same stack top, parents differ; merge parents giving array node, then
287# remainders of those graphs. A new root node is created to point to the
288# merged parents.<br>
289# <embed src="images/SingletonMerge_SameRootDiffPar.svg" type="image/svg+xml"/></p>
290#
291# <p>Different stack tops pointing to same parent. Make array node for the
292# root where both element in the root point to the same (original)
293# parent.<br>
294# <embed src="images/SingletonMerge_DiffRootSamePar.svg" type="image/svg+xml"/></p>
295#
296# <p>Different stack tops pointing to different parents. Make array node for
297# the root where each element points to the corresponding original
298# parent.<br>
299# <embed src="images/SingletonMerge_DiffRootDiffPar.svg" type="image/svg+xml"/></p>
300#
301# @param a the first {@link SingletonPredictionContext}
302# @param b the second {@link SingletonPredictionContext}
303# @param rootIsWildcard {@code true} if this is a local-context merge,
304# otherwise false to indicate a full-context merge
305# @param mergeCache
306#/
307def mergeSingletons(a, b, rootIsWildcard, mergeCache):
308    if mergeCache is not None:
309        previous = mergeCache.get((a,b), None)
310        if previous is not None:
311            return previous
312        previous = mergeCache.get((b,a), None)
313        if previous is not None:
314            return previous
315
316    merged = mergeRoot(a, b, rootIsWildcard)
317    if merged is not None:
318        if mergeCache is not None:
319            mergeCache[(a, b)] = merged
320        return merged
321
322    if a.returnState==b.returnState:
323        parent = merge(a.parentCtx, b.parentCtx, rootIsWildcard, mergeCache)
324        # if parent is same as existing a or b parent or reduced to a parent, return it
325        if parent == a.parentCtx:
326            return a # ax + bx = ax, if a=b
327        if parent == b.parentCtx:
328            return b # ax + bx = bx, if a=b
329        # else: ax + ay = a'[x,y]
330        # merge parents x and y, giving array node with x,y then remainders
331        # of those graphs.  dup a, a' points at merged array
332        # new joined parent so create new singleton pointing to it, a'
333        merged = SingletonPredictionContext.create(parent, a.returnState)
334        if mergeCache is not None:
335            mergeCache[(a, b)] = merged
336        return merged
337    else: # a != b payloads differ
338        # see if we can collapse parents due to $+x parents if local ctx
339        singleParent = None
340        if a is b or (a.parentCtx is not None and a.parentCtx==b.parentCtx): # ax + bx = [a,b]x
341            singleParent = a.parentCtx
342        if singleParent is not None:	# parents are same
343            # sort payloads and use same parent
344            payloads = [ a.returnState, b.returnState ]
345            if a.returnState > b.returnState:
346                payloads = [ b.returnState, a.returnState ]
347            parents = [singleParent, singleParent]
348            merged = ArrayPredictionContext(parents, payloads)
349            if mergeCache is not None:
350                mergeCache[(a, b)] = merged
351            return merged
352        # parents differ and can't merge them. Just pack together
353        # into array; can't merge.
354        # ax + by = [ax,by]
355        payloads = [ a.returnState, b.returnState ]
356        parents = [ a.parentCtx, b.parentCtx ]
357        if a.returnState > b.returnState: # sort by payload
358            payloads = [ b.returnState, a.returnState ]
359            parents = [ b.parentCtx, a.parentCtx ]
360        merged = ArrayPredictionContext(parents, payloads)
361        if mergeCache is not None:
362            mergeCache[(a, b)] = merged
363        return merged
364
365
366#
367# Handle case where at least one of {@code a} or {@code b} is
368# {@link #EMPTY}. In the following diagrams, the symbol {@code $} is used
369# to represent {@link #EMPTY}.
370#
371# <h2>Local-Context Merges</h2>
372#
373# <p>These local-context merge operations are used when {@code rootIsWildcard}
374# is true.</p>
375#
376# <p>{@link #EMPTY} is superset of any graph; return {@link #EMPTY}.<br>
377# <embed src="images/LocalMerge_EmptyRoot.svg" type="image/svg+xml"/></p>
378#
379# <p>{@link #EMPTY} and anything is {@code #EMPTY}, so merged parent is
380# {@code #EMPTY}; return left graph.<br>
381# <embed src="images/LocalMerge_EmptyParent.svg" type="image/svg+xml"/></p>
382#
383# <p>Special case of last merge if local context.<br>
384# <embed src="images/LocalMerge_DiffRoots.svg" type="image/svg+xml"/></p>
385#
386# <h2>Full-Context Merges</h2>
387#
388# <p>These full-context merge operations are used when {@code rootIsWildcard}
389# is false.</p>
390#
391# <p><embed src="images/FullMerge_EmptyRoots.svg" type="image/svg+xml"/></p>
392#
393# <p>Must keep all contexts; {@link #EMPTY} in array is a special value (and
394# null parent).<br>
395# <embed src="images/FullMerge_EmptyRoot.svg" type="image/svg+xml"/></p>
396#
397# <p><embed src="images/FullMerge_SameRoot.svg" type="image/svg+xml"/></p>
398#
399# @param a the first {@link SingletonPredictionContext}
400# @param b the second {@link SingletonPredictionContext}
401# @param rootIsWildcard {@code true} if this is a local-context merge,
402# otherwise false to indicate a full-context merge
403#/
404def mergeRoot(a, b, rootIsWildcard):
405    if rootIsWildcard:
406        if a == PredictionContext.EMPTY:
407            return PredictionContext.EMPTY  ## + b =#
408        if b == PredictionContext.EMPTY:
409            return PredictionContext.EMPTY  # a +# =#
410    else:
411        if a == PredictionContext.EMPTY and b == PredictionContext.EMPTY:
412            return PredictionContext.EMPTY # $ + $ = $
413        elif a == PredictionContext.EMPTY: # $ + x = [$,x]
414            payloads = [ b.returnState, PredictionContext.EMPTY_RETURN_STATE ]
415            parents = [ b.parentCtx, None ]
416            return ArrayPredictionContext(parents, payloads)
417        elif b == PredictionContext.EMPTY: # x + $ = [$,x] ($ is always first if present)
418            payloads = [ a.returnState, PredictionContext.EMPTY_RETURN_STATE ]
419            parents = [ a.parentCtx, None ]
420            return ArrayPredictionContext(parents, payloads)
421    return None
422
423
424#
425# Merge two {@link ArrayPredictionContext} instances.
426#
427# <p>Different tops, different parents.<br>
428# <embed src="images/ArrayMerge_DiffTopDiffPar.svg" type="image/svg+xml"/></p>
429#
430# <p>Shared top, same parents.<br>
431# <embed src="images/ArrayMerge_ShareTopSamePar.svg" type="image/svg+xml"/></p>
432#
433# <p>Shared top, different parents.<br>
434# <embed src="images/ArrayMerge_ShareTopDiffPar.svg" type="image/svg+xml"/></p>
435#
436# <p>Shared top, all shared parents.<br>
437# <embed src="images/ArrayMerge_ShareTopSharePar.svg" type="image/svg+xml"/></p>
438#
439# <p>Equal tops, merge parents and reduce top to
440# {@link SingletonPredictionContext}.<br>
441# <embed src="images/ArrayMerge_EqualTop.svg" type="image/svg+xml"/></p>
442#/
443def mergeArrays(a, b, rootIsWildcard, mergeCache):
444    if mergeCache is not None:
445        previous = mergeCache.get((a,b), None)
446        if previous is not None:
447            return previous
448        previous = mergeCache.get((b,a), None)
449        if previous is not None:
450            return previous
451
452    # merge sorted payloads a + b => M
453    i = 0 # walks a
454    j = 0 # walks b
455    k = 0 # walks target M array
456
457    mergedReturnStates = [None] * (len(a.returnStates) + len( b.returnStates))
458    mergedParents = [None] * len(mergedReturnStates)
459    # walk and merge to yield mergedParents, mergedReturnStates
460    while i<len(a.returnStates) and j<len(b.returnStates):
461        a_parent = a.parents[i]
462        b_parent = b.parents[j]
463        if a.returnStates[i]==b.returnStates[j]:
464            # same payload (stack tops are equal), must yield merged singleton
465            payload = a.returnStates[i]
466            # $+$ = $
467            bothDollars = payload == PredictionContext.EMPTY_RETURN_STATE and \
468                            a_parent is None and b_parent is None
469            ax_ax = (a_parent is not None and b_parent is not None) and a_parent==b_parent # ax+ax -> ax
470            if bothDollars or ax_ax:
471                mergedParents[k] = a_parent # choose left
472                mergedReturnStates[k] = payload
473            else: # ax+ay -> a'[x,y]
474                mergedParent = merge(a_parent, b_parent, rootIsWildcard, mergeCache)
475                mergedParents[k] = mergedParent
476                mergedReturnStates[k] = payload
477            i += 1 # hop over left one as usual
478            j += 1 # but also skip one in right side since we merge
479        elif a.returnStates[i]<b.returnStates[j]: # copy a[i] to M
480            mergedParents[k] = a_parent
481            mergedReturnStates[k] = a.returnStates[i]
482            i += 1
483        else: # b > a, copy b[j] to M
484            mergedParents[k] = b_parent
485            mergedReturnStates[k] = b.returnStates[j]
486            j += 1
487        k += 1
488
489    # copy over any payloads remaining in either array
490    if i < len(a.returnStates):
491        for p in range(i, len(a.returnStates)):
492            mergedParents[k] = a.parents[p]
493            mergedReturnStates[k] = a.returnStates[p]
494            k += 1
495    else:
496        for p in range(j, len(b.returnStates)):
497            mergedParents[k] = b.parents[p]
498            mergedReturnStates[k] = b.returnStates[p]
499            k += 1
500
501    # trim merged if we combined a few that had same stack tops
502    if k < len(mergedParents): # write index < last position; trim
503        if k == 1: # for just one merged element, return singleton top
504            merged = SingletonPredictionContext.create(mergedParents[0], mergedReturnStates[0])
505            if mergeCache is not None:
506                mergeCache[(a,b)] = merged
507            return merged
508        mergedParents = mergedParents[0:k]
509        mergedReturnStates = mergedReturnStates[0:k]
510
511    merged = ArrayPredictionContext(mergedParents, mergedReturnStates)
512
513    # if we created same array as a or b, return that instead
514    # TODO: track whether this is possible above during merge sort for speed
515    if merged==a:
516        if mergeCache is not None:
517            mergeCache[(a,b)] = a
518        return a
519    if merged==b:
520        if mergeCache is not None:
521            mergeCache[(a,b)] = b
522        return b
523    combineCommonParents(mergedParents)
524
525    if mergeCache is not None:
526        mergeCache[(a,b)] = merged
527    return merged
528
529
530#
531# Make pass over all <em>M</em> {@code parents}; merge any {@code equals()}
532# ones.
533#/
534def combineCommonParents(parents):
535    uniqueParents = dict()
536
537    for p in range(0, len(parents)):
538        parent = parents[p]
539        if uniqueParents.get(parent, None) is None:
540            uniqueParents[parent] = parent
541
542    for p in range(0, len(parents)):
543        parents[p] = uniqueParents[parents[p]]
544
545def getCachedPredictionContext(context, contextCache, visited):
546    if context.isEmpty():
547        return context
548    existing = visited.get(context)
549    if existing is not None:
550        return existing
551    existing = contextCache.get(context)
552    if existing is not None:
553        visited[context] = existing
554        return existing
555    changed = False
556    parents = [None] * len(context)
557    for i in range(0, len(parents)):
558        parent = getCachedPredictionContext(context.getParent(i), contextCache, visited)
559        if changed or parent is not context.getParent(i):
560            if not changed:
561                parents = [context.getParent(j) for j in range(len(context))]
562                changed = True
563            parents[i] = parent
564    if not changed:
565        contextCache.add(context)
566        visited[context] = context
567        return context
568
569    updated = None
570    if len(parents) == 0:
571        updated = PredictionContext.EMPTY
572    elif len(parents) == 1:
573        updated = SingletonPredictionContext.create(parents[0], context.getReturnState(0))
574    else:
575        updated = ArrayPredictionContext(parents, context.returnStates)
576
577    contextCache.add(updated)
578    visited[updated] = updated
579    visited[context] = updated
580
581    return updated
582
583
584#	# extra structures, but cut/paste/morphed works, so leave it.
585#	# seems to do a breadth-first walk
586#	public static List<PredictionContext> getAllNodes(PredictionContext context) {
587#		Map<PredictionContext, PredictionContext> visited =
588#			new IdentityHashMap<PredictionContext, PredictionContext>();
589#		Deque<PredictionContext> workList = new ArrayDeque<PredictionContext>();
590#		workList.add(context);
591#		visited.put(context, context);
592#		List<PredictionContext> nodes = new ArrayList<PredictionContext>();
593#		while (!workList.isEmpty()) {
594#			PredictionContext current = workList.pop();
595#			nodes.add(current);
596#			for (int i = 0; i < current.size(); i++) {
597#				PredictionContext parent = current.getParent(i);
598#				if ( parent!=null && visited.put(parent, parent) == null) {
599#					workList.push(parent);
600#				}
601#			}
602#		}
603#		return nodes;
604#	}
605
606# ter's recursive version of Sam's getAllNodes()
607def getAllContextNodes(context, nodes=None, visited=None):
608    if nodes is None:
609        nodes = list()
610        return getAllContextNodes(context, nodes, visited)
611    elif visited is None:
612        visited = dict()
613        return getAllContextNodes(context, nodes, visited)
614    else:
615        if context is None or visited.get(context, None) is not None:
616            return nodes
617        visited.put(context, context)
618        nodes.add(context)
619        for i in range(0, len(context)):
620            getAllContextNodes(context.getParent(i), nodes, visited)
621        return nodes
622
623