1#     Copyright 2021, Kay Hayen, mailto:kay.hayen@gmail.com
2#
3#     Part of "Nuitka", an optimizing Python compiler that is compatible and
4#     integrates with CPython, but also works on its own.
5#
6#     Licensed under the Apache License, Version 2.0 (the "License");
7#     you may not use this file except in compliance with the License.
8#     You may obtain a copy of the License at
9#
10#        http://www.apache.org/licenses/LICENSE-2.0
11#
12#     Unless required by applicable law or agreed to in writing, software
13#     distributed under the License is distributed on an "AS IS" BASIS,
14#     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15#     See the License for the specific language governing permissions and
16#     limitations under the License.
17#
18""" Trace collection (also often still referred to as constraint collection).
19
20At the core of value propagation there is the collection of constraints that
21allow to propagate knowledge forward or not.
22
23This is about collecting these constraints and to manage them.
24"""
25
26import contextlib
27from collections import defaultdict
28from contextlib import contextmanager
29
30from nuitka import Tracing, Variables
31from nuitka.__past__ import iterItems  # Python3 compatibility.
32from nuitka.containers.oset import OrderedSet
33from nuitka.importing.ImportCache import getImportedModuleByNameAndPath
34from nuitka.ModuleRegistry import addUsedModule
35from nuitka.nodes.NodeMakingHelpers import getComputationResult
36from nuitka.nodes.shapes.BuiltinTypeShapes import tshape_dict
37from nuitka.nodes.shapes.StandardShapes import tshape_uninit
38from nuitka.tree.SourceReading import readSourceLine
39from nuitka.utils.FileOperations import relpath
40from nuitka.utils.InstanceCounters import (
41    counted_del,
42    counted_init,
43    isCountingInstances,
44)
45from nuitka.utils.ModuleNames import ModuleName
46from nuitka.utils.Timing import TimerReport
47
48from .ValueTraces import (
49    ValueTraceAssign,
50    ValueTraceDeleted,
51    ValueTraceEscaped,
52    ValueTraceInit,
53    ValueTraceInitStarArgs,
54    ValueTraceInitStarDict,
55    ValueTraceLoopComplete,
56    ValueTraceLoopIncomplete,
57    ValueTraceMerge,
58    ValueTraceUninit,
59    ValueTraceUnknown,
60)
61
62signalChange = None
63
64
65@contextmanager
66def withChangeIndicationsTo(signal_change):
67    """Decide where change indications should go to."""
68
69    global signalChange  # Singleton, pylint: disable=global-statement
70
71    old = signalChange
72    signalChange = signal_change
73    yield
74    signalChange = old
75
76
77class CollectionUpdateMixin(object):
78    """Mixin to use in every collection to add traces."""
79
80    # Mixins are not allow to specify slots.
81    __slots__ = ()
82
83    def hasVariableTrace(self, variable, version):
84        return (variable, version) in self.variable_traces
85
86    def getVariableTrace(self, variable, version):
87        return self.variable_traces[(variable, version)]
88
89    def getVariableTraces(self, variable):
90        result = []
91
92        for key, variable_trace in iterItems(self.variable_traces):
93            candidate = key[0]
94
95            if variable is candidate:
96                result.append(variable_trace)
97
98        return result
99
100    def getVariableTracesAll(self):
101        return self.variable_traces
102
103    def addVariableTrace(self, variable, version, trace):
104        key = variable, version
105
106        assert key not in self.variable_traces, (key, self)
107        self.variable_traces[key] = trace
108
109    def addVariableMergeMultipleTrace(self, variable, traces):
110        version = variable.allocateTargetNumber()
111
112        trace_merge = ValueTraceMerge(traces)
113
114        self.addVariableTrace(variable, version, trace_merge)
115
116        return version
117
118
119class CollectionStartpointMixin(CollectionUpdateMixin):
120    """Mixin to use in start points of collections.
121
122    These are modules, functions, etc. typically entry points.
123    """
124
125    # Mixins are not allow to specify slots, pylint: disable=assigning-non-slot
126    __slots__ = ()
127
128    # Many things are traced
129
130    def __init__(self):
131        # Variable assignments performed in here, last issued number, only used
132        # to determine the next number that should be used for a new assignment.
133        self.variable_versions = {}
134
135        # The full trace of a variable with a version for the function or module
136        # this is.
137        self.variable_traces = {}
138
139        self.break_collections = None
140        self.continue_collections = None
141        self.return_collections = None
142        self.exception_collections = None
143
144        self.outline_functions = None
145
146    def getLoopBreakCollections(self):
147        return self.break_collections
148
149    def onLoopBreak(self, collection=None):
150        if collection is None:
151            collection = self
152
153        self.break_collections.append(
154            TraceCollectionBranch(parent=collection, name="loop break")
155        )
156
157    def getLoopContinueCollections(self):
158        return self.continue_collections
159
160    def onLoopContinue(self, collection=None):
161        if collection is None:
162            collection = self
163
164        self.continue_collections.append(
165            TraceCollectionBranch(parent=collection, name="loop continue")
166        )
167
168    def onFunctionReturn(self, collection=None):
169        if collection is None:
170            collection = self
171
172        if self.return_collections is not None:
173            self.return_collections.append(
174                TraceCollectionBranch(parent=collection, name="return")
175            )
176
177    def onExceptionRaiseExit(self, raisable_exceptions, collection=None):
178        """Indicate to the trace collection what exceptions may have occurred.
179
180        Args:
181            raisable_exception: Currently ignored, one or more exceptions that
182            could occur, e.g. "BaseException".
183            collection: To pass the collection that will be the parent
184        Notes:
185            Currently this is unused. Passing "collection" as an argument, so
186            we know the original collection to attach the branch to, is maybe
187            not the most clever way to do this
188
189            We also might want to specialize functions for specific exceptions,
190            there is little point in providing BaseException as an argument in
191            so many places.
192
193            The actual storage of the exceptions that can occur is currently
194            missing entirely. We just use this to detect "any exception" by
195            not being empty.
196        """
197
198        # TODO: We might want to track per exception, pylint: disable=unused-argument
199
200        if collection is None:
201            collection = self
202
203        if self.exception_collections is not None:
204            self.exception_collections.append(
205                TraceCollectionBranch(parent=collection, name="exception")
206            )
207
208    def getFunctionReturnCollections(self):
209        return self.return_collections
210
211    def getExceptionRaiseCollections(self):
212        return self.exception_collections
213
214    def hasEmptyTraces(self, variable):
215        # TODO: Combine these steps into one for performance gains.
216        traces = self.getVariableTraces(variable)
217        return areEmptyTraces(traces)
218
219    def hasReadOnlyTraces(self, variable):
220        # TODO: Combine these steps into one for performance gains.
221        traces = self.getVariableTraces(variable)
222        return areReadOnlyTraces(traces)
223
224    def initVariableUnknown(self, variable):
225        trace = ValueTraceUnknown(owner=self.owner, previous=None)
226
227        self.addVariableTrace(variable, 0, trace)
228
229        return trace
230
231    def _initVariableInit(self, variable):
232        trace = ValueTraceInit(self.owner)
233
234        self.addVariableTrace(variable, 0, trace)
235
236        return trace
237
238    def _initVariableInitStarArgs(self, variable):
239        trace = ValueTraceInitStarArgs(self.owner)
240
241        self.addVariableTrace(variable, 0, trace)
242
243        return trace
244
245    def _initVariableInitStarDict(self, variable):
246        trace = ValueTraceInitStarDict(self.owner)
247
248        self.addVariableTrace(variable, 0, trace)
249
250        return trace
251
252    def _initVariableUninit(self, variable):
253        trace = ValueTraceUninit(owner=self.owner, previous=None)
254
255        self.addVariableTrace(variable, 0, trace)
256
257        return trace
258
259    def updateVariablesFromCollection(self, old_collection, source_ref):
260        Variables.updateVariablesFromCollection(old_collection, self, source_ref)
261
262    @contextlib.contextmanager
263    def makeAbortStackContext(
264        self, catch_breaks, catch_continues, catch_returns, catch_exceptions
265    ):
266        if catch_breaks:
267            old_break_collections = self.break_collections
268            self.break_collections = []
269        if catch_continues:
270            old_continue_collections = self.continue_collections
271            self.continue_collections = []
272        if catch_returns:
273            old_return_collections = self.return_collections
274            self.return_collections = []
275        if catch_exceptions:
276            old_exception_collections = self.exception_collections
277            self.exception_collections = []
278
279        yield
280
281        if catch_breaks:
282            self.break_collections = old_break_collections
283        if catch_continues:
284            self.continue_collections = old_continue_collections
285        if catch_returns:
286            self.return_collections = old_return_collections
287        if catch_exceptions:
288            self.exception_collections = old_exception_collections
289
290    def initVariable(self, variable):
291        if variable.isParameterVariable():
292            # TODO: That's not happening, maybe just assert against it.
293            result = self._initVariableInit(variable)
294        elif variable.isLocalVariable():
295            result = self._initVariableUninit(variable)
296        elif variable.isModuleVariable():
297            result = self.initVariableUnknown(variable)
298        elif variable.isTempVariable():
299            result = self._initVariableUninit(variable)
300        elif variable.isLocalsDictVariable():
301            if variable.getOwner().getTypeShape() is tshape_dict:
302                result = self._initVariableUninit(variable)
303            else:
304                result = self.initVariableUnknown(variable)
305        else:
306            assert False, variable
307
308        return result
309
310    def addOutlineFunction(self, outline):
311        if self.outline_functions is None:
312            self.outline_functions = [outline]
313        else:
314            self.outline_functions.append(outline)
315
316    def getOutlineFunctions(self):
317        return self.outline_functions
318
319    def onLocalsDictEscaped(self, locals_scope):
320        if locals_scope is not None:
321            for variable in locals_scope.variables.values():
322                self.markActiveVariableAsEscaped(variable)
323
324        # TODO: Limit to the scope.
325        for variable in self.getActiveVariables():
326            if variable.isTempVariable() or variable.isModuleVariable():
327                continue
328
329            self.markActiveVariableAsEscaped(variable)
330
331
332class TraceCollectionBase(object):
333    """This contains for logic for maintaining active traces.
334
335    They are kept for "variable" and versions.
336    """
337
338    __slots__ = ("owner", "parent", "name", "value_states", "variable_actives")
339
340    if isCountingInstances():
341        __del__ = counted_del()
342
343    @counted_init
344    def __init__(self, owner, name, parent):
345        self.owner = owner
346        self.parent = parent
347        self.name = name
348
349        # Value state extra information per node.
350        self.value_states = {}
351
352        # Currently active values in the tracing.
353        self.variable_actives = {}
354
355    def __repr__(self):
356        return "<%s for %s at 0x%x>" % (self.__class__.__name__, self.name, id(self))
357
358    def getOwner(self):
359        return self.owner
360
361    def getVariableCurrentTrace(self, variable):
362        """Get the current value trace associated to this variable
363
364        It is also created on the fly if necessary. We create them
365        lazy so to keep the tracing branches minimal where possible.
366        """
367
368        return self.getVariableTrace(
369            variable=variable, version=self._getCurrentVariableVersion(variable)
370        )
371
372    def markCurrentVariableTrace(self, variable, version):
373        self.variable_actives[variable] = version
374
375    def _getCurrentVariableVersion(self, variable):
376        try:
377            return self.variable_actives[variable]
378        except KeyError:
379            # Initialize variables on the fly.
380            if not self.hasVariableTrace(variable, 0):
381                self.initVariable(variable)
382
383            self.markCurrentVariableTrace(variable, 0)
384
385            return self.variable_actives[variable]
386
387    def getActiveVariables(self):
388        return self.variable_actives.keys()
389
390    def markActiveVariableAsEscaped(self, variable):
391        current = self.getVariableCurrentTrace(variable=variable)
392
393        if not current.isEscapeOrUnknownTrace():
394            version = variable.allocateTargetNumber()
395
396            self.addVariableTrace(
397                variable,
398                version,
399                ValueTraceEscaped(owner=self.owner, previous=current),
400            )
401
402            self.markCurrentVariableTrace(variable, version)
403
404    def markActiveVariableAsUnknown(self, variable):
405        current = self.getVariableCurrentTrace(variable=variable)
406
407        if not current.isUnknownTrace():
408            version = variable.allocateTargetNumber()
409
410            self.addVariableTrace(
411                variable,
412                version,
413                ValueTraceUnknown(owner=self.owner, previous=current),
414            )
415
416            self.markCurrentVariableTrace(variable, version)
417
418    def markActiveVariableAsLoopMerge(
419        self, loop_node, current, variable, shapes, incomplete
420    ):
421        if incomplete:
422            result = ValueTraceLoopIncomplete(loop_node, current, shapes)
423        else:
424            # TODO: Empty is a missing optimization somewhere, but it also happens that
425            # a variable is getting released in a loop.
426            # assert shapes, (variable, current)
427
428            if not shapes:
429                shapes.add(tshape_uninit)
430
431            result = ValueTraceLoopComplete(loop_node, current, shapes)
432
433        version = variable.allocateTargetNumber()
434        self.addVariableTrace(variable, version, result)
435
436        self.markCurrentVariableTrace(variable, version)
437
438        return result
439
440    def markActiveVariablesAsEscaped(self):
441        for variable in self.getActiveVariables():
442            if variable.isTempVariable():
443                continue
444
445            self.markActiveVariableAsEscaped(variable)
446
447    def markActiveVariablesAsUnknown(self):
448        for variable in self.getActiveVariables():
449            if variable.isTempVariable():
450                continue
451
452            self.markActiveVariableAsUnknown(variable)
453
454    @staticmethod
455    def signalChange(tags, source_ref, message):
456        # This is monkey patched from another module. pylint: disable=I0021,not-callable
457        signalChange(tags, source_ref, message)
458
459    def onUsedModule(self, module_name, module_relpath):
460        return self.parent.onUsedModule(module_name, module_relpath)
461
462    def onUsedFunction(self, function_body):
463        owning_module = function_body.getParentModule()
464
465        # Make sure the owning module is added to the used set. This is most
466        # important for helper functions, or modules, which otherwise have
467        # become unused.
468        addUsedModule(owning_module)
469
470        needs_visit = owning_module.addUsedFunction(function_body)
471
472        if needs_visit:
473            function_body.computeFunctionRaw(self)
474
475    @staticmethod
476    def mustAlias(a, b):
477        if a.isExpressionVariableRef() and b.isExpressionVariableRef():
478            return a.getVariable() is b.getVariable()
479
480        return False
481
482    @staticmethod
483    def mustNotAlias(a, b):
484        # TODO: not yet really implemented
485        if a.isExpressionConstantRef() and b.isExpressionConstantRef():
486            if a.isMutable() or b.isMutable():
487                return True
488
489        return False
490
491    def onControlFlowEscape(self, node):
492        # TODO: One day, we should trace which nodes exactly cause a variable
493        # to be considered escaped, pylint: disable=unused-argument
494
495        for variable in self.getActiveVariables():
496            # TODO: Move this to the variable, and prepare and cache it better for
497            # compile time savings.
498            if variable.isModuleVariable():
499                self.markActiveVariableAsUnknown(variable)
500
501            elif variable.isLocalVariable():
502                if (
503                    str is not bytes
504                    and variable.hasWritersOutsideOf(self.owner) is not False
505                ):
506                    self.markActiveVariableAsUnknown(variable)
507                elif variable.hasAccessesOutsideOf(self.owner) is not False:
508                    self.markActiveVariableAsEscaped(variable)
509
510    def removeKnowledge(self, node):
511        if node.isExpressionVariableRef():
512            if node.variable.isModuleVariable():
513                self.markActiveVariableAsUnknown(node.variable)
514            else:
515                self.markActiveVariableAsEscaped(node.variable)
516
517    def onValueEscapeStr(self, node):
518        # TODO: We can ignore these for now.
519        pass
520
521    def removeAllKnowledge(self):
522        self.markActiveVariablesAsUnknown()
523
524    def onVariableSet(self, variable, version, assign_node):
525        variable_trace = ValueTraceAssign(
526            owner=self.owner,
527            assign_node=assign_node,
528            previous=self.getVariableCurrentTrace(variable=variable),
529        )
530
531        self.addVariableTrace(variable, version, variable_trace)
532
533        # Make references point to it.
534        self.markCurrentVariableTrace(variable, version)
535
536        return variable_trace
537
538    def onVariableDel(self, variable, version, del_node):
539        # Add a new trace, allocating a new version for the variable, and
540        # remember the delete of the current
541        old_trace = self.getVariableCurrentTrace(variable)
542
543        # TODO: Annotate value content as escaped.
544
545        variable_trace = ValueTraceDeleted(
546            owner=self.owner, del_node=del_node, previous=old_trace
547        )
548
549        # Assign to not initialized again.
550        self.addVariableTrace(variable, version, variable_trace)
551
552        # Make references point to it.
553        self.markCurrentVariableTrace(variable, version)
554
555        return variable_trace
556
557    def onLocalsUsage(self, locals_scope):
558        self.onLocalsDictEscaped(locals_scope)
559
560        result = []
561
562        scope_locals_variables = locals_scope.getLocalsRelevantVariables()
563
564        for variable in self.getActiveVariables():
565            if variable.isLocalVariable() and variable in scope_locals_variables:
566                variable_trace = self.getVariableCurrentTrace(variable)
567
568                variable_trace.addNameUsage()
569                result.append((variable, variable_trace))
570
571        return result
572
573    def onVariableContentEscapes(self, variable):
574        if variable.isModuleVariable():
575            self.markActiveVariableAsUnknown(variable)
576        else:
577            self.markActiveVariableAsEscaped(variable)
578
579    def onExpression(self, expression, allow_none=False):
580        if expression is None and allow_none:
581            return None
582
583        assert expression.isExpression(), expression
584
585        parent = expression.parent
586        assert parent, expression
587
588        # Now compute this expression, allowing it to replace itself with
589        # something else as part of a local peep hole optimization.
590        r = expression.computeExpressionRaw(trace_collection=self)
591        assert type(r) is tuple, (expression, expression.getVisitableNodes(), r)
592
593        new_node, change_tags, change_desc = r
594
595        if change_tags is not None:
596            # This is mostly for tracing and indication that a change occurred
597            # and it may be interesting to look again.
598            self.signalChange(change_tags, expression.getSourceReference(), change_desc)
599
600        if new_node is not expression:
601            parent.replaceChild(expression, new_node)
602
603        return new_node
604
605    def onStatement(self, statement):
606        try:
607            assert statement.isStatement(), statement
608
609            new_statement, change_tags, change_desc = statement.computeStatement(self)
610
611            # print new_statement, change_tags, change_desc
612            if new_statement is not statement:
613                self.signalChange(
614                    change_tags, statement.getSourceReference(), change_desc
615                )
616
617            return new_statement
618        except Exception:
619            Tracing.printError(
620                "Problem with statement at %s:\n-> %s"
621                % (
622                    statement.source_ref.getAsString(),
623                    readSourceLine(statement.source_ref),
624                )
625            )
626            raise
627
628    def computedStatementResult(self, statement, change_tags, change_desc):
629        """Make sure the replacement statement is computed.
630
631        Use this when a replacement expression needs to be seen by the trace
632        collection and be computed, without causing any duplication, but where
633        otherwise there would be loss of annotated effects.
634
635        This may e.g. be true for nodes that need an initial run to know their
636        exception result and type shape.
637        """
638        # Need to compute the replacement still.
639        new_statement = statement.computeStatement(self)
640
641        if new_statement[0] is not statement:
642            # Signal intermediate result as well.
643            self.signalChange(change_tags, statement.getSourceReference(), change_desc)
644
645            return new_statement
646        else:
647            return statement, change_tags, change_desc
648
649    def computedExpressionResult(self, expression, change_tags, change_desc):
650        """Make sure the replacement expression is computed.
651
652        Use this when a replacement expression needs to be seen by the trace
653        collection and be computed, without causing any duplication, but where
654        otherwise there would be loss of annotated effects.
655
656        This may e.g. be true for nodes that need an initial run to know their
657        exception result and type shape.
658        """
659
660        # Need to compute the replacement still.
661        new_expression = expression.computeExpression(self)
662
663        if new_expression[0] is not expression:
664            # Signal intermediate result as well.
665            self.signalChange(change_tags, expression.getSourceReference(), change_desc)
666
667            return new_expression
668        else:
669            return expression, change_tags, change_desc
670
671    def computedExpressionResultRaw(self, expression, change_tags, change_desc):
672        """Make sure the replacement expression is computed.
673
674        Use this when a replacement expression needs to be seen by the trace
675        collection and be computed, without causing any duplication, but where
676        otherwise there would be loss of annotated effects.
677
678        This may e.g. be true for nodes that need an initial run to know their
679        exception result and type shape.
680
681        This is for raw, i.e. subnodes are not yet computed automatically.
682        """
683
684        # Need to compute the replacement still.
685        new_expression = expression.computeExpressionRaw(self)
686
687        if new_expression[0] is not expression:
688            # Signal intermediate result as well.
689            self.signalChange(change_tags, expression.getSourceReference(), change_desc)
690
691            return new_expression
692        else:
693            return expression, change_tags, change_desc
694
695    def mergeBranches(self, collection_yes, collection_no):
696        """Merge two alternative branches into this trace.
697
698        This is mostly for merging conditional branches, or other ways
699        of having alternative control flow. This deals with up to two
700        alternative branches to both change this collection.
701        """
702
703        # Many branches due to inlining the actual merge and preparing it
704        # pylint: disable=too-many-branches
705
706        if collection_yes is None:
707            if collection_no is not None:
708                # Handle one branch case, we need to merge versions backwards as
709                # they may make themselves obsolete.
710                collection1 = self
711                collection2 = collection_no
712            else:
713                # Refuse to do stupid work
714                return
715        elif collection_no is None:
716            # Handle one branch case, we need to merge versions backwards as
717            # they may make themselves obsolete.
718            collection1 = self
719            collection2 = collection_yes
720        else:
721            # Handle two branch case, they may or may not do the same things.
722            collection1 = collection_yes
723            collection2 = collection_no
724
725        variable_versions = {}
726
727        for variable, version in iterItems(collection1.variable_actives):
728            variable_versions[variable] = version
729
730        for variable, version in iterItems(collection2.variable_actives):
731            if variable not in variable_versions:
732                variable_versions[variable] = 0, version
733            else:
734                other = variable_versions[variable]
735
736                if other != version:
737                    variable_versions[variable] = other, version
738                else:
739                    variable_versions[variable] = other
740
741        for variable in variable_versions:
742            if variable not in collection2.variable_actives:
743                variable_versions[variable] = variable_versions[variable], 0
744
745        self.variable_actives = {}
746
747        for variable, versions in iterItems(variable_versions):
748            if type(versions) is tuple:
749                version = self.addVariableMergeMultipleTrace(
750                    variable=variable,
751                    traces=(
752                        self.getVariableTrace(variable, versions[0]),
753                        self.getVariableTrace(variable, versions[1]),
754                    ),
755                )
756            else:
757                version = versions
758
759            self.markCurrentVariableTrace(variable, version)
760
761    def mergeMultipleBranches(self, collections):
762        assert collections
763
764        # Optimize for length 1, which is trivial merge and needs not a
765        # lot of work.
766        if len(collections) == 1:
767            self.replaceBranch(collections[0])
768            return None
769
770        # print("Enter mergeMultipleBranches", len(collections))
771        with TimerReport(
772            message="Running merge for %s took %%.2f seconds" % collections,
773            decider=lambda: 0,
774        ):
775            variable_versions = defaultdict(OrderedSet)
776
777            for collection in collections:
778                for variable, version in iterItems(collection.variable_actives):
779                    variable_versions[variable].add(version)
780
781            for collection in collections:
782                for variable, versions in iterItems(variable_versions):
783                    if variable not in collection.variable_actives:
784                        versions.add(0)
785
786            self.variable_actives = {}
787
788            for variable, versions in iterItems(variable_versions):
789                if len(versions) == 1:
790                    (version,) = versions
791                else:
792                    version = self.addVariableMergeMultipleTrace(
793                        variable=variable,
794                        traces=tuple(
795                            self.getVariableTrace(variable, version)
796                            for version in versions
797                        ),
798                    )
799
800                self.markCurrentVariableTrace(variable, version)
801
802            # print("Leave mergeMultipleBranches", len(collections))
803
804    def replaceBranch(self, collection_replace):
805        self.variable_actives.update(collection_replace.variable_actives)
806        collection_replace.variable_actives = None
807
808    def onLoopBreak(self, collection=None):
809        if collection is None:
810            collection = self
811
812        return self.parent.onLoopBreak(collection)
813
814    def onLoopContinue(self, collection=None):
815        if collection is None:
816            collection = self
817
818        return self.parent.onLoopContinue(collection)
819
820    def onFunctionReturn(self, collection=None):
821        if collection is None:
822            collection = self
823
824        return self.parent.onFunctionReturn(collection)
825
826    def onExceptionRaiseExit(self, raisable_exceptions, collection=None):
827        if collection is None:
828            collection = self
829
830        return self.parent.onExceptionRaiseExit(raisable_exceptions, collection)
831
832    def getLoopBreakCollections(self):
833        return self.parent.getLoopBreakCollections()
834
835    def getLoopContinueCollections(self):
836        return self.parent.getLoopContinueCollections()
837
838    def getFunctionReturnCollections(self):
839        return self.parent.getFunctionReturnCollections()
840
841    def getExceptionRaiseCollections(self):
842        return self.parent.getExceptionRaiseCollections()
843
844    def makeAbortStackContext(
845        self, catch_breaks, catch_continues, catch_returns, catch_exceptions
846    ):
847        return self.parent.makeAbortStackContext(
848            catch_breaks=catch_breaks,
849            catch_continues=catch_continues,
850            catch_returns=catch_returns,
851            catch_exceptions=catch_exceptions,
852        )
853
854    def onLocalsDictEscaped(self, locals_scope):
855        self.parent.onLocalsDictEscaped(locals_scope)
856
857    def getCompileTimeComputationResult(self, node, computation, description):
858        new_node, change_tags, message = getComputationResult(
859            node=node,
860            computation=computation,
861            description=description,
862            user_provided=False,
863        )
864
865        if change_tags == "new_raise":
866            self.onExceptionRaiseExit(BaseException)
867
868        return new_node, change_tags, message
869
870    def getIteratorNextCount(self, iter_node):
871        return self.value_states.get(iter_node)
872
873    def initIteratorValue(self, iter_node):
874        # TODO: More complex state information will be needed eventually.
875        self.value_states[iter_node] = 0
876
877    def onIteratorNext(self, iter_node):
878        if iter_node in self.value_states:
879            self.value_states[iter_node] += 1
880
881    def resetValueStates(self):
882        for key in self.value_states:
883            self.value_states[key] = None
884
885    def addOutlineFunction(self, outline):
886        self.parent.addOutlineFunction(outline)
887
888
889class TraceCollectionBranch(CollectionUpdateMixin, TraceCollectionBase):
890    __slots__ = ("variable_traces",)
891
892    def __init__(self, name, parent):
893        TraceCollectionBase.__init__(self, owner=parent.owner, name=name, parent=parent)
894
895        # Detach from others
896        self.variable_actives = dict(parent.variable_actives)
897
898        # For quick access without going to parent.
899        self.variable_traces = parent.variable_traces
900
901    def computeBranch(self, branch):
902        if branch.isStatementsSequence():
903            result = branch.computeStatementsSequence(trace_collection=self)
904
905            if result is not branch:
906                branch.parent.replaceChild(branch, result)
907        else:
908            self.onExpression(expression=branch)
909
910    def initVariable(self, variable):
911        variable_trace = self.parent.initVariable(variable)
912
913        self.variable_actives[variable] = 0
914
915        return variable_trace
916
917    def dumpTraces(self):
918        Tracing.printSeparator()
919        self.parent.dumpTraces()
920        Tracing.printSeparator()
921
922    def dumpActiveTraces(self):
923        Tracing.printSeparator()
924        Tracing.printLine("Active are:")
925        for variable, _version in sorted(self.variable_actives.iteritems()):
926            self.getVariableCurrentTrace(variable).dump()
927
928        Tracing.printSeparator()
929
930
931class TraceCollectionFunction(CollectionStartpointMixin, TraceCollectionBase):
932    __slots__ = (
933        "variable_versions",
934        "variable_traces",
935        "break_collections",
936        "continue_collections",
937        "return_collections",
938        "exception_collections",
939        "outline_functions",
940    )
941
942    def __init__(self, parent, function_body):
943        assert (
944            function_body.isExpressionFunctionBody()
945            or function_body.isExpressionGeneratorObjectBody()
946            or function_body.isExpressionCoroutineObjectBody()
947            or function_body.isExpressionAsyncgenObjectBody()
948        ), function_body
949
950        CollectionStartpointMixin.__init__(self)
951
952        TraceCollectionBase.__init__(
953            self,
954            owner=function_body,
955            name="collection_" + function_body.getCodeName(),
956            parent=parent,
957        )
958
959        if function_body.isExpressionFunctionBody():
960            parameters = function_body.getParameters()
961
962            for parameter_variable in parameters.getTopLevelVariables():
963                self._initVariableInit(parameter_variable)
964                self.variable_actives[parameter_variable] = 0
965
966            list_star_variable = parameters.getListStarArgVariable()
967            if list_star_variable is not None:
968                self._initVariableInitStarArgs(list_star_variable)
969                self.variable_actives[list_star_variable] = 0
970
971            dict_star_variable = parameters.getDictStarArgVariable()
972            if dict_star_variable is not None:
973                self._initVariableInitStarDict(dict_star_variable)
974                self.variable_actives[dict_star_variable] = 0
975
976        for closure_variable in function_body.getClosureVariables():
977            self.initVariableUnknown(closure_variable)
978            self.variable_actives[closure_variable] = 0
979
980        # TODO: Have special function type for exec functions stuff.
981        locals_scope = function_body.getLocalsScope()
982
983        if locals_scope is not None:
984            if not locals_scope.isMarkedForPropagation():
985                for locals_dict_variable in locals_scope.variables.values():
986                    self._initVariableUninit(locals_dict_variable)
987            else:
988                function_body.locals_scope = None
989
990
991class TraceCollectionPureFunction(TraceCollectionFunction):
992    """Pure functions don't feed their parent."""
993
994    __slots__ = ("used_functions",)
995
996    def __init__(self, function_body):
997        TraceCollectionFunction.__init__(self, parent=None, function_body=function_body)
998
999        self.used_functions = OrderedSet()
1000
1001    def getUsedFunctions(self):
1002        return self.used_functions
1003
1004    def onUsedFunction(self, function_body):
1005        self.used_functions.add(function_body)
1006
1007
1008class TraceCollectionModule(CollectionStartpointMixin, TraceCollectionBase):
1009    __slots__ = (
1010        "variable_versions",
1011        "variable_traces",
1012        "break_collections",
1013        "continue_collections",
1014        "return_collections",
1015        "exception_collections",
1016        "outline_functions",
1017    )
1018
1019    def __init__(self, module):
1020        assert module.isCompiledPythonModule(), module
1021
1022        CollectionStartpointMixin.__init__(self)
1023
1024        TraceCollectionBase.__init__(
1025            self, owner=module, name="module:" + module.getFullName(), parent=None
1026        )
1027
1028    def onUsedModule(self, module_name, module_relpath):
1029        assert type(module_name) is ModuleName, module_name
1030
1031        # TODO: Make users provide this through a method that has already
1032        # done this.
1033        module_relpath = relpath(module_relpath)
1034
1035        self.owner.addUsedModule((module_name, module_relpath))
1036
1037        module = getImportedModuleByNameAndPath(module_name, module_relpath)
1038        addUsedModule(module)
1039
1040
1041# TODO: This should not exist, but be part of decision at the time these are collected.
1042def areEmptyTraces(variable_traces):
1043    """Do these traces contain any writes or accesses."""
1044    # Many cases immediately return, that is how we do it here,
1045    # pylint: disable=too-many-branches,too-many-return-statements
1046
1047    for variable_trace in variable_traces:
1048        if variable_trace.isAssignTrace():
1049            return False
1050        elif variable_trace.isInitTrace():
1051            return False
1052        elif variable_trace.isDeletedTrace():
1053            # A "del" statement can do this, and needs to prevent variable
1054            # from being removed.
1055
1056            return False
1057        elif variable_trace.isUninitTrace():
1058            if variable_trace.getUsageCount():
1059                # Checking definite is enough, the merges, we shall see
1060                # them as well.
1061                return False
1062        elif variable_trace.isUnknownTrace():
1063            if variable_trace.getUsageCount():
1064                # Checking definite is enough, the merges, we shall see
1065                # them as well.
1066                return False
1067        elif variable_trace.isEscapeTrace():
1068            if variable_trace.getUsageCount():
1069                # Checking definite is enough, the merges, we shall see
1070                # them as well.
1071                return False
1072        elif variable_trace.isMergeTrace():
1073            if variable_trace.getUsageCount():
1074                # Checking definite is enough, the merges, we shall see
1075                # them as well.
1076                return False
1077        elif variable_trace.isLoopTrace():
1078            return False
1079        else:
1080            assert False, variable_trace
1081
1082    return True
1083
1084
1085def areReadOnlyTraces(variable_traces):
1086    """Do these traces contain any writes."""
1087
1088    # Many cases immediately return, that is how we do it here,
1089    for variable_trace in variable_traces:
1090        if variable_trace.isAssignTrace():
1091            return False
1092        elif variable_trace.isInitTrace():
1093            pass
1094        elif variable_trace.isDeletedTrace():
1095            # A "del" statement can do this, and needs to prevent variable
1096            # from being not released.
1097
1098            return False
1099        elif variable_trace.isUninitTrace():
1100            pass
1101        elif variable_trace.isUnknownTrace():
1102            return False
1103        elif variable_trace.isEscapeTrace():
1104            pass
1105        elif variable_trace.isMergeTrace():
1106            pass
1107        elif variable_trace.isLoopTrace():
1108            pass
1109        else:
1110            assert False, variable_trace
1111
1112    return True
1113