1#     Copyright 2021, Kay Hayen, mailto:kay.hayen@gmail.com
2#
3#     Part of "Nuitka", an optimizing Python compiler that is compatible and
4#     integrates with CPython, but also works on its own.
5#
6#     Licensed under the Apache License, Version 2.0 (the "License");
7#     you may not use this file except in compliance with the License.
8#     You may obtain a copy of the License at
9#
10#        http://www.apache.org/licenses/LICENSE-2.0
11#
12#     Unless required by applicable law or agreed to in writing, software
13#     distributed under the License is distributed on an "AS IS" BASIS,
14#     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15#     See the License for the specific language governing permissions and
16#     limitations under the License.
17#
18""" Nodes for unary and binary operations.
19
20No short-circuit involved, boolean 'not' is an unary operation like '-' is,
21no real difference.
22"""
23
24import copy
25import math
26from abc import abstractmethod
27
28from nuitka import PythonOperators
29from nuitka.Errors import NuitkaAssumptionError
30from nuitka.PythonVersions import python_version
31
32from .ExpressionBases import ExpressionChildrenHavingBase
33from .NodeMakingHelpers import (
34    makeRaiseExceptionReplacementExpressionFromInstance,
35    wrapExpressionWithSideEffects,
36)
37from .shapes.BuiltinTypeShapes import tshape_bool, tshape_int_or_long
38from .shapes.StandardShapes import (
39    ShapeLargeConstantValue,
40    ShapeLargeConstantValuePredictable,
41    vshape_unknown,
42)
43
44
45class ExpressionPropertiesFromTypeShapeMixin(object):
46    """Given a self.type_shape, this can derive default properties from there."""
47
48    # Mixins are required to slots
49    __slots__ = ()
50
51    def isKnownToBeHashable(self):
52        return self.type_shape.hasShapeSlotHash()
53
54
55class ExpressionOperationBinaryBase(
56    ExpressionPropertiesFromTypeShapeMixin, ExpressionChildrenHavingBase
57):
58    """Base class for all binary operation expression."""
59
60    __slots__ = ("type_shape", "escape_desc", "inplace_suspect", "shape")
61
62    named_children = ("left", "right")
63    nice_children = tuple(child_name + " operand" for child_name in named_children)
64
65    def __init__(self, left, right, source_ref):
66        ExpressionChildrenHavingBase.__init__(
67            self, values={"left": left, "right": right}, source_ref=source_ref
68        )
69
70        self.type_shape = None
71        self.escape_desc = None
72
73        self.inplace_suspect = False
74
75        self.shape = vshape_unknown
76
77    @staticmethod
78    def isExpressionOperationBinary():
79        return True
80
81    def getOperator(self):
82        return self.operator
83
84    def markAsInplaceSuspect(self):
85        self.inplace_suspect = True
86
87    def unmarkAsInplaceSuspect(self):
88        self.inplace_suspect = False
89
90    def isInplaceSuspect(self):
91        return self.inplace_suspect
92
93    def getOperands(self):
94        return (self.subnode_left, self.subnode_right)
95
96    def mayRaiseExceptionOperation(self):
97        return self.escape_desc.getExceptionExit() is not None
98
99    def mayRaiseException(self, exception_type):
100        # TODO: Match getExceptionExit() more precisely against exception type given
101        return (
102            self.escape_desc is None
103            or self.escape_desc.getExceptionExit() is not None
104            or self.subnode_left.mayRaiseException(exception_type)
105            or self.subnode_right.mayRaiseException(exception_type)
106        )
107
108    def getTypeShape(self):
109        # Question might be asked early on, later this is cached from last computation.
110        if self.type_shape is None:
111            self.type_shape, self.escape_desc = self._getOperationShape(
112                self.subnode_left.getTypeShape(), self.subnode_right.getTypeShape()
113            )
114
115        return self.type_shape
116
117    @abstractmethod
118    def _getOperationShape(self, left_shape, right_shape):
119        pass
120
121    @staticmethod
122    def canCreateUnsupportedException(left_shape, right_shape):
123        return hasattr(left_shape, "typical_value") and hasattr(
124            right_shape, "typical_value"
125        )
126
127    def createUnsupportedException(self, left_shape, right_shape):
128        left = left_shape.typical_value
129        right = right_shape.typical_value
130
131        try:
132            self.simulator(left, right)
133        except TypeError as e:
134            return e
135        except Exception as e:
136            raise NuitkaAssumptionError(
137                "Unexpected exception type doing operation simulation",
138                self.operator,
139                self.simulator,
140                left_shape,
141                right_shape,
142                repr(left),
143                repr(right),
144                e,
145                "!=",
146            )
147        else:
148            raise NuitkaAssumptionError(
149                "Unexpected no-exception doing operation simulation",
150                self.operator,
151                self.simulator,
152                left_shape,
153                right_shape,
154                repr(left),
155                repr(right),
156            )
157
158    @staticmethod
159    def _isTooLarge():
160        return False
161
162    def _simulateOperation(self, trace_collection):
163        left_value = self.subnode_left.getCompileTimeConstant()
164        right_value = self.subnode_right.getCompileTimeConstant()
165
166        # Avoid mutating owned by nodes values and potentially shared.
167        if self.subnode_left.isMutable():
168            left_value = copy.copy(left_value)
169
170        return trace_collection.getCompileTimeComputationResult(
171            node=self,
172            computation=lambda: self.simulator(left_value, right_value),
173            description="Operator '%s' with constant arguments." % self.operator,
174        )
175
176    def computeExpression(self, trace_collection):
177        # Nothing to do anymore for large constants.
178        if self.shape is not None and self.shape.isConstant():
179            return self, None, None
180
181        left = self.subnode_left
182        left_shape = left.getTypeShape()
183        right = self.subnode_right
184        right_shape = right.getTypeShape()
185
186        self.type_shape, self.escape_desc = self._getOperationShape(
187            left_shape, right_shape
188        )
189
190        if left.isCompileTimeConstant() and right.isCompileTimeConstant():
191            if not self._isTooLarge():
192                return self._simulateOperation(trace_collection)
193
194        exception_raise_exit = self.escape_desc.getExceptionExit()
195        if exception_raise_exit is not None:
196            trace_collection.onExceptionRaiseExit(exception_raise_exit)
197
198            if self.escape_desc.isUnsupported() and self.canCreateUnsupportedException(
199                left_shape, right_shape
200            ):
201                result = wrapExpressionWithSideEffects(
202                    new_node=makeRaiseExceptionReplacementExpressionFromInstance(
203                        expression=self,
204                        exception=self.createUnsupportedException(
205                            left_shape,
206                            right_shape,
207                        ),
208                    ),
209                    old_node=self,
210                    side_effects=(left, right),
211                )
212
213                return (
214                    result,
215                    "new_raise",
216                    "Replaced operator '%s' with %s %s arguments that cannot work."
217                    % (self.operator, left_shape, right_shape),
218                )
219
220        if self.escape_desc.isValueEscaping():
221            # The value of these nodes escaped and could change its contents.
222            trace_collection.removeKnowledge(left)
223            trace_collection.removeKnowledge(right)
224
225        if self.escape_desc.isControlFlowEscape():
226            # Any code could be run, note that.
227            trace_collection.onControlFlowEscape(self)
228
229        return self, None, None
230
231    def canPredictIterationValues(self):
232        # TODO: Actually we could very well, esp. for sequence repeats.
233        # pylint: disable=no-self-use
234        return False
235
236
237class ExpressionOperationAddMixin(object):
238    # Mixins are not allow to specify slots, pylint: disable=assigning-non-slot
239    __slots__ = ()
240
241    def getValueShape(self):
242        return self.shape
243
244    def _isTooLarge(self):
245        if self.subnode_left.isKnownToBeIterable(
246            None
247        ) and self.subnode_right.isKnownToBeIterable(None):
248            size = (
249                self.subnode_left.getIterationLength()
250                + self.subnode_right.getIterationLength()
251            )
252
253            # TODO: Actually could make a predictor, but we don't use it yet.
254            self.shape = ShapeLargeConstantValuePredictable(
255                size=size,
256                predictor=None,  # predictValuesFromRightAndLeftValue,
257                shape=self.subnode_left.getTypeShape(),
258            )
259
260            return size > 256
261        else:
262            return False
263
264
265class ExpressionOperationBinaryAdd(
266    ExpressionOperationAddMixin, ExpressionOperationBinaryBase
267):
268    kind = "EXPRESSION_OPERATION_BINARY_ADD"
269
270    def __init__(self, left, right, source_ref):
271        ExpressionOperationBinaryBase.__init__(
272            self, left=left, right=right, source_ref=source_ref
273        )
274
275    operator = "Add"
276    simulator = PythonOperators.binary_operator_functions[operator]
277
278    @staticmethod
279    def _getOperationShape(left_shape, right_shape):
280        return left_shape.getOperationBinaryAddShape(right_shape)
281
282
283class ExpressionOperationBinarySub(ExpressionOperationBinaryBase):
284    kind = "EXPRESSION_OPERATION_BINARY_SUB"
285
286    operator = "Sub"
287    simulator = PythonOperators.binary_operator_functions[operator]
288
289    @staticmethod
290    def _getOperationShape(left_shape, right_shape):
291        return left_shape.getOperationBinarySubShape(right_shape)
292
293
294class ExpressionOperationMultMixin(object):
295    # Mixins are not allow to specify slots, pylint: disable=assigning-non-slot
296    __slots__ = ()
297
298    def getValueShape(self):
299        return self.shape
300
301    def _isTooLarge(self):
302        if self.subnode_right.isNumberConstant():
303            iter_length = self.subnode_left.getIterationLength()
304
305            if iter_length is not None:
306                size = iter_length * self.subnode_right.getCompileTimeConstant()
307                if size > 256:
308                    self.shape = ShapeLargeConstantValuePredictable(
309                        size=size,
310                        predictor=None,  # predictValuesFromRightAndLeftValue,
311                        shape=self.subnode_left.getTypeShape(),
312                    )
313
314                    return True
315
316            if self.subnode_left.isNumberConstant():
317                if (
318                    self.subnode_left.isIndexConstant()
319                    and self.subnode_right.isIndexConstant()
320                ):
321                    # Estimate with logarithm, if the result of number
322                    # calculations is computable with acceptable effort,
323                    # otherwise, we will have to do it at runtime.
324                    left_value = self.subnode_left.getCompileTimeConstant()
325
326                    if left_value != 0:
327                        right_value = self.subnode_right.getCompileTimeConstant()
328
329                        # TODO: Is this really useful, can this be really slow.
330                        if right_value != 0:
331                            if (
332                                math.log10(abs(left_value))
333                                + math.log10(abs(right_value))
334                                > 20
335                            ):
336                                self.shape = ShapeLargeConstantValue(
337                                    size=None, shape=tshape_int_or_long
338                                )
339
340                                return True
341
342        elif self.subnode_left.isNumberConstant():
343            iter_length = self.subnode_right.getIterationLength()
344
345            if iter_length is not None:
346                left_value = self.subnode_left.getCompileTimeConstant()
347
348                size = iter_length * left_value
349                if iter_length * left_value > 256:
350                    self.shape = ShapeLargeConstantValuePredictable(
351                        size=size,
352                        predictor=None,  # predictValuesFromRightAndLeftValue,
353                        shape=self.subnode_right.getTypeShape(),
354                    )
355
356                    return True
357
358        return False
359
360
361class ExpressionOperationBinaryMult(
362    ExpressionOperationMultMixin, ExpressionOperationBinaryBase
363):
364    kind = "EXPRESSION_OPERATION_BINARY_MULT"
365
366    operator = "Mult"
367    simulator = PythonOperators.binary_operator_functions[operator]
368
369    def __init__(self, left, right, source_ref):
370        ExpressionOperationBinaryBase.__init__(
371            self, left=left, right=right, source_ref=source_ref
372        )
373
374    @staticmethod
375    def _getOperationShape(left_shape, right_shape):
376        return left_shape.getOperationBinaryMultShape(right_shape)
377
378    def getIterationLength(self):
379        left_length = self.subnode_left.getIterationLength()
380
381        if left_length is not None:
382            right_value = self.subnode_right.getIntegerValue()
383
384            if right_value is not None:
385                return left_length * right_value
386
387        right_length = self.subnode_right.getIterationLength()
388
389        if right_length is not None:
390            left_value = self.subnode_left.getIntegerValue()
391
392            if left_value is not None:
393                return right_length * left_value
394
395        return ExpressionOperationBinaryBase.getIterationLength(self)
396
397    def extractSideEffects(self):
398        left_length = self.subnode_left.getIterationLength()
399
400        if left_length is not None:
401            right_value = self.subnode_right.getIntegerValue()
402
403            if right_value is not None:
404                return (
405                    self.subnode_left.extractSideEffects()
406                    + self.subnode_right.extractSideEffects()
407                )
408
409        right_length = self.subnode_right.getIterationLength()
410
411        if right_length is not None:
412            left_value = self.subnode_left.getIntegerValue()
413
414            if left_value is not None:
415                return (
416                    self.subnode_left.extractSideEffects()
417                    + self.subnode_right.extractSideEffects()
418                )
419
420        return ExpressionOperationBinaryBase.extractSideEffects(self)
421
422
423class ExpressionOperationBinaryFloorDiv(ExpressionOperationBinaryBase):
424    kind = "EXPRESSION_OPERATION_BINARY_FLOOR_DIV"
425
426    operator = "FloorDiv"
427    simulator = PythonOperators.binary_operator_functions[operator]
428
429    @staticmethod
430    def _getOperationShape(left_shape, right_shape):
431        return left_shape.getOperationBinaryFloorDivShape(right_shape)
432
433
434if python_version < 0x300:
435
436    class ExpressionOperationBinaryOldDiv(ExpressionOperationBinaryBase):
437        kind = "EXPRESSION_OPERATION_BINARY_OLD_DIV"
438
439        operator = "OldDiv"
440        simulator = PythonOperators.binary_operator_functions[operator]
441
442        @staticmethod
443        def _getOperationShape(left_shape, right_shape):
444            return left_shape.getOperationBinaryOldDivShape(right_shape)
445
446
447class ExpressionOperationBinaryTrueDiv(ExpressionOperationBinaryBase):
448    kind = "EXPRESSION_OPERATION_BINARY_TRUE_DIV"
449
450    operator = "TrueDiv"
451    simulator = PythonOperators.binary_operator_functions[operator]
452
453    @staticmethod
454    def _getOperationShape(left_shape, right_shape):
455        return left_shape.getOperationBinaryTrueDivShape(right_shape)
456
457
458class ExpressionOperationBinaryMod(ExpressionOperationBinaryBase):
459    kind = "EXPRESSION_OPERATION_BINARY_MOD"
460
461    operator = "Mod"
462    simulator = PythonOperators.binary_operator_functions[operator]
463
464    @staticmethod
465    def _getOperationShape(left_shape, right_shape):
466        return left_shape.getOperationBinaryModShape(right_shape)
467
468
469class ExpressionOperationBinaryDivmod(ExpressionOperationBinaryBase):
470    kind = "EXPRESSION_OPERATION_BINARY_DIVMOD"
471
472    operator = "Divmod"
473    simulator = PythonOperators.binary_operator_functions[operator]
474
475    def __init__(self, left, right, source_ref):
476        ExpressionOperationBinaryBase.__init__(
477            self, left=left, right=right, source_ref=source_ref
478        )
479
480    @staticmethod
481    def _getOperationShape(left_shape, right_shape):
482        return left_shape.getOperationBinaryDivmodShape(right_shape)
483
484
485class ExpressionOperationBinaryPow(ExpressionOperationBinaryBase):
486    kind = "EXPRESSION_OPERATION_BINARY_POW"
487
488    operator = "Pow"
489    simulator = PythonOperators.binary_operator_functions[operator]
490
491    @staticmethod
492    def _getOperationShape(left_shape, right_shape):
493        return left_shape.getOperationBinaryPowShape(right_shape)
494
495
496class ExpressionOperationBinaryLshift(ExpressionOperationBinaryBase):
497    kind = "EXPRESSION_OPERATION_BINARY_LSHIFT"
498
499    operator = "LShift"
500    simulator = PythonOperators.binary_operator_functions[operator]
501
502    @staticmethod
503    def _getOperationShape(left_shape, right_shape):
504        return left_shape.getOperationBinaryLShiftShape(right_shape)
505
506
507class ExpressionOperationBinaryRshift(ExpressionOperationBinaryBase):
508    kind = "EXPRESSION_OPERATION_BINARY_RSHIFT"
509
510    operator = "RShift"
511    simulator = PythonOperators.binary_operator_functions[operator]
512
513    @staticmethod
514    def _getOperationShape(left_shape, right_shape):
515        return left_shape.getOperationBinaryRShiftShape(right_shape)
516
517
518class ExpressionOperationBinaryBitOr(ExpressionOperationBinaryBase):
519    kind = "EXPRESSION_OPERATION_BINARY_BIT_OR"
520
521    operator = "BitOr"
522    simulator = PythonOperators.binary_operator_functions[operator]
523
524    @staticmethod
525    def _getOperationShape(left_shape, right_shape):
526        return left_shape.getOperationBinaryBitOrShape(right_shape)
527
528
529class ExpressionOperationBinaryBitAnd(ExpressionOperationBinaryBase):
530    kind = "EXPRESSION_OPERATION_BINARY_BIT_AND"
531
532    operator = "BitAnd"
533    simulator = PythonOperators.binary_operator_functions[operator]
534
535    @staticmethod
536    def _getOperationShape(left_shape, right_shape):
537        return left_shape.getOperationBinaryBitAndShape(right_shape)
538
539
540class ExpressionOperationBinaryBitXor(ExpressionOperationBinaryBase):
541    kind = "EXPRESSION_OPERATION_BINARY_BIT_XOR"
542
543    operator = "BitXor"
544    simulator = PythonOperators.binary_operator_functions[operator]
545
546    @staticmethod
547    def _getOperationShape(left_shape, right_shape):
548        return left_shape.getOperationBinaryBitXorShape(right_shape)
549
550
551if python_version >= 0x350:
552
553    class ExpressionOperationBinaryMatMult(ExpressionOperationBinaryBase):
554        kind = "EXPRESSION_OPERATION_BINARY_MAT_MULT"
555
556        operator = "MatMult"
557        simulator = PythonOperators.binary_operator_functions[operator]
558
559        @staticmethod
560        def _getOperationShape(left_shape, right_shape):
561            return left_shape.getOperationBinaryMatMultShape(right_shape)
562
563
564_operator2binary_operation_nodeclass = {
565    "Add": ExpressionOperationBinaryAdd,
566    "Sub": ExpressionOperationBinarySub,
567    "Mult": ExpressionOperationBinaryMult,
568    "FloorDiv": ExpressionOperationBinaryFloorDiv,
569    "TrueDiv": ExpressionOperationBinaryTrueDiv,
570    "Mod": ExpressionOperationBinaryMod,
571    # Divmod only from built-in call.
572    "Pow": ExpressionOperationBinaryPow,
573    "LShift": ExpressionOperationBinaryLshift,
574    "RShift": ExpressionOperationBinaryRshift,
575    "BitOr": ExpressionOperationBinaryBitOr,
576    "BitAnd": ExpressionOperationBinaryBitAnd,
577    "BitXor": ExpressionOperationBinaryBitXor,
578}
579
580if python_version < 0x300:
581    _operator2binary_operation_nodeclass["OldDiv"] = ExpressionOperationBinaryOldDiv
582
583if python_version >= 0x350:
584    _operator2binary_operation_nodeclass["MatMult"] = ExpressionOperationBinaryMatMult
585
586
587def makeBinaryOperationNode(operator, left, right, source_ref):
588    node_class = _operator2binary_operation_nodeclass[operator]
589
590    return node_class(left=left, right=right, source_ref=source_ref)
591
592
593class ExpressionOperationBinaryInplaceBase(ExpressionOperationBinaryBase):
594    # Base classes can be abstract, pylint: disable=abstract-method
595    """Base class for all inplace operations."""
596
597    def __init__(self, left, right, source_ref):
598        ExpressionOperationBinaryBase.__init__(
599            self, left=left, right=right, source_ref=source_ref
600        )
601
602        self.inplace_suspect = True
603
604    @staticmethod
605    def isExpressionOperationInplace():
606        return True
607
608    def computeExpression(self, trace_collection):
609        # Nothing to do anymore for large constants.
610        if self.shape is not None and self.shape.isConstant():
611            return self, None, None
612
613        left = self.subnode_left
614        left_shape = left.getTypeShape()
615        right = self.subnode_right
616        right_shape = right.getTypeShape()
617
618        self.type_shape, self.escape_desc = self._getOperationShape(
619            left_shape, right_shape
620        )
621
622        if left.isCompileTimeConstant() and right.isCompileTimeConstant():
623            if not self._isTooLarge():
624                return self._simulateOperation(trace_collection)
625
626        exception_raise_exit = self.escape_desc.getExceptionExit()
627        if exception_raise_exit is not None:
628            trace_collection.onExceptionRaiseExit(exception_raise_exit)
629
630            if self.escape_desc.isUnsupported() and self.canCreateUnsupportedException(
631                left_shape, right_shape
632            ):
633                result = wrapExpressionWithSideEffects(
634                    new_node=makeRaiseExceptionReplacementExpressionFromInstance(
635                        expression=self,
636                        exception=self.createUnsupportedException(
637                            left_shape,
638                            right_shape,
639                        ),
640                    ),
641                    old_node=self,
642                    side_effects=(left, right),
643                )
644
645                return (
646                    result,
647                    "new_raise",
648                    "Replaced inplace-operator '%s' with %s %s arguments that cannot work."
649                    % (self.operator, left_shape, right_shape),
650                )
651
652        if self.escape_desc.isValueEscaping():
653            # The value of these nodes escaped and could change its contents.
654            trace_collection.removeKnowledge(left)
655            trace_collection.removeKnowledge(right)
656
657        if self.escape_desc.isControlFlowEscape():
658            # Any code could be run, note that.
659            trace_collection.onControlFlowEscape(self)
660
661        if left_shape is tshape_bool:
662            result = makeBinaryOperationNode(
663                self.operator[1:], left, right, self.source_ref
664            )
665
666            return trace_collection.computedExpressionResult(
667                result,
668                "new_expression",
669                "Lowered inplace-operator '%s' to binary operation." % self.operator,
670            )
671
672        return self, None, None
673
674
675class ExpressionOperationInplaceAdd(
676    ExpressionOperationAddMixin, ExpressionOperationBinaryInplaceBase
677):
678    kind = "EXPRESSION_OPERATION_INPLACE_ADD"
679
680    operator = "IAdd"
681    simulator = PythonOperators.binary_operator_functions[operator]
682
683    def __init__(self, left, right, source_ref):
684        ExpressionOperationBinaryInplaceBase.__init__(
685            self, left=left, right=right, source_ref=source_ref
686        )
687
688    @staticmethod
689    def _getOperationShape(left_shape, right_shape):
690        return left_shape.getOperationInplaceAddShape(right_shape)
691
692
693class ExpressionOperationInplaceSub(ExpressionOperationBinaryInplaceBase):
694    kind = "EXPRESSION_OPERATION_INPLACE_SUB"
695
696    operator = "ISub"
697    simulator = PythonOperators.binary_operator_functions[operator]
698
699    @staticmethod
700    def _getOperationShape(left_shape, right_shape):
701        return left_shape.getOperationBinarySubShape(right_shape)
702
703
704class ExpressionOperationInplaceMult(ExpressionOperationBinaryInplaceBase):
705    kind = "EXPRESSION_OPERATION_INPLACE_MULT"
706
707    operator = "IMult"
708    simulator = PythonOperators.binary_operator_functions[operator]
709
710    @staticmethod
711    def _getOperationShape(left_shape, right_shape):
712        return left_shape.getOperationBinaryMultShape(right_shape)
713
714
715class ExpressionOperationInplaceFloorDiv(ExpressionOperationBinaryInplaceBase):
716    kind = "EXPRESSION_OPERATION_INPLACE_FLOOR_DIV"
717
718    operator = "IFloorDiv"
719    simulator = PythonOperators.binary_operator_functions[operator]
720
721    @staticmethod
722    def _getOperationShape(left_shape, right_shape):
723        return left_shape.getOperationBinaryFloorDivShape(right_shape)
724
725
726if python_version < 0x300:
727
728    class ExpressionOperationInplaceOldDiv(ExpressionOperationBinaryInplaceBase):
729        kind = "EXPRESSION_OPERATION_INPLACE_OLD_DIV"
730
731        operator = "IOldDiv"
732        simulator = PythonOperators.binary_operator_functions[operator]
733
734        @staticmethod
735        def _getOperationShape(left_shape, right_shape):
736            return left_shape.getOperationBinaryOldDivShape(right_shape)
737
738
739class ExpressionOperationInplaceTrueDiv(ExpressionOperationBinaryInplaceBase):
740    kind = "EXPRESSION_OPERATION_INPLACE_TRUE_DIV"
741
742    operator = "ITrueDiv"
743    simulator = PythonOperators.binary_operator_functions[operator]
744
745    @staticmethod
746    def _getOperationShape(left_shape, right_shape):
747        return left_shape.getOperationBinaryTrueDivShape(right_shape)
748
749
750class ExpressionOperationInplaceMod(ExpressionOperationBinaryInplaceBase):
751    kind = "EXPRESSION_OPERATION_INPLACE_MOD"
752
753    operator = "IMod"
754    simulator = PythonOperators.binary_operator_functions[operator]
755
756    @staticmethod
757    def _getOperationShape(left_shape, right_shape):
758        return left_shape.getOperationBinaryModShape(right_shape)
759
760
761class ExpressionOperationInplacePow(ExpressionOperationBinaryInplaceBase):
762    kind = "EXPRESSION_OPERATION_INPLACE_POW"
763
764    operator = "IPow"
765    simulator = PythonOperators.binary_operator_functions[operator]
766
767    @staticmethod
768    def _getOperationShape(left_shape, right_shape):
769        return left_shape.getOperationBinaryPowShape(right_shape)
770
771
772class ExpressionOperationInplaceLshift(ExpressionOperationBinaryInplaceBase):
773    kind = "EXPRESSION_OPERATION_INPLACE_LSHIFT"
774
775    operator = "ILShift"
776    simulator = PythonOperators.binary_operator_functions[operator]
777
778    @staticmethod
779    def _getOperationShape(left_shape, right_shape):
780        return left_shape.getOperationBinaryLShiftShape(right_shape)
781
782
783class ExpressionOperationInplaceRshift(ExpressionOperationBinaryInplaceBase):
784    kind = "EXPRESSION_OPERATION_INPLACE_RSHIFT"
785
786    operator = "IRShift"
787    simulator = PythonOperators.binary_operator_functions[operator]
788
789    @staticmethod
790    def _getOperationShape(left_shape, right_shape):
791        return left_shape.getOperationBinaryRShiftShape(right_shape)
792
793
794class ExpressionOperationInplaceBitOr(ExpressionOperationBinaryInplaceBase):
795    kind = "EXPRESSION_OPERATION_INPLACE_BIT_OR"
796
797    operator = "IBitOr"
798    simulator = PythonOperators.binary_operator_functions[operator]
799
800    # No inplace bitor special handling before 3.9
801    if python_version < 0x390:
802
803        @staticmethod
804        def _getOperationShape(left_shape, right_shape):
805            return left_shape.getOperationBinaryBitOrShape(right_shape)
806
807    else:
808
809        @staticmethod
810        def _getOperationShape(left_shape, right_shape):
811            return left_shape.getOperationInplaceBitOrShape(right_shape)
812
813
814class ExpressionOperationInplaceBitAnd(ExpressionOperationBinaryInplaceBase):
815    kind = "EXPRESSION_OPERATION_INPLACE_BIT_AND"
816
817    operator = "IBitAnd"
818    simulator = PythonOperators.binary_operator_functions[operator]
819
820    @staticmethod
821    def _getOperationShape(left_shape, right_shape):
822        return left_shape.getOperationBinaryBitAndShape(right_shape)
823
824
825class ExpressionOperationInplaceBitXor(ExpressionOperationBinaryInplaceBase):
826    kind = "EXPRESSION_OPERATION_INPLACE_BIT_XOR"
827
828    operator = "IBitXor"
829    simulator = PythonOperators.binary_operator_functions[operator]
830
831    @staticmethod
832    def _getOperationShape(left_shape, right_shape):
833        return left_shape.getOperationBinaryBitXorShape(right_shape)
834
835
836if python_version >= 0x350:
837
838    class ExpressionOperationInplaceMatMult(ExpressionOperationBinaryInplaceBase):
839        kind = "EXPRESSION_OPERATION_INPLACE_MAT_MULT"
840
841        operator = "IMatMult"
842        simulator = PythonOperators.binary_operator_functions[operator]
843
844        @staticmethod
845        def _getOperationShape(left_shape, right_shape):
846            return left_shape.getOperationBinaryMatMultShape(right_shape)
847
848
849_operator2binary_inplace_nodeclass = {
850    "IAdd": ExpressionOperationInplaceAdd,
851    "ISub": ExpressionOperationInplaceSub,
852    "IMult": ExpressionOperationInplaceMult,
853    "IFloorDiv": ExpressionOperationInplaceFloorDiv,
854    "ITrueDiv": ExpressionOperationInplaceTrueDiv,
855    "IMod": ExpressionOperationInplaceMod,
856    "IPow": ExpressionOperationInplacePow,
857    "ILShift": ExpressionOperationInplaceLshift,
858    "IRShift": ExpressionOperationInplaceRshift,
859    "IBitOr": ExpressionOperationInplaceBitOr,
860    "IBitAnd": ExpressionOperationInplaceBitAnd,
861    "IBitXor": ExpressionOperationInplaceBitXor,
862}
863
864if python_version < 0x300:
865    _operator2binary_inplace_nodeclass["IOldDiv"] = ExpressionOperationInplaceOldDiv
866
867if python_version >= 0x350:
868    _operator2binary_inplace_nodeclass["IMatMult"] = ExpressionOperationInplaceMatMult
869
870
871def makeExpressionOperationBinaryInplace(operator, left, right, source_ref):
872    node_class = _operator2binary_inplace_nodeclass[operator]
873
874    return node_class(left=left, right=right, source_ref=source_ref)
875