1# Copyright 2021, Kay Hayen, mailto:kay.hayen@gmail.com 2# 3# Part of "Nuitka", an optimizing Python compiler that is compatible and 4# integrates with CPython, but also works on its own. 5# 6# Licensed under the Apache License, Version 2.0 (the "License"); 7# you may not use this file except in compliance with the License. 8# You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, software 13# distributed under the License is distributed on an "AS IS" BASIS, 14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15# See the License for the specific language governing permissions and 16# limitations under the License. 17# 18""" Nodes for comparisons. 19 20""" 21 22from nuitka import PythonOperators 23from nuitka.Errors import NuitkaAssumptionError 24from nuitka.PythonVersions import python_version 25 26from .ExpressionBases import ExpressionChildrenHavingBase 27from .NodeMakingHelpers import ( 28 makeConstantReplacementNode, 29 makeRaiseExceptionReplacementExpressionFromInstance, 30 wrapExpressionWithSideEffects, 31) 32from .shapes.BuiltinTypeShapes import tshape_bool, tshape_exception_class 33 34 35class ExpressionComparisonBase(ExpressionChildrenHavingBase): 36 named_children = ("left", "right") 37 38 def __init__(self, left, right, source_ref): 39 ExpressionChildrenHavingBase.__init__( 40 self, values={"left": left, "right": right}, source_ref=source_ref 41 ) 42 43 @staticmethod 44 def copyTraceStateFrom(source): 45 pass 46 47 def getOperands(self): 48 return (self.subnode_left, self.subnode_right) 49 50 def getComparator(self): 51 return self.comparator 52 53 def getDetails(self): 54 return {"comparator": self.comparator} 55 56 @staticmethod 57 def isExpressionComparison(): 58 return True 59 60 def getSimulator(self): 61 return PythonOperators.all_comparison_functions[self.comparator] 62 63 def _computeCompileTimeConstantComparision(self, trace_collection): 64 left_value = self.subnode_left.getCompileTimeConstant() 65 right_value = self.subnode_right.getCompileTimeConstant() 66 67 return trace_collection.getCompileTimeComputationResult( 68 node=self, 69 computation=lambda: self.getSimulator()(left_value, right_value), 70 description="Comparison of constant arguments.", 71 ) 72 73 def computeExpression(self, trace_collection): 74 left = self.subnode_left 75 right = self.subnode_right 76 77 if left.isCompileTimeConstant() and right.isCompileTimeConstant(): 78 return self._computeCompileTimeConstantComparision(trace_collection) 79 80 # The value of these nodes escaped and could change its contents. 81 # TODO: Comparisons don't do much, but add this. 82 # trace_collection.onValueEscapeRichComparison(left, right, self.comparator) 83 84 # Any code could be run, note that. 85 trace_collection.onControlFlowEscape(self) 86 87 trace_collection.onExceptionRaiseExit(BaseException) 88 89 return self, None, None 90 91 def makeInverseComparision(self): 92 # Making this accessing for tree building phase as well. 93 return makeComparisonExpression( 94 left=self.subnode_left, 95 right=self.subnode_right, 96 comparator=PythonOperators.comparison_inversions[self.comparator], 97 source_ref=self.source_ref, 98 ) 99 100 def computeExpressionOperationNot(self, not_node, trace_collection): 101 if self.getTypeShape() is tshape_bool: 102 result = self.makeInverseComparision() 103 104 result.copyTraceStateFrom(self) 105 106 return ( 107 result, 108 "new_expression", 109 """Replaced negated comparison '%s' with inverse comparison '%s'.""" 110 % (self.comparator, result.comparator), 111 ) 112 113 return not_node, None, None 114 115 116class ExpressionComparisonRichBase(ExpressionComparisonBase): 117 __slots__ = "type_shape", "escape_desc" 118 119 def __init__(self, left, right, source_ref): 120 ExpressionComparisonBase.__init__( 121 self, left=left, right=right, source_ref=source_ref 122 ) 123 124 self.type_shape = None 125 self.escape_desc = None 126 127 def getTypeShape(self): 128 return self.type_shape 129 130 @staticmethod 131 def getDetails(): 132 return {} 133 134 def copyTraceStateFrom(self, source): 135 self.type_shape = source.type_shape 136 self.escape_desc = source.escape_desc 137 138 def canCreateUnsupportedException(self): 139 return hasattr(self.subnode_left.getTypeShape(), "typical_value") and hasattr( 140 self.subnode_right.getTypeShape(), "typical_value" 141 ) 142 143 def createUnsupportedException(self): 144 left = self.subnode_left.getTypeShape().typical_value 145 right = self.subnode_right.getTypeShape().typical_value 146 147 try: 148 self.getSimulator()(left, right) 149 except TypeError as e: 150 return e 151 else: 152 raise NuitkaAssumptionError( 153 "Unexpected no-exception doing comparison simulation", 154 self.operator, 155 self.simulator, 156 self.subnode_left.getTypeShape(), 157 self.subnode_right.getTypeShape(), 158 repr(left), 159 repr(right), 160 ) 161 162 def computeExpression(self, trace_collection): 163 left = self.subnode_left 164 right = self.subnode_right 165 166 if left.isCompileTimeConstant() and right.isCompileTimeConstant(): 167 return self._computeCompileTimeConstantComparision(trace_collection) 168 169 left_shape = left.getTypeShape() 170 right_shape = right.getTypeShape() 171 172 self.type_shape, self.escape_desc = self.getComparisonShape( 173 left_shape, right_shape 174 ) 175 176 exception_raise_exit = self.escape_desc.getExceptionExit() 177 if exception_raise_exit is not None: 178 trace_collection.onExceptionRaiseExit(exception_raise_exit) 179 180 if ( 181 self.escape_desc.isUnsupported() 182 and self.canCreateUnsupportedException() 183 ): 184 result = wrapExpressionWithSideEffects( 185 new_node=makeRaiseExceptionReplacementExpressionFromInstance( 186 expression=self, exception=self.createUnsupportedException() 187 ), 188 old_node=self, 189 side_effects=(self.subnode_left, self.subnode_right), 190 ) 191 192 return ( 193 result, 194 "new_raise", 195 """Replaced comparator '%s' with %s %s arguments that cannot work.""" 196 % ( 197 self.comparator, 198 self.subnode_left.getTypeShape(), 199 self.subnode_right.getTypeShape(), 200 ), 201 ) 202 203 # The value of these nodes escaped and could change its contents. 204 205 # TODO: Comparisons don't do much, but add this. 206 # if self.escape_desc.isValueEscaping(): 207 # trace_collection.onValueEscapeRichComparison(left, right, self.comparator) 208 209 if self.escape_desc.isControlFlowEscape(): 210 # Any code could be run, note that. 211 trace_collection.onControlFlowEscape(self) 212 213 return self, None, None 214 215 def mayRaiseException(self, exception_type): 216 # TODO: Match more precisely 217 return ( 218 self.escape_desc is None 219 or self.escape_desc.getExceptionExit() is not None 220 or self.subnode_left.mayRaiseException(exception_type) 221 or self.subnode_right.mayRaiseException(exception_type) 222 ) 223 224 def mayRaiseExceptionBool(self, exception_type): 225 return self.type_shape.hasShapeSlotBool() is not True 226 227 def mayRaiseExceptionComparison(self): 228 return ( 229 self.escape_desc is None or self.escape_desc.getExceptionExit() is not None 230 ) 231 232 233class ExpressionComparisonLt(ExpressionComparisonRichBase): 234 kind = "EXPRESSION_COMPARISON_LT" 235 236 comparator = "Lt" 237 238 def __init__(self, left, right, source_ref): 239 ExpressionComparisonRichBase.__init__( 240 self, left=left, right=right, source_ref=source_ref 241 ) 242 243 @staticmethod 244 def getComparisonShape(left_shape, right_shape): 245 return left_shape.getComparisonLtShape(right_shape) 246 247 248class ExpressionComparisonLte(ExpressionComparisonRichBase): 249 kind = "EXPRESSION_COMPARISON_LTE" 250 251 comparator = "LtE" 252 253 def __init__(self, left, right, source_ref): 254 ExpressionComparisonRichBase.__init__( 255 self, left=left, right=right, source_ref=source_ref 256 ) 257 258 @staticmethod 259 def getComparisonShape(left_shape, right_shape): 260 return left_shape.getComparisonLteShape(right_shape) 261 262 263class ExpressionComparisonGt(ExpressionComparisonRichBase): 264 kind = "EXPRESSION_COMPARISON_GT" 265 266 comparator = "Gt" 267 268 def __init__(self, left, right, source_ref): 269 ExpressionComparisonRichBase.__init__( 270 self, left=left, right=right, source_ref=source_ref 271 ) 272 273 @staticmethod 274 def getComparisonShape(left_shape, right_shape): 275 return left_shape.getComparisonGtShape(right_shape) 276 277 278class ExpressionComparisonGte(ExpressionComparisonRichBase): 279 kind = "EXPRESSION_COMPARISON_GTE" 280 281 comparator = "GtE" 282 283 def __init__(self, left, right, source_ref): 284 ExpressionComparisonRichBase.__init__( 285 self, left=left, right=right, source_ref=source_ref 286 ) 287 288 @staticmethod 289 def getComparisonShape(left_shape, right_shape): 290 return left_shape.getComparisonGteShape(right_shape) 291 292 293class ExpressionComparisonEq(ExpressionComparisonRichBase): 294 kind = "EXPRESSION_COMPARISON_EQ" 295 296 comparator = "Eq" 297 298 def __init__(self, left, right, source_ref): 299 ExpressionComparisonRichBase.__init__( 300 self, left=left, right=right, source_ref=source_ref 301 ) 302 303 @staticmethod 304 def getComparisonShape(left_shape, right_shape): 305 return left_shape.getComparisonEqShape(right_shape) 306 307 308class ExpressionComparisonNeq(ExpressionComparisonRichBase): 309 kind = "EXPRESSION_COMPARISON_NEQ" 310 311 comparator = "NotEq" 312 313 def __init__(self, left, right, source_ref): 314 ExpressionComparisonRichBase.__init__( 315 self, left=left, right=right, source_ref=source_ref 316 ) 317 318 @staticmethod 319 def getComparisonShape(left_shape, right_shape): 320 return left_shape.getComparisonNeqShape(right_shape) 321 322 323class ExpressionComparisonIsIsNotBase(ExpressionComparisonBase): 324 __slots__ = ("match_value",) 325 326 def __init__(self, left, right, source_ref): 327 ExpressionComparisonBase.__init__( 328 self, left=left, right=right, source_ref=source_ref 329 ) 330 331 assert self.comparator in ("Is", "IsNot") 332 333 # TODO: Forward propagate this one. 334 self.match_value = self.comparator == "Is" 335 336 @staticmethod 337 def getDetails(): 338 return {} 339 340 @staticmethod 341 def getTypeShape(): 342 return tshape_bool 343 344 def mayRaiseException(self, exception_type): 345 return self.subnode_left.mayRaiseException( 346 exception_type 347 ) or self.subnode_right.mayRaiseException(exception_type) 348 349 def mayRaiseExceptionBool(self, exception_type): 350 return False 351 352 def computeExpression(self, trace_collection): 353 left, right = self.getOperands() 354 355 if trace_collection.mustAlias(left, right): 356 result = makeConstantReplacementNode( 357 constant=self.match_value, node=self, user_provided=False 358 ) 359 360 if left.mayHaveSideEffects() or right.mayHaveSideEffects(): 361 result = wrapExpressionWithSideEffects( 362 side_effects=self.extractSideEffects(), 363 old_node=self, 364 new_node=result, 365 ) 366 367 return ( 368 result, 369 "new_constant", 370 """\ 371Determined values to alias and therefore result of %s comparison.""" 372 % (self.comparator), 373 ) 374 375 if trace_collection.mustNotAlias(left, right): 376 result = makeConstantReplacementNode( 377 constant=not self.match_value, node=self, user_provided=False 378 ) 379 380 if left.mayHaveSideEffects() or right.mayHaveSideEffects(): 381 result = wrapExpressionWithSideEffects( 382 side_effects=self.extractSideEffects(), 383 old_node=self, 384 new_node=result, 385 ) 386 387 return ( 388 result, 389 "new_constant", 390 """\ 391Determined values to not alias and therefore result of '%s' comparison.""" 392 % (self.comparator), 393 ) 394 395 return ExpressionComparisonBase.computeExpression( 396 self, trace_collection=trace_collection 397 ) 398 399 def extractSideEffects(self): 400 left, right = self.getOperands() 401 402 return left.extractSideEffects() + right.extractSideEffects() 403 404 def computeExpressionDrop(self, statement, trace_collection): 405 from .NodeMakingHelpers import makeStatementOnlyNodesFromExpressions 406 407 result = makeStatementOnlyNodesFromExpressions(expressions=self.getOperands()) 408 409 del self.parent 410 411 return ( 412 result, 413 "new_statements", 414 """\ 415Removed %s comparison for unused result.""" 416 % self.comparator, 417 ) 418 419 420class ExpressionComparisonIs(ExpressionComparisonIsIsNotBase): 421 kind = "EXPRESSION_COMPARISON_IS" 422 423 comparator = "Is" 424 425 def __init__(self, left, right, source_ref): 426 ExpressionComparisonIsIsNotBase.__init__( 427 self, left=left, right=right, source_ref=source_ref 428 ) 429 430 431class ExpressionComparisonIsNot(ExpressionComparisonIsIsNotBase): 432 kind = "EXPRESSION_COMPARISON_IS_NOT" 433 434 comparator = "IsNot" 435 436 def __init__(self, left, right, source_ref): 437 ExpressionComparisonIsIsNotBase.__init__( 438 self, left=left, right=right, source_ref=source_ref 439 ) 440 441 442class ExpressionComparisonExceptionMatchBase(ExpressionComparisonBase): 443 def __init__(self, left, right, source_ref): 444 ExpressionComparisonBase.__init__( 445 self, left=left, right=right, source_ref=source_ref 446 ) 447 448 @staticmethod 449 def getDetails(): 450 return {} 451 452 @staticmethod 453 def getTypeShape(): 454 return tshape_bool 455 456 def getSimulator(self): 457 # TODO: Doesn't happen yet, but will once we trace exceptions. 458 assert False 459 460 return PythonOperators.all_comparison_functions[self.comparator] 461 462 def mayRaiseException(self, exception_type): 463 # TODO: Match errors that exception comparisons might raise more accurately. 464 return ( 465 self.subnode_left.mayRaiseException(exception_type) 466 or self.subnode_right.mayRaiseException(exception_type) 467 or self.mayRaiseExceptionComparison() 468 ) 469 470 def mayRaiseExceptionComparison(self): 471 if python_version < 0x300: 472 return False 473 474 # TODO: Add shape for exceptions. 475 type_shape = self.subnode_right.getTypeShape() 476 477 if type_shape is tshape_exception_class: 478 return False 479 480 return True 481 482 @staticmethod 483 def mayRaiseExceptionBool(exception_type): 484 return False 485 486 487class ExpressionComparisonExceptionMatch(ExpressionComparisonExceptionMatchBase): 488 kind = "EXPRESSION_COMPARISON_EXCEPTION_MATCH" 489 490 comparator = "exception_match" 491 492 493class ExpressionComparisonExceptionMismatch(ExpressionComparisonExceptionMatchBase): 494 kind = "EXPRESSION_COMPARISON_EXCEPTION_MISMATCH" 495 496 comparator = "exception_mismatch" 497 498 499class ExpressionComparisonInNotInBase(ExpressionComparisonBase): 500 def __init__(self, left, right, source_ref): 501 ExpressionComparisonBase.__init__( 502 self, left=left, right=right, source_ref=source_ref 503 ) 504 505 assert self.comparator in ("In", "NotIn") 506 507 @staticmethod 508 def getDetails(): 509 return {} 510 511 @staticmethod 512 def getTypeShape(): 513 return tshape_bool 514 515 def mayRaiseException(self, exception_type): 516 left = self.subnode_left 517 518 if left.mayRaiseException(exception_type): 519 return True 520 521 right = self.subnode_right 522 523 if right.mayRaiseException(exception_type): 524 return True 525 526 return right.mayRaiseExceptionIn(exception_type, left) 527 528 @staticmethod 529 def mayRaiseExceptionBool(exception_type): 530 return False 531 532 def computeExpression(self, trace_collection): 533 return self.subnode_right.computeExpressionComparisonIn( 534 in_node=self, 535 value_node=self.subnode_left, 536 trace_collection=trace_collection, 537 ) 538 539 540class ExpressionComparisonIn(ExpressionComparisonInNotInBase): 541 kind = "EXPRESSION_COMPARISON_IN" 542 543 comparator = "In" 544 545 def __init__(self, left, right, source_ref): 546 ExpressionComparisonInNotInBase.__init__( 547 self, left=left, right=right, source_ref=source_ref 548 ) 549 550 551class ExpressionComparisonNotIn(ExpressionComparisonInNotInBase): 552 kind = "EXPRESSION_COMPARISON_NOT_IN" 553 554 comparator = "NotIn" 555 556 def __init__(self, left, right, source_ref): 557 ExpressionComparisonInNotInBase.__init__( 558 self, left=left, right=right, source_ref=source_ref 559 ) 560 561 562_comparator_to_nodeclass = { 563 "Is": ExpressionComparisonIs, 564 "IsNot": ExpressionComparisonIsNot, 565 "In": ExpressionComparisonIn, 566 "NotIn": ExpressionComparisonNotIn, 567 "Lt": ExpressionComparisonLt, 568 "LtE": ExpressionComparisonLte, 569 "Gt": ExpressionComparisonGt, 570 "GtE": ExpressionComparisonGte, 571 "Eq": ExpressionComparisonEq, 572 "NotEq": ExpressionComparisonNeq, 573 "exception_match": ExpressionComparisonExceptionMatch, 574 "exception_mismatch": ExpressionComparisonExceptionMismatch, 575} 576 577 578def makeComparisonExpression(left, right, comparator, source_ref): 579 return _comparator_to_nodeclass[comparator]( 580 left=left, right=right, source_ref=source_ref 581 ) 582