1#     Copyright 2021, Kay Hayen, mailto:kay.hayen@gmail.com
2#
3#     Part of "Nuitka", an optimizing Python compiler that is compatible and
4#     integrates with CPython, but also works on its own.
5#
6#     Licensed under the Apache License, Version 2.0 (the "License");
7#     you may not use this file except in compliance with the License.
8#     You may obtain a copy of the License at
9#
10#        http://www.apache.org/licenses/LICENSE-2.0
11#
12#     Unless required by applicable law or agreed to in writing, software
13#     distributed under the License is distributed on an "AS IS" BASIS,
14#     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15#     See the License for the specific language governing permissions and
16#     limitations under the License.
17#
18""" Value trace objects.
19
20Value traces indicate the flow of values and merges their versions for
21the SSA (Single State Assignment) form being used in Nuitka.
22
23Values can be seen as:
24
25* Unknown (maybe initialized, maybe not, we cannot know)
26* Uninit (definitely not initialized, first version)
27* Init (definitely initialized, e.g. parameter variables)
28* Assign (assignment was done)
29* Deleted (del was done, now unassigned, uninitialted)
30* Merge (result of diverged code paths, loop potentially)
31* LoopIncomplete (aggregation during loops, not yet fully known)
32* LoopComplete (complete knowledge of loop types)
33"""
34
35from nuitka.nodes.shapes.BuiltinTypeShapes import tshape_dict, tshape_tuple
36from nuitka.nodes.shapes.StandardShapes import (
37    ShapeLoopCompleteAlternative,
38    ShapeLoopInitialAlternative,
39    tshape_uninit,
40    tshape_unknown,
41)
42from nuitka.utils.InstanceCounters import (
43    counted_del,
44    counted_init,
45    isCountingInstances,
46)
47
48
49class ValueTraceBase(object):
50    # We are going to have many instance attributes, but should strive to minimize, as
51    # there is going to be a lot of fluctuation in these objects.
52
53    __slots__ = (
54        "owner",
55        "usage_count",
56        "name_usage_count",
57        "merge_usage_count",
58        "closure_usages",
59        "previous",
60    )
61
62    @counted_init
63    def __init__(self, owner, previous):
64        self.owner = owner
65
66        # Definite usage indicator.
67        self.usage_count = 0
68
69        # If 0, this indicates, the variable name needs to be assigned as name.
70        self.name_usage_count = 0
71
72        # If 0, this indicates no value merges happened on the value.
73        self.merge_usage_count = 0
74
75        self.closure_usages = False
76
77        # Previous trace this is replacing.
78        self.previous = previous
79
80    if isCountingInstances():
81        __del__ = counted_del()
82
83    def __repr__(self):
84        return "<%s of %s>" % (self.__class__.__name__, self.owner.getCodeName())
85
86    def getOwner(self):
87        return self.owner
88
89    @staticmethod
90    def isLoopTrace():
91        return False
92
93    def addUsage(self):
94        self.usage_count += 1
95
96    def addNameUsage(self):
97        self.usage_count += 1
98        self.name_usage_count += 1
99
100        if self.name_usage_count <= 2 and self.previous is not None:
101            self.previous.addNameUsage()
102
103    def addMergeUsage(self):
104        self.usage_count += 1
105        self.merge_usage_count += 1
106
107    def getUsageCount(self):
108        return self.usage_count
109
110    def getNameUsageCount(self):
111        return self.name_usage_count
112
113    def getMergeUsageCount(self):
114        return self.merge_usage_count
115
116    def getMergeOrNameUsageCount(self):
117        return self.merge_usage_count + self.name_usage_count
118
119    def getPrevious(self):
120        return self.previous
121
122    @staticmethod
123    def isAssignTrace():
124        return False
125
126    @staticmethod
127    def isUnassignedTrace():
128        return False
129
130    @staticmethod
131    def isDeletedTrace():
132        return False
133
134    @staticmethod
135    def isUninitTrace():
136        return False
137
138    @staticmethod
139    def isInitTrace():
140        return False
141
142    @staticmethod
143    def isUnknownTrace():
144        return False
145
146    @staticmethod
147    def isEscapeTrace():
148        return False
149
150    @staticmethod
151    def isEscapeOrUnknownTrace():
152        return False
153
154    @staticmethod
155    def isMergeTrace():
156        return False
157
158    def mustHaveValue(self):
159        """Will this definitely have a value.
160
161        Every trace has this overloaded.
162        """
163        assert False, self
164
165    def mustNotHaveValue(self):
166        """Will this definitely have a value.
167
168        Every trace has this overloaded.
169        """
170        assert False, self
171
172    def getReplacementNode(self, usage):
173        # Virtual method, pylint: disable=no-self-use,unused-argument
174
175        return None
176
177    @staticmethod
178    def hasShapeDictionaryExact():
179        return False
180
181    @staticmethod
182    def getTruthValue():
183        return None
184
185
186class ValueTraceUnassignedBase(ValueTraceBase):
187    __slots__ = ()
188
189    @staticmethod
190    def isUnassignedTrace():
191        return True
192
193    @staticmethod
194    def getTypeShape():
195        return tshape_uninit
196
197    def compareValueTrace(self, other):
198        # We are unassigned, just need to know if the other one is, pylint: disable=no-self-use
199        return other.isUnassignedTrace()
200
201    @staticmethod
202    def mustHaveValue():
203        return False
204
205    @staticmethod
206    def mustNotHaveValue():
207        return True
208
209
210class ValueTraceUninit(ValueTraceUnassignedBase):
211    __slots__ = ()
212
213    def __init__(self, owner, previous):
214        ValueTraceUnassignedBase.__init__(self, owner=owner, previous=previous)
215
216    @staticmethod
217    def isUninitTrace():
218        return True
219
220
221class ValueTraceDeleted(ValueTraceUnassignedBase):
222    """Trace caused by a deletion."""
223
224    __slots__ = ("del_node",)
225
226    def __init__(self, owner, previous, del_node):
227        ValueTraceUnassignedBase.__init__(self, owner=owner, previous=previous)
228
229        self.del_node = del_node
230
231    @staticmethod
232    def isDeletedTrace():
233        return True
234
235    def getDelNode(self):
236        return self.del_node
237
238
239class ValueTraceInit(ValueTraceBase):
240    __slots__ = ()
241
242    def __init__(self, owner):
243        ValueTraceBase.__init__(self, owner=owner, previous=None)
244
245    @staticmethod
246    def getTypeShape():
247        return tshape_unknown
248
249    def compareValueTrace(self, other):
250        # We are initialized, just need to know if the other one is, pylint: disable=no-self-use
251        return other.isInitTrace()
252
253    @staticmethod
254    def isInitTrace():
255        return True
256
257    @staticmethod
258    def mustHaveValue():
259        return True
260
261    @staticmethod
262    def mustNotHaveValue():
263        return False
264
265
266class ValueTraceInitStarArgs(ValueTraceInit):
267    @staticmethod
268    def getTypeShape():
269        return tshape_tuple
270
271
272class ValueTraceInitStarDict(ValueTraceInit):
273    @staticmethod
274    def getTypeShape():
275        return tshape_dict
276
277
278class ValueTraceUnknown(ValueTraceBase):
279    __slots__ = ()
280
281    def __init__(self, owner, previous):
282        ValueTraceBase.__init__(self, owner=owner, previous=previous)
283
284    @staticmethod
285    def getTypeShape():
286        return tshape_unknown
287
288    def addUsage(self):
289        self.usage_count += 1
290
291        if self.previous:
292            self.previous.addUsage()
293
294    def addMergeUsage(self):
295        self.usage_count += 1
296        self.merge_usage_count += 1
297
298        if self.previous:
299            self.previous.addMergeUsage()
300
301    def compareValueTrace(self, other):
302        # We are unknown, just need to know if the other one is, pylint: disable=no-self-use
303        return other.isUnknownTrace()
304
305    @staticmethod
306    def isUnknownTrace():
307        return True
308
309    @staticmethod
310    def isEscapeOrUnknownTrace():
311        return True
312
313    @staticmethod
314    def mustHaveValue():
315        return False
316
317    @staticmethod
318    def mustNotHaveValue():
319        return False
320
321
322class ValueTraceEscaped(ValueTraceUnknown):
323    __slots__ = ()
324
325    def addUsage(self):
326        self.usage_count += 1
327
328        # The previous must be prevented from optimization if still used afterwards.
329        if self.usage_count <= 2:
330            self.previous.addNameUsage()
331
332    def addMergeUsage(self):
333        self.usage_count += 1
334        if self.usage_count <= 2:
335            self.previous.addNameUsage()
336
337        self.merge_usage_count += 1
338        if self.merge_usage_count <= 2:
339            self.previous.addMergeUsage()
340
341    def mustHaveValue(self):
342        return self.previous.mustHaveValue()
343
344    def mustNotHaveValue(self):
345        return self.previous.mustNotHaveValue()
346
347    def getReplacementNode(self, usage):
348        return self.previous.getReplacementNode(usage)
349
350    @staticmethod
351    def isUnknownTrace():
352        return False
353
354    @staticmethod
355    def isEscapeTrace():
356        return True
357
358    @staticmethod
359    def isEscapeOrUnknownTrace():
360        return True
361
362
363class ValueTraceAssign(ValueTraceBase):
364    __slots__ = ("assign_node", "replace_it")
365
366    def __init__(self, owner, assign_node, previous):
367        ValueTraceBase.__init__(self, owner=owner, previous=previous)
368
369        self.assign_node = assign_node
370        self.replace_it = None
371
372    def __repr__(self):
373        return "<ValueTraceAssign at {source_ref} of {value}>".format(
374            source_ref=self.assign_node.getSourceReference().getAsString(),
375            value=self.assign_node.subnode_source,
376        )
377
378    @staticmethod
379    def isAssignTrace():
380        return True
381
382    def compareValueTrace(self, other):
383        return other.isAssignTrace() and self.assign_node is other.assign_node
384
385    @staticmethod
386    def mustHaveValue():
387        return True
388
389    @staticmethod
390    def mustNotHaveValue():
391        return False
392
393    def getTypeShape(self):
394        return self.assign_node.getTypeShape()
395
396    def getAssignNode(self):
397        return self.assign_node
398
399    def setReplacementNode(self, replacement):
400        self.replace_it = replacement
401
402    def getReplacementNode(self, usage):
403        if self.replace_it is not None:
404            return self.replace_it(usage)
405        else:
406            return None
407
408    def hasShapeDictionaryExact(self):
409        return self.assign_node.subnode_source.hasShapeDictionaryExact()
410
411    def getTruthValue(self):
412        return self.assign_node.subnode_source.getTruthValue()
413
414
415class ValueTraceMergeBase(ValueTraceBase):
416    """Merge of two or more traces or start of loops."""
417
418    __slots__ = ()
419
420    def addNameUsage(self):
421        self.usage_count += 1
422        self.name_usage_count += 1
423
424        if self.name_usage_count <= 2 and self.previous is not None:
425            for previous in self.previous:
426                previous.addNameUsage()
427
428
429class ValueTraceMerge(ValueTraceMergeBase):
430    """Merge of two or more traces.
431
432    Happens at the end of conditional blocks. This is "phi" in
433    SSA theory. Also used for merging multiple "return", "break" or
434    "continue" exits.
435    """
436
437    __slots__ = ()
438
439    def __init__(self, traces):
440        ValueTraceMergeBase.__init__(self, owner=traces[0].owner, previous=traces)
441
442        for trace in traces:
443            trace.addMergeUsage()
444
445    def __repr__(self):
446        return "<ValueTraceMerge of {previous}>".format(previous=self.previous)
447
448    def getTypeShape(self):
449        type_shapes = set()
450
451        for trace in self.previous:
452            type_shape = trace.getTypeShape()
453
454            if type_shape is tshape_unknown:
455                return tshape_unknown
456
457            type_shapes.add(type_shape)
458
459        # TODO: Find the lowest common denominator.
460        if len(type_shapes) == 1:
461            return type_shapes.pop()
462        else:
463            return tshape_unknown
464
465    @staticmethod
466    def isMergeTrace():
467        return True
468
469    def compareValueTrace(self, other):
470        if not other.isMergeTrace():
471            return False
472
473        if len(self.previous) != len(other.previous):
474            return False
475
476        for a, b in zip(self.previous, other.previous):
477            if not a.compareValueTrace(b):
478                return False
479
480        return True
481
482    def mustHaveValue(self):
483        for previous in self.previous:
484            if not previous.isInitTrace() and not previous.isAssignTrace():
485                return False
486
487        return True
488
489    def mustNotHaveValue(self):
490        for previous in self.previous:
491            if not previous.mustNotHaveValue():
492                return False
493
494        return True
495
496    def addUsage(self):
497        self.usage_count += 1
498
499    def hasShapeDictionaryExact(self):
500        return all(previous.hasShapeDictionaryExact() for previous in self.previous)
501
502    def getTruthValue(self):
503        any_false = False
504        any_true = False
505
506        for previous in self.previous:
507            truth_value = previous.getTruthValue()
508
509            # One unknown kills it.
510            if truth_value is None:
511                return None
512            elif truth_value is True:
513                # True and false values resembled unknown.
514                if any_false:
515                    return None
516                any_true = True
517            else:
518                # True and false values resembled unknown.
519                if any_true:
520                    return None
521                any_false = True
522
523        # Now all agreed and were not unknown, so we can conclude all false or all true.
524        return any_true
525
526
527class ValueTraceLoopBase(ValueTraceMergeBase):
528    __slots__ = ("loop_node", "type_shapes", "type_shape", "recursion")
529
530    def __init__(self, loop_node, previous, type_shapes):
531        # Note: That previous is being added to later.
532        ValueTraceMergeBase.__init__(self, owner=previous.owner, previous=(previous,))
533
534        previous.addMergeUsage()
535
536        self.loop_node = loop_node
537        self.type_shapes = type_shapes
538        self.type_shape = None
539
540        self.recursion = False
541
542    def __repr__(self):
543        return "<%s shapes %s of %s>" % (
544            self.__class__.__name__,
545            self.type_shapes,
546            self.owner.getCodeName(),
547        )
548
549    @staticmethod
550    def isLoopTrace():
551        return True
552
553    def getTypeShape(self):
554        if self.type_shape is None:
555            if len(self.type_shapes) > 1:
556                self.type_shape = ShapeLoopCompleteAlternative(self.type_shapes)
557            else:
558                self.type_shape = next(iter(self.type_shapes))
559
560        return self.type_shape
561
562    def addLoopContinueTraces(self, continue_traces):
563        self.previous += tuple(continue_traces)
564
565        for previous in continue_traces:
566            previous.addMergeUsage()
567
568    def mustHaveValue(self):
569        # To handle recursion, we lie to ourselves.
570        if self.recursion:
571            return True
572
573        self.recursion = True
574
575        for previous in self.previous:
576            if not previous.mustHaveValue():
577                self.recursion = False
578                return False
579
580        self.recursion = False
581        return True
582
583
584class ValueTraceLoopComplete(ValueTraceLoopBase):
585    __slots__ = ()
586
587    def compareValueTrace(self, other):
588        # Incomplete loop value traces behave the same.
589        return (
590            self.__class__ is other.__class__
591            and self.loop_node == other.loop_node
592            and self.type_shapes == other.type_shapes
593        )
594
595    # TODO: These could be better
596    @staticmethod
597    def mustHaveValue():
598        return False
599
600    @staticmethod
601    def mustNotHaveValue():
602        return False
603
604    @staticmethod
605    def getTruthValue():
606        return None
607
608
609class ValueTraceLoopIncomplete(ValueTraceLoopBase):
610    __slots__ = ()
611
612    def getTypeShape(self):
613        if self.type_shape is None:
614            self.type_shape = ShapeLoopInitialAlternative(self.type_shapes)
615
616        return self.type_shape
617
618    def compareValueTrace(self, other):
619        # Incomplete loop value traces behave the same.
620        return self.__class__ is other.__class__ and self.loop_node == other.loop_node
621
622    @staticmethod
623    def mustHaveValue():
624        return False
625
626    @staticmethod
627    def mustNotHaveValue():
628        return False
629
630    @staticmethod
631    def getTruthValue():
632        return None
633