1#     Copyright 2021, Kay Hayen, mailto:kay.hayen@gmail.com
2#
3#     Part of "Nuitka", an optimizing Python compiler that is compatible and
4#     integrates with CPython, but also works on its own.
5#
6#     Licensed under the Apache License, Version 2.0 (the "License");
7#     you may not use this file except in compliance with the License.
8#     You may obtain a copy of the License at
9#
10#        http://www.apache.org/licenses/LICENSE-2.0
11#
12#     Unless required by applicable law or agreed to in writing, software
13#     distributed under the License is distributed on an "AS IS" BASIS,
14#     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15#     See the License for the specific language governing permissions and
16#     limitations under the License.
17#
18""" Nodes for comparisons.
19
20"""
21
22from nuitka import PythonOperators
23from nuitka.Errors import NuitkaAssumptionError
24from nuitka.PythonVersions import python_version
25
26from .ExpressionBases import ExpressionChildrenHavingBase
27from .NodeMakingHelpers import (
28    makeConstantReplacementNode,
29    makeRaiseExceptionReplacementExpressionFromInstance,
30    wrapExpressionWithSideEffects,
31)
32from .shapes.BuiltinTypeShapes import tshape_bool, tshape_exception_class
33
34
35class ExpressionComparisonBase(ExpressionChildrenHavingBase):
36    named_children = ("left", "right")
37
38    def __init__(self, left, right, source_ref):
39        ExpressionChildrenHavingBase.__init__(
40            self, values={"left": left, "right": right}, source_ref=source_ref
41        )
42
43    @staticmethod
44    def copyTraceStateFrom(source):
45        pass
46
47    def getOperands(self):
48        return (self.subnode_left, self.subnode_right)
49
50    def getComparator(self):
51        return self.comparator
52
53    def getDetails(self):
54        return {"comparator": self.comparator}
55
56    @staticmethod
57    def isExpressionComparison():
58        return True
59
60    def getSimulator(self):
61        return PythonOperators.all_comparison_functions[self.comparator]
62
63    def _computeCompileTimeConstantComparision(self, trace_collection):
64        left_value = self.subnode_left.getCompileTimeConstant()
65        right_value = self.subnode_right.getCompileTimeConstant()
66
67        return trace_collection.getCompileTimeComputationResult(
68            node=self,
69            computation=lambda: self.getSimulator()(left_value, right_value),
70            description="Comparison of constant arguments.",
71        )
72
73    def computeExpression(self, trace_collection):
74        left = self.subnode_left
75        right = self.subnode_right
76
77        if left.isCompileTimeConstant() and right.isCompileTimeConstant():
78            return self._computeCompileTimeConstantComparision(trace_collection)
79
80        # The value of these nodes escaped and could change its contents.
81        # TODO: Comparisons don't do much, but add this.
82        # trace_collection.onValueEscapeRichComparison(left, right, self.comparator)
83
84        # Any code could be run, note that.
85        trace_collection.onControlFlowEscape(self)
86
87        trace_collection.onExceptionRaiseExit(BaseException)
88
89        return self, None, None
90
91    def makeInverseComparision(self):
92        # Making this accessing for tree building phase as well.
93        return makeComparisonExpression(
94            left=self.subnode_left,
95            right=self.subnode_right,
96            comparator=PythonOperators.comparison_inversions[self.comparator],
97            source_ref=self.source_ref,
98        )
99
100    def computeExpressionOperationNot(self, not_node, trace_collection):
101        if self.getTypeShape() is tshape_bool:
102            result = self.makeInverseComparision()
103
104            result.copyTraceStateFrom(self)
105
106            return (
107                result,
108                "new_expression",
109                """Replaced negated comparison '%s' with inverse comparison '%s'."""
110                % (self.comparator, result.comparator),
111            )
112
113        return not_node, None, None
114
115
116class ExpressionComparisonRichBase(ExpressionComparisonBase):
117    __slots__ = "type_shape", "escape_desc"
118
119    def __init__(self, left, right, source_ref):
120        ExpressionComparisonBase.__init__(
121            self, left=left, right=right, source_ref=source_ref
122        )
123
124        self.type_shape = None
125        self.escape_desc = None
126
127    def getTypeShape(self):
128        return self.type_shape
129
130    @staticmethod
131    def getDetails():
132        return {}
133
134    def copyTraceStateFrom(self, source):
135        self.type_shape = source.type_shape
136        self.escape_desc = source.escape_desc
137
138    def canCreateUnsupportedException(self):
139        return hasattr(self.subnode_left.getTypeShape(), "typical_value") and hasattr(
140            self.subnode_right.getTypeShape(), "typical_value"
141        )
142
143    def createUnsupportedException(self):
144        left = self.subnode_left.getTypeShape().typical_value
145        right = self.subnode_right.getTypeShape().typical_value
146
147        try:
148            self.getSimulator()(left, right)
149        except TypeError as e:
150            return e
151        else:
152            raise NuitkaAssumptionError(
153                "Unexpected no-exception doing comparison simulation",
154                self.operator,
155                self.simulator,
156                self.subnode_left.getTypeShape(),
157                self.subnode_right.getTypeShape(),
158                repr(left),
159                repr(right),
160            )
161
162    def computeExpression(self, trace_collection):
163        left = self.subnode_left
164        right = self.subnode_right
165
166        if left.isCompileTimeConstant() and right.isCompileTimeConstant():
167            return self._computeCompileTimeConstantComparision(trace_collection)
168
169        left_shape = left.getTypeShape()
170        right_shape = right.getTypeShape()
171
172        self.type_shape, self.escape_desc = self.getComparisonShape(
173            left_shape, right_shape
174        )
175
176        exception_raise_exit = self.escape_desc.getExceptionExit()
177        if exception_raise_exit is not None:
178            trace_collection.onExceptionRaiseExit(exception_raise_exit)
179
180            if (
181                self.escape_desc.isUnsupported()
182                and self.canCreateUnsupportedException()
183            ):
184                result = wrapExpressionWithSideEffects(
185                    new_node=makeRaiseExceptionReplacementExpressionFromInstance(
186                        expression=self, exception=self.createUnsupportedException()
187                    ),
188                    old_node=self,
189                    side_effects=(self.subnode_left, self.subnode_right),
190                )
191
192                return (
193                    result,
194                    "new_raise",
195                    """Replaced comparator '%s' with %s %s arguments that cannot work."""
196                    % (
197                        self.comparator,
198                        self.subnode_left.getTypeShape(),
199                        self.subnode_right.getTypeShape(),
200                    ),
201                )
202
203            # The value of these nodes escaped and could change its contents.
204
205            # TODO: Comparisons don't do much, but add this.
206            # if self.escape_desc.isValueEscaping():
207            #    trace_collection.onValueEscapeRichComparison(left, right, self.comparator)
208
209        if self.escape_desc.isControlFlowEscape():
210            # Any code could be run, note that.
211            trace_collection.onControlFlowEscape(self)
212
213        return self, None, None
214
215    def mayRaiseException(self, exception_type):
216        # TODO: Match more precisely
217        return (
218            self.escape_desc is None
219            or self.escape_desc.getExceptionExit() is not None
220            or self.subnode_left.mayRaiseException(exception_type)
221            or self.subnode_right.mayRaiseException(exception_type)
222        )
223
224    def mayRaiseExceptionBool(self, exception_type):
225        return self.type_shape.hasShapeSlotBool() is not True
226
227    def mayRaiseExceptionComparison(self):
228        return (
229            self.escape_desc is None or self.escape_desc.getExceptionExit() is not None
230        )
231
232
233class ExpressionComparisonLt(ExpressionComparisonRichBase):
234    kind = "EXPRESSION_COMPARISON_LT"
235
236    comparator = "Lt"
237
238    def __init__(self, left, right, source_ref):
239        ExpressionComparisonRichBase.__init__(
240            self, left=left, right=right, source_ref=source_ref
241        )
242
243    @staticmethod
244    def getComparisonShape(left_shape, right_shape):
245        return left_shape.getComparisonLtShape(right_shape)
246
247
248class ExpressionComparisonLte(ExpressionComparisonRichBase):
249    kind = "EXPRESSION_COMPARISON_LTE"
250
251    comparator = "LtE"
252
253    def __init__(self, left, right, source_ref):
254        ExpressionComparisonRichBase.__init__(
255            self, left=left, right=right, source_ref=source_ref
256        )
257
258    @staticmethod
259    def getComparisonShape(left_shape, right_shape):
260        return left_shape.getComparisonLteShape(right_shape)
261
262
263class ExpressionComparisonGt(ExpressionComparisonRichBase):
264    kind = "EXPRESSION_COMPARISON_GT"
265
266    comparator = "Gt"
267
268    def __init__(self, left, right, source_ref):
269        ExpressionComparisonRichBase.__init__(
270            self, left=left, right=right, source_ref=source_ref
271        )
272
273    @staticmethod
274    def getComparisonShape(left_shape, right_shape):
275        return left_shape.getComparisonGtShape(right_shape)
276
277
278class ExpressionComparisonGte(ExpressionComparisonRichBase):
279    kind = "EXPRESSION_COMPARISON_GTE"
280
281    comparator = "GtE"
282
283    def __init__(self, left, right, source_ref):
284        ExpressionComparisonRichBase.__init__(
285            self, left=left, right=right, source_ref=source_ref
286        )
287
288    @staticmethod
289    def getComparisonShape(left_shape, right_shape):
290        return left_shape.getComparisonGteShape(right_shape)
291
292
293class ExpressionComparisonEq(ExpressionComparisonRichBase):
294    kind = "EXPRESSION_COMPARISON_EQ"
295
296    comparator = "Eq"
297
298    def __init__(self, left, right, source_ref):
299        ExpressionComparisonRichBase.__init__(
300            self, left=left, right=right, source_ref=source_ref
301        )
302
303    @staticmethod
304    def getComparisonShape(left_shape, right_shape):
305        return left_shape.getComparisonEqShape(right_shape)
306
307
308class ExpressionComparisonNeq(ExpressionComparisonRichBase):
309    kind = "EXPRESSION_COMPARISON_NEQ"
310
311    comparator = "NotEq"
312
313    def __init__(self, left, right, source_ref):
314        ExpressionComparisonRichBase.__init__(
315            self, left=left, right=right, source_ref=source_ref
316        )
317
318    @staticmethod
319    def getComparisonShape(left_shape, right_shape):
320        return left_shape.getComparisonNeqShape(right_shape)
321
322
323class ExpressionComparisonIsIsNotBase(ExpressionComparisonBase):
324    __slots__ = ("match_value",)
325
326    def __init__(self, left, right, source_ref):
327        ExpressionComparisonBase.__init__(
328            self, left=left, right=right, source_ref=source_ref
329        )
330
331        assert self.comparator in ("Is", "IsNot")
332
333        # TODO: Forward propagate this one.
334        self.match_value = self.comparator == "Is"
335
336    @staticmethod
337    def getDetails():
338        return {}
339
340    @staticmethod
341    def getTypeShape():
342        return tshape_bool
343
344    def mayRaiseException(self, exception_type):
345        return self.subnode_left.mayRaiseException(
346            exception_type
347        ) or self.subnode_right.mayRaiseException(exception_type)
348
349    def mayRaiseExceptionBool(self, exception_type):
350        return False
351
352    def computeExpression(self, trace_collection):
353        left, right = self.getOperands()
354
355        if trace_collection.mustAlias(left, right):
356            result = makeConstantReplacementNode(
357                constant=self.match_value, node=self, user_provided=False
358            )
359
360            if left.mayHaveSideEffects() or right.mayHaveSideEffects():
361                result = wrapExpressionWithSideEffects(
362                    side_effects=self.extractSideEffects(),
363                    old_node=self,
364                    new_node=result,
365                )
366
367            return (
368                result,
369                "new_constant",
370                """\
371Determined values to alias and therefore result of %s comparison."""
372                % (self.comparator),
373            )
374
375        if trace_collection.mustNotAlias(left, right):
376            result = makeConstantReplacementNode(
377                constant=not self.match_value, node=self, user_provided=False
378            )
379
380            if left.mayHaveSideEffects() or right.mayHaveSideEffects():
381                result = wrapExpressionWithSideEffects(
382                    side_effects=self.extractSideEffects(),
383                    old_node=self,
384                    new_node=result,
385                )
386
387            return (
388                result,
389                "new_constant",
390                """\
391Determined values to not alias and therefore result of '%s' comparison."""
392                % (self.comparator),
393            )
394
395        return ExpressionComparisonBase.computeExpression(
396            self, trace_collection=trace_collection
397        )
398
399    def extractSideEffects(self):
400        left, right = self.getOperands()
401
402        return left.extractSideEffects() + right.extractSideEffects()
403
404    def computeExpressionDrop(self, statement, trace_collection):
405        from .NodeMakingHelpers import makeStatementOnlyNodesFromExpressions
406
407        result = makeStatementOnlyNodesFromExpressions(expressions=self.getOperands())
408
409        del self.parent
410
411        return (
412            result,
413            "new_statements",
414            """\
415Removed %s comparison for unused result."""
416            % self.comparator,
417        )
418
419
420class ExpressionComparisonIs(ExpressionComparisonIsIsNotBase):
421    kind = "EXPRESSION_COMPARISON_IS"
422
423    comparator = "Is"
424
425    def __init__(self, left, right, source_ref):
426        ExpressionComparisonIsIsNotBase.__init__(
427            self, left=left, right=right, source_ref=source_ref
428        )
429
430
431class ExpressionComparisonIsNot(ExpressionComparisonIsIsNotBase):
432    kind = "EXPRESSION_COMPARISON_IS_NOT"
433
434    comparator = "IsNot"
435
436    def __init__(self, left, right, source_ref):
437        ExpressionComparisonIsIsNotBase.__init__(
438            self, left=left, right=right, source_ref=source_ref
439        )
440
441
442class ExpressionComparisonExceptionMatchBase(ExpressionComparisonBase):
443    def __init__(self, left, right, source_ref):
444        ExpressionComparisonBase.__init__(
445            self, left=left, right=right, source_ref=source_ref
446        )
447
448    @staticmethod
449    def getDetails():
450        return {}
451
452    @staticmethod
453    def getTypeShape():
454        return tshape_bool
455
456    def getSimulator(self):
457        # TODO: Doesn't happen yet, but will once we trace exceptions.
458        assert False
459
460        return PythonOperators.all_comparison_functions[self.comparator]
461
462    def mayRaiseException(self, exception_type):
463        # TODO: Match errors that exception comparisons might raise more accurately.
464        return (
465            self.subnode_left.mayRaiseException(exception_type)
466            or self.subnode_right.mayRaiseException(exception_type)
467            or self.mayRaiseExceptionComparison()
468        )
469
470    def mayRaiseExceptionComparison(self):
471        if python_version < 0x300:
472            return False
473
474        # TODO: Add shape for exceptions.
475        type_shape = self.subnode_right.getTypeShape()
476
477        if type_shape is tshape_exception_class:
478            return False
479
480        return True
481
482    @staticmethod
483    def mayRaiseExceptionBool(exception_type):
484        return False
485
486
487class ExpressionComparisonExceptionMatch(ExpressionComparisonExceptionMatchBase):
488    kind = "EXPRESSION_COMPARISON_EXCEPTION_MATCH"
489
490    comparator = "exception_match"
491
492
493class ExpressionComparisonExceptionMismatch(ExpressionComparisonExceptionMatchBase):
494    kind = "EXPRESSION_COMPARISON_EXCEPTION_MISMATCH"
495
496    comparator = "exception_mismatch"
497
498
499class ExpressionComparisonInNotInBase(ExpressionComparisonBase):
500    def __init__(self, left, right, source_ref):
501        ExpressionComparisonBase.__init__(
502            self, left=left, right=right, source_ref=source_ref
503        )
504
505        assert self.comparator in ("In", "NotIn")
506
507    @staticmethod
508    def getDetails():
509        return {}
510
511    @staticmethod
512    def getTypeShape():
513        return tshape_bool
514
515    def mayRaiseException(self, exception_type):
516        left = self.subnode_left
517
518        if left.mayRaiseException(exception_type):
519            return True
520
521        right = self.subnode_right
522
523        if right.mayRaiseException(exception_type):
524            return True
525
526        return right.mayRaiseExceptionIn(exception_type, left)
527
528    @staticmethod
529    def mayRaiseExceptionBool(exception_type):
530        return False
531
532    def computeExpression(self, trace_collection):
533        return self.subnode_right.computeExpressionComparisonIn(
534            in_node=self,
535            value_node=self.subnode_left,
536            trace_collection=trace_collection,
537        )
538
539
540class ExpressionComparisonIn(ExpressionComparisonInNotInBase):
541    kind = "EXPRESSION_COMPARISON_IN"
542
543    comparator = "In"
544
545    def __init__(self, left, right, source_ref):
546        ExpressionComparisonInNotInBase.__init__(
547            self, left=left, right=right, source_ref=source_ref
548        )
549
550
551class ExpressionComparisonNotIn(ExpressionComparisonInNotInBase):
552    kind = "EXPRESSION_COMPARISON_NOT_IN"
553
554    comparator = "NotIn"
555
556    def __init__(self, left, right, source_ref):
557        ExpressionComparisonInNotInBase.__init__(
558            self, left=left, right=right, source_ref=source_ref
559        )
560
561
562_comparator_to_nodeclass = {
563    "Is": ExpressionComparisonIs,
564    "IsNot": ExpressionComparisonIsNot,
565    "In": ExpressionComparisonIn,
566    "NotIn": ExpressionComparisonNotIn,
567    "Lt": ExpressionComparisonLt,
568    "LtE": ExpressionComparisonLte,
569    "Gt": ExpressionComparisonGt,
570    "GtE": ExpressionComparisonGte,
571    "Eq": ExpressionComparisonEq,
572    "NotEq": ExpressionComparisonNeq,
573    "exception_match": ExpressionComparisonExceptionMatch,
574    "exception_mismatch": ExpressionComparisonExceptionMismatch,
575}
576
577
578def makeComparisonExpression(left, right, comparator, source_ref):
579    return _comparator_to_nodeclass[comparator](
580        left=left, right=right, source_ref=source_ref
581    )
582