1from __future__ import absolute_import 2 3import re 4import sys 5import copy 6import codecs 7import itertools 8 9from . import TypeSlots 10from .ExprNodes import not_a_constant 11import cython 12cython.declare(UtilityCode=object, EncodedString=object, bytes_literal=object, encoded_string=object, 13 Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object, 14 UtilNodes=object, _py_int_types=object) 15 16if sys.version_info[0] >= 3: 17 _py_int_types = int 18 _py_string_types = (bytes, str) 19else: 20 _py_int_types = (int, long) 21 _py_string_types = (bytes, unicode) 22 23from . import Nodes 24from . import ExprNodes 25from . import PyrexTypes 26from . import Visitor 27from . import Builtin 28from . import UtilNodes 29from . import Options 30 31from .Code import UtilityCode, TempitaUtilityCode 32from .StringEncoding import EncodedString, bytes_literal, encoded_string 33from .Errors import error, warning 34from .ParseTreeTransforms import SkipDeclarations 35 36try: 37 from __builtin__ import reduce 38except ImportError: 39 from functools import reduce 40 41try: 42 from __builtin__ import basestring 43except ImportError: 44 basestring = str # Python 3 45 46 47def load_c_utility(name): 48 return UtilityCode.load_cached(name, "Optimize.c") 49 50 51def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)): 52 if isinstance(node, coercion_nodes): 53 return node.arg 54 return node 55 56 57def unwrap_node(node): 58 while isinstance(node, UtilNodes.ResultRefNode): 59 node = node.expression 60 return node 61 62 63def is_common_value(a, b): 64 a = unwrap_node(a) 65 b = unwrap_node(b) 66 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode): 67 return a.name == b.name 68 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode): 69 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute 70 return False 71 72 73def filter_none_node(node): 74 if node is not None and node.constant_result is None: 75 return None 76 return node 77 78 79class _YieldNodeCollector(Visitor.TreeVisitor): 80 """ 81 YieldExprNode finder for generator expressions. 82 """ 83 def __init__(self): 84 Visitor.TreeVisitor.__init__(self) 85 self.yield_stat_nodes = {} 86 self.yield_nodes = [] 87 88 visit_Node = Visitor.TreeVisitor.visitchildren 89 90 def visit_YieldExprNode(self, node): 91 self.yield_nodes.append(node) 92 self.visitchildren(node) 93 94 def visit_ExprStatNode(self, node): 95 self.visitchildren(node) 96 if node.expr in self.yield_nodes: 97 self.yield_stat_nodes[node.expr] = node 98 99 # everything below these nodes is out of scope: 100 101 def visit_GeneratorExpressionNode(self, node): 102 pass 103 104 def visit_LambdaNode(self, node): 105 pass 106 107 def visit_FuncDefNode(self, node): 108 pass 109 110 111def _find_single_yield_expression(node): 112 yield_statements = _find_yield_statements(node) 113 if len(yield_statements) != 1: 114 return None, None 115 return yield_statements[0] 116 117 118def _find_yield_statements(node): 119 collector = _YieldNodeCollector() 120 collector.visitchildren(node) 121 try: 122 yield_statements = [ 123 (yield_node.arg, collector.yield_stat_nodes[yield_node]) 124 for yield_node in collector.yield_nodes 125 ] 126 except KeyError: 127 # found YieldExprNode without ExprStatNode (i.e. a non-statement usage of 'yield') 128 yield_statements = [] 129 return yield_statements 130 131 132class IterationTransform(Visitor.EnvTransform): 133 """Transform some common for-in loop patterns into efficient C loops: 134 135 - for-in-dict loop becomes a while loop calling PyDict_Next() 136 - for-in-enumerate is replaced by an external counter variable 137 - for-in-range loop becomes a plain C for loop 138 """ 139 def visit_PrimaryCmpNode(self, node): 140 if node.is_ptr_contains(): 141 142 # for t in operand2: 143 # if operand1 == t: 144 # res = True 145 # break 146 # else: 147 # res = False 148 149 pos = node.pos 150 result_ref = UtilNodes.ResultRefNode(node) 151 if node.operand2.is_subscript: 152 base_type = node.operand2.base.type.base_type 153 else: 154 base_type = node.operand2.type.base_type 155 target_handle = UtilNodes.TempHandle(base_type) 156 target = target_handle.ref(pos) 157 cmp_node = ExprNodes.PrimaryCmpNode( 158 pos, operator=u'==', operand1=node.operand1, operand2=target) 159 if_body = Nodes.StatListNode( 160 pos, 161 stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)), 162 Nodes.BreakStatNode(pos)]) 163 if_node = Nodes.IfStatNode( 164 pos, 165 if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)], 166 else_clause=None) 167 for_loop = UtilNodes.TempsBlockNode( 168 pos, 169 temps = [target_handle], 170 body = Nodes.ForInStatNode( 171 pos, 172 target=target, 173 iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2), 174 body=if_node, 175 else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0)))) 176 for_loop = for_loop.analyse_expressions(self.current_env()) 177 for_loop = self.visit(for_loop) 178 new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop) 179 180 if node.operator == 'not_in': 181 new_node = ExprNodes.NotNode(pos, operand=new_node) 182 return new_node 183 184 else: 185 self.visitchildren(node) 186 return node 187 188 def visit_ForInStatNode(self, node): 189 self.visitchildren(node) 190 return self._optimise_for_loop(node, node.iterator.sequence) 191 192 def _optimise_for_loop(self, node, iterable, reversed=False): 193 annotation_type = None 194 if (iterable.is_name or iterable.is_attribute) and iterable.entry and iterable.entry.annotation: 195 annotation = iterable.entry.annotation 196 if annotation.is_subscript: 197 annotation = annotation.base # container base type 198 # FIXME: generalise annotation evaluation => maybe provide a "qualified name" also for imported names? 199 if annotation.is_name: 200 if annotation.entry and annotation.entry.qualified_name == 'typing.Dict': 201 annotation_type = Builtin.dict_type 202 elif annotation.name == 'Dict': 203 annotation_type = Builtin.dict_type 204 if annotation.entry and annotation.entry.qualified_name in ('typing.Set', 'typing.FrozenSet'): 205 annotation_type = Builtin.set_type 206 elif annotation.name in ('Set', 'FrozenSet'): 207 annotation_type = Builtin.set_type 208 209 if Builtin.dict_type in (iterable.type, annotation_type): 210 # like iterating over dict.keys() 211 if reversed: 212 # CPython raises an error here: not a sequence 213 return node 214 return self._transform_dict_iteration( 215 node, dict_obj=iterable, method=None, keys=True, values=False) 216 217 if (Builtin.set_type in (iterable.type, annotation_type) or 218 Builtin.frozenset_type in (iterable.type, annotation_type)): 219 if reversed: 220 # CPython raises an error here: not a sequence 221 return node 222 return self._transform_set_iteration(node, iterable) 223 224 # C array (slice) iteration? 225 if iterable.type.is_ptr or iterable.type.is_array: 226 return self._transform_carray_iteration(node, iterable, reversed=reversed) 227 if iterable.type is Builtin.bytes_type: 228 return self._transform_bytes_iteration(node, iterable, reversed=reversed) 229 if iterable.type is Builtin.unicode_type: 230 return self._transform_unicode_iteration(node, iterable, reversed=reversed) 231 232 # the rest is based on function calls 233 if not isinstance(iterable, ExprNodes.SimpleCallNode): 234 return node 235 236 if iterable.args is None: 237 arg_count = iterable.arg_tuple and len(iterable.arg_tuple.args) or 0 238 else: 239 arg_count = len(iterable.args) 240 if arg_count and iterable.self is not None: 241 arg_count -= 1 242 243 function = iterable.function 244 # dict iteration? 245 if function.is_attribute and not reversed and not arg_count: 246 base_obj = iterable.self or function.obj 247 method = function.attribute 248 # in Py3, items() is equivalent to Py2's iteritems() 249 is_safe_iter = self.global_scope().context.language_level >= 3 250 251 if not is_safe_iter and method in ('keys', 'values', 'items'): 252 # try to reduce this to the corresponding .iter*() methods 253 if isinstance(base_obj, ExprNodes.CallNode): 254 inner_function = base_obj.function 255 if (inner_function.is_name and inner_function.name == 'dict' 256 and inner_function.entry 257 and inner_function.entry.is_builtin): 258 # e.g. dict(something).items() => safe to use .iter*() 259 is_safe_iter = True 260 261 keys = values = False 262 if method == 'iterkeys' or (is_safe_iter and method == 'keys'): 263 keys = True 264 elif method == 'itervalues' or (is_safe_iter and method == 'values'): 265 values = True 266 elif method == 'iteritems' or (is_safe_iter and method == 'items'): 267 keys = values = True 268 269 if keys or values: 270 return self._transform_dict_iteration( 271 node, base_obj, method, keys, values) 272 273 # enumerate/reversed ? 274 if iterable.self is None and function.is_name and \ 275 function.entry and function.entry.is_builtin: 276 if function.name == 'enumerate': 277 if reversed: 278 # CPython raises an error here: not a sequence 279 return node 280 return self._transform_enumerate_iteration(node, iterable) 281 elif function.name == 'reversed': 282 if reversed: 283 # CPython raises an error here: not a sequence 284 return node 285 return self._transform_reversed_iteration(node, iterable) 286 287 # range() iteration? 288 if Options.convert_range and arg_count >= 1 and ( 289 iterable.self is None and 290 function.is_name and function.name in ('range', 'xrange') and 291 function.entry and function.entry.is_builtin): 292 if node.target.type.is_int or node.target.type.is_enum: 293 return self._transform_range_iteration(node, iterable, reversed=reversed) 294 if node.target.type.is_pyobject: 295 # Assume that small integer ranges (C long >= 32bit) are best handled in C as well. 296 for arg in (iterable.arg_tuple.args if iterable.args is None else iterable.args): 297 if isinstance(arg, ExprNodes.IntNode): 298 if arg.has_constant_result() and -2**30 <= arg.constant_result < 2**30: 299 continue 300 break 301 else: 302 return self._transform_range_iteration(node, iterable, reversed=reversed) 303 304 return node 305 306 def _transform_reversed_iteration(self, node, reversed_function): 307 args = reversed_function.arg_tuple.args 308 if len(args) == 0: 309 error(reversed_function.pos, 310 "reversed() requires an iterable argument") 311 return node 312 elif len(args) > 1: 313 error(reversed_function.pos, 314 "reversed() takes exactly 1 argument") 315 return node 316 arg = args[0] 317 318 # reversed(list/tuple) ? 319 if arg.type in (Builtin.tuple_type, Builtin.list_type): 320 node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable") 321 node.iterator.reversed = True 322 return node 323 324 return self._optimise_for_loop(node, arg, reversed=True) 325 326 PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType( 327 PyrexTypes.c_char_ptr_type, [ 328 PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) 329 ]) 330 331 PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType( 332 PyrexTypes.c_py_ssize_t_type, [ 333 PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) 334 ]) 335 336 def _transform_bytes_iteration(self, node, slice_node, reversed=False): 337 target_type = node.target.type 338 if not target_type.is_int and target_type is not Builtin.bytes_type: 339 # bytes iteration returns bytes objects in Py2, but 340 # integers in Py3 341 return node 342 343 unpack_temp_node = UtilNodes.LetRefNode( 344 slice_node.as_none_safe_node("'NoneType' is not iterable")) 345 346 slice_base_node = ExprNodes.PythonCapiCallNode( 347 slice_node.pos, "PyBytes_AS_STRING", 348 self.PyBytes_AS_STRING_func_type, 349 args = [unpack_temp_node], 350 is_temp = 0, 351 ) 352 len_node = ExprNodes.PythonCapiCallNode( 353 slice_node.pos, "PyBytes_GET_SIZE", 354 self.PyBytes_GET_SIZE_func_type, 355 args = [unpack_temp_node], 356 is_temp = 0, 357 ) 358 359 return UtilNodes.LetNode( 360 unpack_temp_node, 361 self._transform_carray_iteration( 362 node, 363 ExprNodes.SliceIndexNode( 364 slice_node.pos, 365 base = slice_base_node, 366 start = None, 367 step = None, 368 stop = len_node, 369 type = slice_base_node.type, 370 is_temp = 1, 371 ), 372 reversed = reversed)) 373 374 PyUnicode_READ_func_type = PyrexTypes.CFuncType( 375 PyrexTypes.c_py_ucs4_type, [ 376 PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_type, None), 377 PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_type, None), 378 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None) 379 ]) 380 381 init_unicode_iteration_func_type = PyrexTypes.CFuncType( 382 PyrexTypes.c_int_type, [ 383 PyrexTypes.CFuncTypeArg("s", PyrexTypes.py_object_type, None), 384 PyrexTypes.CFuncTypeArg("length", PyrexTypes.c_py_ssize_t_ptr_type, None), 385 PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_ptr_type, None), 386 PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_ptr_type, None) 387 ], 388 exception_value = '-1') 389 390 def _transform_unicode_iteration(self, node, slice_node, reversed=False): 391 if slice_node.is_literal: 392 # try to reduce to byte iteration for plain Latin-1 strings 393 try: 394 bytes_value = bytes_literal(slice_node.value.encode('latin1'), 'iso8859-1') 395 except UnicodeEncodeError: 396 pass 397 else: 398 bytes_slice = ExprNodes.SliceIndexNode( 399 slice_node.pos, 400 base=ExprNodes.BytesNode( 401 slice_node.pos, value=bytes_value, 402 constant_result=bytes_value, 403 type=PyrexTypes.c_const_char_ptr_type).coerce_to( 404 PyrexTypes.c_const_uchar_ptr_type, self.current_env()), 405 start=None, 406 stop=ExprNodes.IntNode( 407 slice_node.pos, value=str(len(bytes_value)), 408 constant_result=len(bytes_value), 409 type=PyrexTypes.c_py_ssize_t_type), 410 type=Builtin.unicode_type, # hint for Python conversion 411 ) 412 return self._transform_carray_iteration(node, bytes_slice, reversed) 413 414 unpack_temp_node = UtilNodes.LetRefNode( 415 slice_node.as_none_safe_node("'NoneType' is not iterable")) 416 417 start_node = ExprNodes.IntNode( 418 node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type) 419 length_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) 420 end_node = length_temp.ref(node.pos) 421 if reversed: 422 relation1, relation2 = '>', '>=' 423 start_node, end_node = end_node, start_node 424 else: 425 relation1, relation2 = '<=', '<' 426 427 kind_temp = UtilNodes.TempHandle(PyrexTypes.c_int_type) 428 data_temp = UtilNodes.TempHandle(PyrexTypes.c_void_ptr_type) 429 counter_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) 430 431 target_value = ExprNodes.PythonCapiCallNode( 432 slice_node.pos, "__Pyx_PyUnicode_READ", 433 self.PyUnicode_READ_func_type, 434 args = [kind_temp.ref(slice_node.pos), 435 data_temp.ref(slice_node.pos), 436 counter_temp.ref(node.target.pos)], 437 is_temp = False, 438 ) 439 if target_value.type != node.target.type: 440 target_value = target_value.coerce_to(node.target.type, 441 self.current_env()) 442 target_assign = Nodes.SingleAssignmentNode( 443 pos = node.target.pos, 444 lhs = node.target, 445 rhs = target_value) 446 body = Nodes.StatListNode( 447 node.pos, 448 stats = [target_assign, node.body]) 449 450 loop_node = Nodes.ForFromStatNode( 451 node.pos, 452 bound1=start_node, relation1=relation1, 453 target=counter_temp.ref(node.target.pos), 454 relation2=relation2, bound2=end_node, 455 step=None, body=body, 456 else_clause=node.else_clause, 457 from_range=True) 458 459 setup_node = Nodes.ExprStatNode( 460 node.pos, 461 expr = ExprNodes.PythonCapiCallNode( 462 slice_node.pos, "__Pyx_init_unicode_iteration", 463 self.init_unicode_iteration_func_type, 464 args = [unpack_temp_node, 465 ExprNodes.AmpersandNode(slice_node.pos, operand=length_temp.ref(slice_node.pos), 466 type=PyrexTypes.c_py_ssize_t_ptr_type), 467 ExprNodes.AmpersandNode(slice_node.pos, operand=data_temp.ref(slice_node.pos), 468 type=PyrexTypes.c_void_ptr_ptr_type), 469 ExprNodes.AmpersandNode(slice_node.pos, operand=kind_temp.ref(slice_node.pos), 470 type=PyrexTypes.c_int_ptr_type), 471 ], 472 is_temp = True, 473 result_is_used = False, 474 utility_code=UtilityCode.load_cached("unicode_iter", "Optimize.c"), 475 )) 476 return UtilNodes.LetNode( 477 unpack_temp_node, 478 UtilNodes.TempsBlockNode( 479 node.pos, temps=[counter_temp, length_temp, data_temp, kind_temp], 480 body=Nodes.StatListNode(node.pos, stats=[setup_node, loop_node]))) 481 482 def _transform_carray_iteration(self, node, slice_node, reversed=False): 483 neg_step = False 484 if isinstance(slice_node, ExprNodes.SliceIndexNode): 485 slice_base = slice_node.base 486 start = filter_none_node(slice_node.start) 487 stop = filter_none_node(slice_node.stop) 488 step = None 489 if not stop: 490 if not slice_base.type.is_pyobject: 491 error(slice_node.pos, "C array iteration requires known end index") 492 return node 493 494 elif slice_node.is_subscript: 495 assert isinstance(slice_node.index, ExprNodes.SliceNode) 496 slice_base = slice_node.base 497 index = slice_node.index 498 start = filter_none_node(index.start) 499 stop = filter_none_node(index.stop) 500 step = filter_none_node(index.step) 501 if step: 502 if not isinstance(step.constant_result, _py_int_types) \ 503 or step.constant_result == 0 \ 504 or step.constant_result > 0 and not stop \ 505 or step.constant_result < 0 and not start: 506 if not slice_base.type.is_pyobject: 507 error(step.pos, "C array iteration requires known step size and end index") 508 return node 509 else: 510 # step sign is handled internally by ForFromStatNode 511 step_value = step.constant_result 512 if reversed: 513 step_value = -step_value 514 neg_step = step_value < 0 515 step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type, 516 value=str(abs(step_value)), 517 constant_result=abs(step_value)) 518 519 elif slice_node.type.is_array: 520 if slice_node.type.size is None: 521 error(slice_node.pos, "C array iteration requires known end index") 522 return node 523 slice_base = slice_node 524 start = None 525 stop = ExprNodes.IntNode( 526 slice_node.pos, value=str(slice_node.type.size), 527 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size) 528 step = None 529 530 else: 531 if not slice_node.type.is_pyobject: 532 error(slice_node.pos, "C array iteration requires known end index") 533 return node 534 535 if start: 536 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) 537 if stop: 538 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) 539 if stop is None: 540 if neg_step: 541 stop = ExprNodes.IntNode( 542 slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1) 543 else: 544 error(slice_node.pos, "C array iteration requires known step size and end index") 545 return node 546 547 if reversed: 548 if not start: 549 start = ExprNodes.IntNode(slice_node.pos, value="0", constant_result=0, 550 type=PyrexTypes.c_py_ssize_t_type) 551 # if step was provided, it was already negated above 552 start, stop = stop, start 553 554 ptr_type = slice_base.type 555 if ptr_type.is_array: 556 ptr_type = ptr_type.element_ptr_type() 557 carray_ptr = slice_base.coerce_to_simple(self.current_env()) 558 559 if start and start.constant_result != 0: 560 start_ptr_node = ExprNodes.AddNode( 561 start.pos, 562 operand1=carray_ptr, 563 operator='+', 564 operand2=start, 565 type=ptr_type) 566 else: 567 start_ptr_node = carray_ptr 568 569 if stop and stop.constant_result != 0: 570 stop_ptr_node = ExprNodes.AddNode( 571 stop.pos, 572 operand1=ExprNodes.CloneNode(carray_ptr), 573 operator='+', 574 operand2=stop, 575 type=ptr_type 576 ).coerce_to_simple(self.current_env()) 577 else: 578 stop_ptr_node = ExprNodes.CloneNode(carray_ptr) 579 580 counter = UtilNodes.TempHandle(ptr_type) 581 counter_temp = counter.ref(node.target.pos) 582 583 if slice_base.type.is_string and node.target.type.is_pyobject: 584 # special case: char* -> bytes/unicode 585 if slice_node.type is Builtin.unicode_type: 586 target_value = ExprNodes.CastNode( 587 ExprNodes.DereferenceNode( 588 node.target.pos, operand=counter_temp, 589 type=ptr_type.base_type), 590 PyrexTypes.c_py_ucs4_type).coerce_to( 591 node.target.type, self.current_env()) 592 else: 593 # char* -> bytes coercion requires slicing, not indexing 594 target_value = ExprNodes.SliceIndexNode( 595 node.target.pos, 596 start=ExprNodes.IntNode(node.target.pos, value='0', 597 constant_result=0, 598 type=PyrexTypes.c_int_type), 599 stop=ExprNodes.IntNode(node.target.pos, value='1', 600 constant_result=1, 601 type=PyrexTypes.c_int_type), 602 base=counter_temp, 603 type=Builtin.bytes_type, 604 is_temp=1) 605 elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type): 606 # Allow iteration with pointer target to avoid copy. 607 target_value = counter_temp 608 else: 609 # TODO: can this safely be replaced with DereferenceNode() as above? 610 target_value = ExprNodes.IndexNode( 611 node.target.pos, 612 index=ExprNodes.IntNode(node.target.pos, value='0', 613 constant_result=0, 614 type=PyrexTypes.c_int_type), 615 base=counter_temp, 616 type=ptr_type.base_type) 617 618 if target_value.type != node.target.type: 619 target_value = target_value.coerce_to(node.target.type, 620 self.current_env()) 621 622 target_assign = Nodes.SingleAssignmentNode( 623 pos = node.target.pos, 624 lhs = node.target, 625 rhs = target_value) 626 627 body = Nodes.StatListNode( 628 node.pos, 629 stats = [target_assign, node.body]) 630 631 relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed) 632 633 for_node = Nodes.ForFromStatNode( 634 node.pos, 635 bound1=start_ptr_node, relation1=relation1, 636 target=counter_temp, 637 relation2=relation2, bound2=stop_ptr_node, 638 step=step, body=body, 639 else_clause=node.else_clause, 640 from_range=True) 641 642 return UtilNodes.TempsBlockNode( 643 node.pos, temps=[counter], 644 body=for_node) 645 646 def _transform_enumerate_iteration(self, node, enumerate_function): 647 args = enumerate_function.arg_tuple.args 648 if len(args) == 0: 649 error(enumerate_function.pos, 650 "enumerate() requires an iterable argument") 651 return node 652 elif len(args) > 2: 653 error(enumerate_function.pos, 654 "enumerate() takes at most 2 arguments") 655 return node 656 657 if not node.target.is_sequence_constructor: 658 # leave this untouched for now 659 return node 660 targets = node.target.args 661 if len(targets) != 2: 662 # leave this untouched for now 663 return node 664 665 enumerate_target, iterable_target = targets 666 counter_type = enumerate_target.type 667 668 if not counter_type.is_pyobject and not counter_type.is_int: 669 # nothing we can do here, I guess 670 return node 671 672 if len(args) == 2: 673 start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env()) 674 else: 675 start = ExprNodes.IntNode(enumerate_function.pos, 676 value='0', 677 type=counter_type, 678 constant_result=0) 679 temp = UtilNodes.LetRefNode(start) 680 681 inc_expression = ExprNodes.AddNode( 682 enumerate_function.pos, 683 operand1 = temp, 684 operand2 = ExprNodes.IntNode(node.pos, value='1', 685 type=counter_type, 686 constant_result=1), 687 operator = '+', 688 type = counter_type, 689 #inplace = True, # not worth using in-place operation for Py ints 690 is_temp = counter_type.is_pyobject 691 ) 692 693 loop_body = [ 694 Nodes.SingleAssignmentNode( 695 pos = enumerate_target.pos, 696 lhs = enumerate_target, 697 rhs = temp), 698 Nodes.SingleAssignmentNode( 699 pos = enumerate_target.pos, 700 lhs = temp, 701 rhs = inc_expression) 702 ] 703 704 if isinstance(node.body, Nodes.StatListNode): 705 node.body.stats = loop_body + node.body.stats 706 else: 707 loop_body.append(node.body) 708 node.body = Nodes.StatListNode( 709 node.body.pos, 710 stats = loop_body) 711 712 node.target = iterable_target 713 node.item = node.item.coerce_to(iterable_target.type, self.current_env()) 714 node.iterator.sequence = args[0] 715 716 # recurse into loop to check for further optimisations 717 return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence)) 718 719 def _find_for_from_node_relations(self, neg_step_value, reversed): 720 if reversed: 721 if neg_step_value: 722 return '<', '<=' 723 else: 724 return '>', '>=' 725 else: 726 if neg_step_value: 727 return '>=', '>' 728 else: 729 return '<=', '<' 730 731 def _transform_range_iteration(self, node, range_function, reversed=False): 732 args = range_function.arg_tuple.args 733 if len(args) < 3: 734 step_pos = range_function.pos 735 step_value = 1 736 step = ExprNodes.IntNode(step_pos, value='1', constant_result=1) 737 else: 738 step = args[2] 739 step_pos = step.pos 740 if not isinstance(step.constant_result, _py_int_types): 741 # cannot determine step direction 742 return node 743 step_value = step.constant_result 744 if step_value == 0: 745 # will lead to an error elsewhere 746 return node 747 step = ExprNodes.IntNode(step_pos, value=str(step_value), 748 constant_result=step_value) 749 750 if len(args) == 1: 751 bound1 = ExprNodes.IntNode(range_function.pos, value='0', 752 constant_result=0) 753 bound2 = args[0].coerce_to_integer(self.current_env()) 754 else: 755 bound1 = args[0].coerce_to_integer(self.current_env()) 756 bound2 = args[1].coerce_to_integer(self.current_env()) 757 758 relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed) 759 760 bound2_ref_node = None 761 if reversed: 762 bound1, bound2 = bound2, bound1 763 abs_step = abs(step_value) 764 if abs_step != 1: 765 if (isinstance(bound1.constant_result, _py_int_types) and 766 isinstance(bound2.constant_result, _py_int_types)): 767 # calculate final bounds now 768 if step_value < 0: 769 begin_value = bound2.constant_result 770 end_value = bound1.constant_result 771 bound1_value = begin_value - abs_step * ((begin_value - end_value - 1) // abs_step) - 1 772 else: 773 begin_value = bound1.constant_result 774 end_value = bound2.constant_result 775 bound1_value = end_value + abs_step * ((begin_value - end_value - 1) // abs_step) + 1 776 777 bound1 = ExprNodes.IntNode( 778 bound1.pos, value=str(bound1_value), constant_result=bound1_value, 779 type=PyrexTypes.spanning_type(bound1.type, bound2.type)) 780 else: 781 # evaluate the same expression as above at runtime 782 bound2_ref_node = UtilNodes.LetRefNode(bound2) 783 bound1 = self._build_range_step_calculation( 784 bound1, bound2_ref_node, step, step_value) 785 786 if step_value < 0: 787 step_value = -step_value 788 step.value = str(step_value) 789 step.constant_result = step_value 790 step = step.coerce_to_integer(self.current_env()) 791 792 if not bound2.is_literal: 793 # stop bound must be immutable => keep it in a temp var 794 bound2_is_temp = True 795 bound2 = bound2_ref_node or UtilNodes.LetRefNode(bound2) 796 else: 797 bound2_is_temp = False 798 799 for_node = Nodes.ForFromStatNode( 800 node.pos, 801 target=node.target, 802 bound1=bound1, relation1=relation1, 803 relation2=relation2, bound2=bound2, 804 step=step, body=node.body, 805 else_clause=node.else_clause, 806 from_range=True) 807 for_node.set_up_loop(self.current_env()) 808 809 if bound2_is_temp: 810 for_node = UtilNodes.LetNode(bound2, for_node) 811 812 return for_node 813 814 def _build_range_step_calculation(self, bound1, bound2_ref_node, step, step_value): 815 abs_step = abs(step_value) 816 spanning_type = PyrexTypes.spanning_type(bound1.type, bound2_ref_node.type) 817 if step.type.is_int and abs_step < 0x7FFF: 818 # Avoid loss of integer precision warnings. 819 spanning_step_type = PyrexTypes.spanning_type(spanning_type, PyrexTypes.c_int_type) 820 else: 821 spanning_step_type = PyrexTypes.spanning_type(spanning_type, step.type) 822 if step_value < 0: 823 begin_value = bound2_ref_node 824 end_value = bound1 825 final_op = '-' 826 else: 827 begin_value = bound1 828 end_value = bound2_ref_node 829 final_op = '+' 830 831 step_calculation_node = ExprNodes.binop_node( 832 bound1.pos, 833 operand1=ExprNodes.binop_node( 834 bound1.pos, 835 operand1=bound2_ref_node, 836 operator=final_op, # +/- 837 operand2=ExprNodes.MulNode( 838 bound1.pos, 839 operand1=ExprNodes.IntNode( 840 bound1.pos, 841 value=str(abs_step), 842 constant_result=abs_step, 843 type=spanning_step_type), 844 operator='*', 845 operand2=ExprNodes.DivNode( 846 bound1.pos, 847 operand1=ExprNodes.SubNode( 848 bound1.pos, 849 operand1=ExprNodes.SubNode( 850 bound1.pos, 851 operand1=begin_value, 852 operator='-', 853 operand2=end_value, 854 type=spanning_type), 855 operator='-', 856 operand2=ExprNodes.IntNode( 857 bound1.pos, 858 value='1', 859 constant_result=1), 860 type=spanning_step_type), 861 operator='//', 862 operand2=ExprNodes.IntNode( 863 bound1.pos, 864 value=str(abs_step), 865 constant_result=abs_step, 866 type=spanning_step_type), 867 type=spanning_step_type), 868 type=spanning_step_type), 869 type=spanning_step_type), 870 operator=final_op, # +/- 871 operand2=ExprNodes.IntNode( 872 bound1.pos, 873 value='1', 874 constant_result=1), 875 type=spanning_type) 876 return step_calculation_node 877 878 def _transform_dict_iteration(self, node, dict_obj, method, keys, values): 879 temps = [] 880 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type) 881 temps.append(temp) 882 dict_temp = temp.ref(dict_obj.pos) 883 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) 884 temps.append(temp) 885 pos_temp = temp.ref(node.pos) 886 887 key_target = value_target = tuple_target = None 888 if keys and values: 889 if node.target.is_sequence_constructor: 890 if len(node.target.args) == 2: 891 key_target, value_target = node.target.args 892 else: 893 # unusual case that may or may not lead to an error 894 return node 895 else: 896 tuple_target = node.target 897 elif keys: 898 key_target = node.target 899 else: 900 value_target = node.target 901 902 if isinstance(node.body, Nodes.StatListNode): 903 body = node.body 904 else: 905 body = Nodes.StatListNode(pos = node.body.pos, 906 stats = [node.body]) 907 908 # keep original length to guard against dict modification 909 dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) 910 temps.append(dict_len_temp) 911 dict_len_temp_addr = ExprNodes.AmpersandNode( 912 node.pos, operand=dict_len_temp.ref(dict_obj.pos), 913 type=PyrexTypes.c_ptr_type(dict_len_temp.type)) 914 temp = UtilNodes.TempHandle(PyrexTypes.c_int_type) 915 temps.append(temp) 916 is_dict_temp = temp.ref(node.pos) 917 is_dict_temp_addr = ExprNodes.AmpersandNode( 918 node.pos, operand=is_dict_temp, 919 type=PyrexTypes.c_ptr_type(temp.type)) 920 921 iter_next_node = Nodes.DictIterationNextNode( 922 dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp, 923 key_target, value_target, tuple_target, 924 is_dict_temp) 925 iter_next_node = iter_next_node.analyse_expressions(self.current_env()) 926 body.stats[0:0] = [iter_next_node] 927 928 if method: 929 method_node = ExprNodes.StringNode( 930 dict_obj.pos, is_identifier=True, value=method) 931 dict_obj = dict_obj.as_none_safe_node( 932 "'NoneType' object has no attribute '%{0}s'".format('.30' if len(method) <= 30 else ''), 933 error = "PyExc_AttributeError", 934 format_args = [method]) 935 else: 936 method_node = ExprNodes.NullNode(dict_obj.pos) 937 dict_obj = dict_obj.as_none_safe_node("'NoneType' object is not iterable") 938 939 def flag_node(value): 940 value = value and 1 or 0 941 return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value) 942 943 result_code = [ 944 Nodes.SingleAssignmentNode( 945 node.pos, 946 lhs = pos_temp, 947 rhs = ExprNodes.IntNode(node.pos, value='0', 948 constant_result=0)), 949 Nodes.SingleAssignmentNode( 950 dict_obj.pos, 951 lhs = dict_temp, 952 rhs = ExprNodes.PythonCapiCallNode( 953 dict_obj.pos, 954 "__Pyx_dict_iterator", 955 self.PyDict_Iterator_func_type, 956 utility_code = UtilityCode.load_cached("dict_iter", "Optimize.c"), 957 args = [dict_obj, flag_node(dict_obj.type is Builtin.dict_type), 958 method_node, dict_len_temp_addr, is_dict_temp_addr, 959 ], 960 is_temp=True, 961 )), 962 Nodes.WhileStatNode( 963 node.pos, 964 condition = None, 965 body = body, 966 else_clause = node.else_clause 967 ) 968 ] 969 970 return UtilNodes.TempsBlockNode( 971 node.pos, temps=temps, 972 body=Nodes.StatListNode( 973 node.pos, 974 stats = result_code 975 )) 976 977 PyDict_Iterator_func_type = PyrexTypes.CFuncType( 978 PyrexTypes.py_object_type, [ 979 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), 980 PyrexTypes.CFuncTypeArg("is_dict", PyrexTypes.c_int_type, None), 981 PyrexTypes.CFuncTypeArg("method_name", PyrexTypes.py_object_type, None), 982 PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_ptr_type, None), 983 PyrexTypes.CFuncTypeArg("p_is_dict", PyrexTypes.c_int_ptr_type, None), 984 ]) 985 986 PySet_Iterator_func_type = PyrexTypes.CFuncType( 987 PyrexTypes.py_object_type, [ 988 PyrexTypes.CFuncTypeArg("set", PyrexTypes.py_object_type, None), 989 PyrexTypes.CFuncTypeArg("is_set", PyrexTypes.c_int_type, None), 990 PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_ptr_type, None), 991 PyrexTypes.CFuncTypeArg("p_is_set", PyrexTypes.c_int_ptr_type, None), 992 ]) 993 994 def _transform_set_iteration(self, node, set_obj): 995 temps = [] 996 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type) 997 temps.append(temp) 998 set_temp = temp.ref(set_obj.pos) 999 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) 1000 temps.append(temp) 1001 pos_temp = temp.ref(node.pos) 1002 1003 if isinstance(node.body, Nodes.StatListNode): 1004 body = node.body 1005 else: 1006 body = Nodes.StatListNode(pos = node.body.pos, 1007 stats = [node.body]) 1008 1009 # keep original length to guard against set modification 1010 set_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) 1011 temps.append(set_len_temp) 1012 set_len_temp_addr = ExprNodes.AmpersandNode( 1013 node.pos, operand=set_len_temp.ref(set_obj.pos), 1014 type=PyrexTypes.c_ptr_type(set_len_temp.type)) 1015 temp = UtilNodes.TempHandle(PyrexTypes.c_int_type) 1016 temps.append(temp) 1017 is_set_temp = temp.ref(node.pos) 1018 is_set_temp_addr = ExprNodes.AmpersandNode( 1019 node.pos, operand=is_set_temp, 1020 type=PyrexTypes.c_ptr_type(temp.type)) 1021 1022 value_target = node.target 1023 iter_next_node = Nodes.SetIterationNextNode( 1024 set_temp, set_len_temp.ref(set_obj.pos), pos_temp, value_target, is_set_temp) 1025 iter_next_node = iter_next_node.analyse_expressions(self.current_env()) 1026 body.stats[0:0] = [iter_next_node] 1027 1028 def flag_node(value): 1029 value = value and 1 or 0 1030 return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value) 1031 1032 result_code = [ 1033 Nodes.SingleAssignmentNode( 1034 node.pos, 1035 lhs=pos_temp, 1036 rhs=ExprNodes.IntNode(node.pos, value='0', constant_result=0)), 1037 Nodes.SingleAssignmentNode( 1038 set_obj.pos, 1039 lhs=set_temp, 1040 rhs=ExprNodes.PythonCapiCallNode( 1041 set_obj.pos, 1042 "__Pyx_set_iterator", 1043 self.PySet_Iterator_func_type, 1044 utility_code=UtilityCode.load_cached("set_iter", "Optimize.c"), 1045 args=[set_obj, flag_node(set_obj.type is Builtin.set_type), 1046 set_len_temp_addr, is_set_temp_addr, 1047 ], 1048 is_temp=True, 1049 )), 1050 Nodes.WhileStatNode( 1051 node.pos, 1052 condition=None, 1053 body=body, 1054 else_clause=node.else_clause, 1055 ) 1056 ] 1057 1058 return UtilNodes.TempsBlockNode( 1059 node.pos, temps=temps, 1060 body=Nodes.StatListNode( 1061 node.pos, 1062 stats = result_code 1063 )) 1064 1065 1066class SwitchTransform(Visitor.EnvTransform): 1067 """ 1068 This transformation tries to turn long if statements into C switch statements. 1069 The requirement is that every clause be an (or of) var == value, where the var 1070 is common among all clauses and both var and value are ints. 1071 """ 1072 NO_MATCH = (None, None, None) 1073 1074 def extract_conditions(self, cond, allow_not_in): 1075 while True: 1076 if isinstance(cond, (ExprNodes.CoerceToTempNode, 1077 ExprNodes.CoerceToBooleanNode)): 1078 cond = cond.arg 1079 elif isinstance(cond, ExprNodes.BoolBinopResultNode): 1080 cond = cond.arg.arg 1081 elif isinstance(cond, UtilNodes.EvalWithTempExprNode): 1082 # this is what we get from the FlattenInListTransform 1083 cond = cond.subexpression 1084 elif isinstance(cond, ExprNodes.TypecastNode): 1085 cond = cond.operand 1086 else: 1087 break 1088 1089 if isinstance(cond, ExprNodes.PrimaryCmpNode): 1090 if cond.cascade is not None: 1091 return self.NO_MATCH 1092 elif cond.is_c_string_contains() and \ 1093 isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)): 1094 not_in = cond.operator == 'not_in' 1095 if not_in and not allow_not_in: 1096 return self.NO_MATCH 1097 if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \ 1098 cond.operand2.contains_surrogates(): 1099 # dealing with surrogates leads to different 1100 # behaviour on wide and narrow Unicode 1101 # platforms => refuse to optimise this case 1102 return self.NO_MATCH 1103 return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2) 1104 elif not cond.is_python_comparison(): 1105 if cond.operator == '==': 1106 not_in = False 1107 elif allow_not_in and cond.operator == '!=': 1108 not_in = True 1109 else: 1110 return self.NO_MATCH 1111 # this looks somewhat silly, but it does the right 1112 # checks for NameNode and AttributeNode 1113 if is_common_value(cond.operand1, cond.operand1): 1114 if cond.operand2.is_literal: 1115 return not_in, cond.operand1, [cond.operand2] 1116 elif getattr(cond.operand2, 'entry', None) \ 1117 and cond.operand2.entry.is_const: 1118 return not_in, cond.operand1, [cond.operand2] 1119 if is_common_value(cond.operand2, cond.operand2): 1120 if cond.operand1.is_literal: 1121 return not_in, cond.operand2, [cond.operand1] 1122 elif getattr(cond.operand1, 'entry', None) \ 1123 and cond.operand1.entry.is_const: 1124 return not_in, cond.operand2, [cond.operand1] 1125 elif isinstance(cond, ExprNodes.BoolBinopNode): 1126 if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'): 1127 allow_not_in = (cond.operator == 'and') 1128 not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in) 1129 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in) 1130 if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2): 1131 if (not not_in_1) or allow_not_in: 1132 return not_in_1, t1, c1+c2 1133 return self.NO_MATCH 1134 1135 def extract_in_string_conditions(self, string_literal): 1136 if isinstance(string_literal, ExprNodes.UnicodeNode): 1137 charvals = list(map(ord, set(string_literal.value))) 1138 charvals.sort() 1139 return [ ExprNodes.IntNode(string_literal.pos, value=str(charval), 1140 constant_result=charval) 1141 for charval in charvals ] 1142 else: 1143 # this is a bit tricky as Py3's bytes type returns 1144 # integers on iteration, whereas Py2 returns 1-char byte 1145 # strings 1146 characters = string_literal.value 1147 characters = list(set([ characters[i:i+1] for i in range(len(characters)) ])) 1148 characters.sort() 1149 return [ ExprNodes.CharNode(string_literal.pos, value=charval, 1150 constant_result=charval) 1151 for charval in characters ] 1152 1153 def extract_common_conditions(self, common_var, condition, allow_not_in): 1154 not_in, var, conditions = self.extract_conditions(condition, allow_not_in) 1155 if var is None: 1156 return self.NO_MATCH 1157 elif common_var is not None and not is_common_value(var, common_var): 1158 return self.NO_MATCH 1159 elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is_int or cond.type.is_enum) for cond in conditions]): 1160 return self.NO_MATCH 1161 return not_in, var, conditions 1162 1163 def has_duplicate_values(self, condition_values): 1164 # duplicated values don't work in a switch statement 1165 seen = set() 1166 for value in condition_values: 1167 if value.has_constant_result(): 1168 if value.constant_result in seen: 1169 return True 1170 seen.add(value.constant_result) 1171 else: 1172 # this isn't completely safe as we don't know the 1173 # final C value, but this is about the best we can do 1174 try: 1175 if value.entry.cname in seen: 1176 return True 1177 except AttributeError: 1178 return True # play safe 1179 seen.add(value.entry.cname) 1180 return False 1181 1182 def visit_IfStatNode(self, node): 1183 if not self.current_directives.get('optimize.use_switch'): 1184 self.visitchildren(node) 1185 return node 1186 1187 common_var = None 1188 cases = [] 1189 for if_clause in node.if_clauses: 1190 _, common_var, conditions = self.extract_common_conditions( 1191 common_var, if_clause.condition, False) 1192 if common_var is None: 1193 self.visitchildren(node) 1194 return node 1195 cases.append(Nodes.SwitchCaseNode(pos=if_clause.pos, 1196 conditions=conditions, 1197 body=if_clause.body)) 1198 1199 condition_values = [ 1200 cond for case in cases for cond in case.conditions] 1201 if len(condition_values) < 2: 1202 self.visitchildren(node) 1203 return node 1204 if self.has_duplicate_values(condition_values): 1205 self.visitchildren(node) 1206 return node 1207 1208 # Recurse into body subtrees that we left untouched so far. 1209 self.visitchildren(node, 'else_clause') 1210 for case in cases: 1211 self.visitchildren(case, 'body') 1212 1213 common_var = unwrap_node(common_var) 1214 switch_node = Nodes.SwitchStatNode(pos=node.pos, 1215 test=common_var, 1216 cases=cases, 1217 else_clause=node.else_clause) 1218 return switch_node 1219 1220 def visit_CondExprNode(self, node): 1221 if not self.current_directives.get('optimize.use_switch'): 1222 self.visitchildren(node) 1223 return node 1224 1225 not_in, common_var, conditions = self.extract_common_conditions( 1226 None, node.test, True) 1227 if common_var is None \ 1228 or len(conditions) < 2 \ 1229 or self.has_duplicate_values(conditions): 1230 self.visitchildren(node) 1231 return node 1232 1233 return self.build_simple_switch_statement( 1234 node, common_var, conditions, not_in, 1235 node.true_val, node.false_val) 1236 1237 def visit_BoolBinopNode(self, node): 1238 if not self.current_directives.get('optimize.use_switch'): 1239 self.visitchildren(node) 1240 return node 1241 1242 not_in, common_var, conditions = self.extract_common_conditions( 1243 None, node, True) 1244 if common_var is None \ 1245 or len(conditions) < 2 \ 1246 or self.has_duplicate_values(conditions): 1247 self.visitchildren(node) 1248 node.wrap_operands(self.current_env()) # in case we changed the operands 1249 return node 1250 1251 return self.build_simple_switch_statement( 1252 node, common_var, conditions, not_in, 1253 ExprNodes.BoolNode(node.pos, value=True, constant_result=True), 1254 ExprNodes.BoolNode(node.pos, value=False, constant_result=False)) 1255 1256 def visit_PrimaryCmpNode(self, node): 1257 if not self.current_directives.get('optimize.use_switch'): 1258 self.visitchildren(node) 1259 return node 1260 1261 not_in, common_var, conditions = self.extract_common_conditions( 1262 None, node, True) 1263 if common_var is None \ 1264 or len(conditions) < 2 \ 1265 or self.has_duplicate_values(conditions): 1266 self.visitchildren(node) 1267 return node 1268 1269 return self.build_simple_switch_statement( 1270 node, common_var, conditions, not_in, 1271 ExprNodes.BoolNode(node.pos, value=True, constant_result=True), 1272 ExprNodes.BoolNode(node.pos, value=False, constant_result=False)) 1273 1274 def build_simple_switch_statement(self, node, common_var, conditions, 1275 not_in, true_val, false_val): 1276 result_ref = UtilNodes.ResultRefNode(node) 1277 true_body = Nodes.SingleAssignmentNode( 1278 node.pos, 1279 lhs=result_ref, 1280 rhs=true_val.coerce_to(node.type, self.current_env()), 1281 first=True) 1282 false_body = Nodes.SingleAssignmentNode( 1283 node.pos, 1284 lhs=result_ref, 1285 rhs=false_val.coerce_to(node.type, self.current_env()), 1286 first=True) 1287 1288 if not_in: 1289 true_body, false_body = false_body, true_body 1290 1291 cases = [Nodes.SwitchCaseNode(pos = node.pos, 1292 conditions = conditions, 1293 body = true_body)] 1294 1295 common_var = unwrap_node(common_var) 1296 switch_node = Nodes.SwitchStatNode(pos = node.pos, 1297 test = common_var, 1298 cases = cases, 1299 else_clause = false_body) 1300 replacement = UtilNodes.TempResultFromStatNode(result_ref, switch_node) 1301 return replacement 1302 1303 def visit_EvalWithTempExprNode(self, node): 1304 if not self.current_directives.get('optimize.use_switch'): 1305 self.visitchildren(node) 1306 return node 1307 1308 # drop unused expression temp from FlattenInListTransform 1309 orig_expr = node.subexpression 1310 temp_ref = node.lazy_temp 1311 self.visitchildren(node) 1312 if node.subexpression is not orig_expr: 1313 # node was restructured => check if temp is still used 1314 if not Visitor.tree_contains(node.subexpression, temp_ref): 1315 return node.subexpression 1316 return node 1317 1318 visit_Node = Visitor.VisitorTransform.recurse_to_children 1319 1320 1321class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations): 1322 """ 1323 This transformation flattens "x in [val1, ..., valn]" into a sequential list 1324 of comparisons. 1325 """ 1326 1327 def visit_PrimaryCmpNode(self, node): 1328 self.visitchildren(node) 1329 if node.cascade is not None: 1330 return node 1331 elif node.operator == 'in': 1332 conjunction = 'or' 1333 eq_or_neq = '==' 1334 elif node.operator == 'not_in': 1335 conjunction = 'and' 1336 eq_or_neq = '!=' 1337 else: 1338 return node 1339 1340 if not isinstance(node.operand2, (ExprNodes.TupleNode, 1341 ExprNodes.ListNode, 1342 ExprNodes.SetNode)): 1343 return node 1344 1345 args = node.operand2.args 1346 if len(args) == 0: 1347 # note: lhs may have side effects 1348 return node 1349 1350 lhs = UtilNodes.ResultRefNode(node.operand1) 1351 1352 conds = [] 1353 temps = [] 1354 for arg in args: 1355 try: 1356 # Trial optimisation to avoid redundant temp 1357 # assignments. However, since is_simple() is meant to 1358 # be called after type analysis, we ignore any errors 1359 # and just play safe in that case. 1360 is_simple_arg = arg.is_simple() 1361 except Exception: 1362 is_simple_arg = False 1363 if not is_simple_arg: 1364 # must evaluate all non-simple RHS before doing the comparisons 1365 arg = UtilNodes.LetRefNode(arg) 1366 temps.append(arg) 1367 cond = ExprNodes.PrimaryCmpNode( 1368 pos = node.pos, 1369 operand1 = lhs, 1370 operator = eq_or_neq, 1371 operand2 = arg, 1372 cascade = None) 1373 conds.append(ExprNodes.TypecastNode( 1374 pos = node.pos, 1375 operand = cond, 1376 type = PyrexTypes.c_bint_type)) 1377 def concat(left, right): 1378 return ExprNodes.BoolBinopNode( 1379 pos = node.pos, 1380 operator = conjunction, 1381 operand1 = left, 1382 operand2 = right) 1383 1384 condition = reduce(concat, conds) 1385 new_node = UtilNodes.EvalWithTempExprNode(lhs, condition) 1386 for temp in temps[::-1]: 1387 new_node = UtilNodes.EvalWithTempExprNode(temp, new_node) 1388 return new_node 1389 1390 visit_Node = Visitor.VisitorTransform.recurse_to_children 1391 1392 1393class DropRefcountingTransform(Visitor.VisitorTransform): 1394 """Drop ref-counting in safe places. 1395 """ 1396 visit_Node = Visitor.VisitorTransform.recurse_to_children 1397 1398 def visit_ParallelAssignmentNode(self, node): 1399 """ 1400 Parallel swap assignments like 'a,b = b,a' are safe. 1401 """ 1402 left_names, right_names = [], [] 1403 left_indices, right_indices = [], [] 1404 temps = [] 1405 1406 for stat in node.stats: 1407 if isinstance(stat, Nodes.SingleAssignmentNode): 1408 if not self._extract_operand(stat.lhs, left_names, 1409 left_indices, temps): 1410 return node 1411 if not self._extract_operand(stat.rhs, right_names, 1412 right_indices, temps): 1413 return node 1414 elif isinstance(stat, Nodes.CascadedAssignmentNode): 1415 # FIXME 1416 return node 1417 else: 1418 return node 1419 1420 if left_names or right_names: 1421 # lhs/rhs names must be a non-redundant permutation 1422 lnames = [ path for path, n in left_names ] 1423 rnames = [ path for path, n in right_names ] 1424 if set(lnames) != set(rnames): 1425 return node 1426 if len(set(lnames)) != len(right_names): 1427 return node 1428 1429 if left_indices or right_indices: 1430 # base name and index of index nodes must be a 1431 # non-redundant permutation 1432 lindices = [] 1433 for lhs_node in left_indices: 1434 index_id = self._extract_index_id(lhs_node) 1435 if not index_id: 1436 return node 1437 lindices.append(index_id) 1438 rindices = [] 1439 for rhs_node in right_indices: 1440 index_id = self._extract_index_id(rhs_node) 1441 if not index_id: 1442 return node 1443 rindices.append(index_id) 1444 1445 if set(lindices) != set(rindices): 1446 return node 1447 if len(set(lindices)) != len(right_indices): 1448 return node 1449 1450 # really supporting IndexNode requires support in 1451 # __Pyx_GetItemInt(), so let's stop short for now 1452 return node 1453 1454 temp_args = [t.arg for t in temps] 1455 for temp in temps: 1456 temp.use_managed_ref = False 1457 1458 for _, name_node in left_names + right_names: 1459 if name_node not in temp_args: 1460 name_node.use_managed_ref = False 1461 1462 for index_node in left_indices + right_indices: 1463 index_node.use_managed_ref = False 1464 1465 return node 1466 1467 def _extract_operand(self, node, names, indices, temps): 1468 node = unwrap_node(node) 1469 if not node.type.is_pyobject: 1470 return False 1471 if isinstance(node, ExprNodes.CoerceToTempNode): 1472 temps.append(node) 1473 node = node.arg 1474 name_path = [] 1475 obj_node = node 1476 while obj_node.is_attribute: 1477 if obj_node.is_py_attr: 1478 return False 1479 name_path.append(obj_node.member) 1480 obj_node = obj_node.obj 1481 if obj_node.is_name: 1482 name_path.append(obj_node.name) 1483 names.append( ('.'.join(name_path[::-1]), node) ) 1484 elif node.is_subscript: 1485 if node.base.type != Builtin.list_type: 1486 return False 1487 if not node.index.type.is_int: 1488 return False 1489 if not node.base.is_name: 1490 return False 1491 indices.append(node) 1492 else: 1493 return False 1494 return True 1495 1496 def _extract_index_id(self, index_node): 1497 base = index_node.base 1498 index = index_node.index 1499 if isinstance(index, ExprNodes.NameNode): 1500 index_val = index.name 1501 elif isinstance(index, ExprNodes.ConstNode): 1502 # FIXME: 1503 return None 1504 else: 1505 return None 1506 return (base.name, index_val) 1507 1508 1509class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): 1510 """Optimize some common calls to builtin types *before* the type 1511 analysis phase and *after* the declarations analysis phase. 1512 1513 This transform cannot make use of any argument types, but it can 1514 restructure the tree in a way that the type analysis phase can 1515 respond to. 1516 1517 Introducing C function calls here may not be a good idea. Move 1518 them to the OptimizeBuiltinCalls transform instead, which runs 1519 after type analysis. 1520 """ 1521 # only intercept on call nodes 1522 visit_Node = Visitor.VisitorTransform.recurse_to_children 1523 1524 def visit_SimpleCallNode(self, node): 1525 self.visitchildren(node) 1526 function = node.function 1527 if not self._function_is_builtin_name(function): 1528 return node 1529 return self._dispatch_to_handler(node, function, node.args) 1530 1531 def visit_GeneralCallNode(self, node): 1532 self.visitchildren(node) 1533 function = node.function 1534 if not self._function_is_builtin_name(function): 1535 return node 1536 arg_tuple = node.positional_args 1537 if not isinstance(arg_tuple, ExprNodes.TupleNode): 1538 return node 1539 args = arg_tuple.args 1540 return self._dispatch_to_handler( 1541 node, function, args, node.keyword_args) 1542 1543 def _function_is_builtin_name(self, function): 1544 if not function.is_name: 1545 return False 1546 env = self.current_env() 1547 entry = env.lookup(function.name) 1548 if entry is not env.builtin_scope().lookup_here(function.name): 1549 return False 1550 # if entry is None, it's at least an undeclared name, so likely builtin 1551 return True 1552 1553 def _dispatch_to_handler(self, node, function, args, kwargs=None): 1554 if kwargs is None: 1555 handler_name = '_handle_simple_function_%s' % function.name 1556 else: 1557 handler_name = '_handle_general_function_%s' % function.name 1558 handle_call = getattr(self, handler_name, None) 1559 if handle_call is not None: 1560 if kwargs is None: 1561 return handle_call(node, args) 1562 else: 1563 return handle_call(node, args, kwargs) 1564 return node 1565 1566 def _inject_capi_function(self, node, cname, func_type, utility_code=None): 1567 node.function = ExprNodes.PythonCapiFunctionNode( 1568 node.function.pos, node.function.name, cname, func_type, 1569 utility_code = utility_code) 1570 1571 def _error_wrong_arg_count(self, function_name, node, args, expected=None): 1572 if not expected: # None or 0 1573 arg_str = '' 1574 elif isinstance(expected, basestring) or expected > 1: 1575 arg_str = '...' 1576 elif expected == 1: 1577 arg_str = 'x' 1578 else: 1579 arg_str = '' 1580 if expected is not None: 1581 expected_str = 'expected %s, ' % expected 1582 else: 1583 expected_str = '' 1584 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % ( 1585 function_name, arg_str, expected_str, len(args))) 1586 1587 # specific handlers for simple call nodes 1588 1589 def _handle_simple_function_float(self, node, pos_args): 1590 if not pos_args: 1591 return ExprNodes.FloatNode(node.pos, value='0.0') 1592 if len(pos_args) > 1: 1593 self._error_wrong_arg_count('float', node, pos_args, 1) 1594 arg_type = getattr(pos_args[0], 'type', None) 1595 if arg_type in (PyrexTypes.c_double_type, Builtin.float_type): 1596 return pos_args[0] 1597 return node 1598 1599 def _handle_simple_function_slice(self, node, pos_args): 1600 arg_count = len(pos_args) 1601 start = step = None 1602 if arg_count == 1: 1603 stop, = pos_args 1604 elif arg_count == 2: 1605 start, stop = pos_args 1606 elif arg_count == 3: 1607 start, stop, step = pos_args 1608 else: 1609 self._error_wrong_arg_count('slice', node, pos_args) 1610 return node 1611 return ExprNodes.SliceNode( 1612 node.pos, 1613 start=start or ExprNodes.NoneNode(node.pos), 1614 stop=stop, 1615 step=step or ExprNodes.NoneNode(node.pos)) 1616 1617 def _handle_simple_function_ord(self, node, pos_args): 1618 """Unpack ord('X'). 1619 """ 1620 if len(pos_args) != 1: 1621 return node 1622 arg = pos_args[0] 1623 if isinstance(arg, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)): 1624 if len(arg.value) == 1: 1625 return ExprNodes.IntNode( 1626 arg.pos, type=PyrexTypes.c_long_type, 1627 value=str(ord(arg.value)), 1628 constant_result=ord(arg.value) 1629 ) 1630 elif isinstance(arg, ExprNodes.StringNode): 1631 if arg.unicode_value and len(arg.unicode_value) == 1 \ 1632 and ord(arg.unicode_value) <= 255: # Py2/3 portability 1633 return ExprNodes.IntNode( 1634 arg.pos, type=PyrexTypes.c_int_type, 1635 value=str(ord(arg.unicode_value)), 1636 constant_result=ord(arg.unicode_value) 1637 ) 1638 return node 1639 1640 # sequence processing 1641 1642 def _handle_simple_function_all(self, node, pos_args): 1643 """Transform 1644 1645 _result = all(p(x) for L in LL for x in L) 1646 1647 into 1648 1649 for L in LL: 1650 for x in L: 1651 if not p(x): 1652 return False 1653 else: 1654 return True 1655 """ 1656 return self._transform_any_all(node, pos_args, False) 1657 1658 def _handle_simple_function_any(self, node, pos_args): 1659 """Transform 1660 1661 _result = any(p(x) for L in LL for x in L) 1662 1663 into 1664 1665 for L in LL: 1666 for x in L: 1667 if p(x): 1668 return True 1669 else: 1670 return False 1671 """ 1672 return self._transform_any_all(node, pos_args, True) 1673 1674 def _transform_any_all(self, node, pos_args, is_any): 1675 if len(pos_args) != 1: 1676 return node 1677 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): 1678 return node 1679 gen_expr_node = pos_args[0] 1680 generator_body = gen_expr_node.def_node.gbody 1681 loop_node = generator_body.body 1682 yield_expression, yield_stat_node = _find_single_yield_expression(loop_node) 1683 if yield_expression is None: 1684 return node 1685 1686 if is_any: 1687 condition = yield_expression 1688 else: 1689 condition = ExprNodes.NotNode(yield_expression.pos, operand=yield_expression) 1690 1691 test_node = Nodes.IfStatNode( 1692 yield_expression.pos, else_clause=None, if_clauses=[ 1693 Nodes.IfClauseNode( 1694 yield_expression.pos, 1695 condition=condition, 1696 body=Nodes.ReturnStatNode( 1697 node.pos, 1698 value=ExprNodes.BoolNode(yield_expression.pos, value=is_any, constant_result=is_any)) 1699 )] 1700 ) 1701 loop_node.else_clause = Nodes.ReturnStatNode( 1702 node.pos, 1703 value=ExprNodes.BoolNode(yield_expression.pos, value=not is_any, constant_result=not is_any)) 1704 1705 Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, test_node) 1706 1707 return ExprNodes.InlinedGeneratorExpressionNode( 1708 gen_expr_node.pos, gen=gen_expr_node, orig_func='any' if is_any else 'all') 1709 1710 PySequence_List_func_type = PyrexTypes.CFuncType( 1711 Builtin.list_type, 1712 [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)]) 1713 1714 def _handle_simple_function_sorted(self, node, pos_args): 1715 """Transform sorted(genexpr) and sorted([listcomp]) into 1716 [listcomp].sort(). CPython just reads the iterable into a 1717 list and calls .sort() on it. Expanding the iterable in a 1718 listcomp is still faster and the result can be sorted in 1719 place. 1720 """ 1721 if len(pos_args) != 1: 1722 return node 1723 1724 arg = pos_args[0] 1725 if isinstance(arg, ExprNodes.ComprehensionNode) and arg.type is Builtin.list_type: 1726 list_node = pos_args[0] 1727 loop_node = list_node.loop 1728 1729 elif isinstance(arg, ExprNodes.GeneratorExpressionNode): 1730 gen_expr_node = arg 1731 loop_node = gen_expr_node.loop 1732 yield_statements = _find_yield_statements(loop_node) 1733 if not yield_statements: 1734 return node 1735 1736 list_node = ExprNodes.InlinedGeneratorExpressionNode( 1737 node.pos, gen_expr_node, orig_func='sorted', 1738 comprehension_type=Builtin.list_type) 1739 1740 for yield_expression, yield_stat_node in yield_statements: 1741 append_node = ExprNodes.ComprehensionAppendNode( 1742 yield_expression.pos, 1743 expr=yield_expression, 1744 target=list_node.target) 1745 Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) 1746 1747 elif arg.is_sequence_constructor: 1748 # sorted([a, b, c]) or sorted((a, b, c)). The result is always a list, 1749 # so starting off with a fresh one is more efficient. 1750 list_node = loop_node = arg.as_list() 1751 1752 else: 1753 # Interestingly, PySequence_List works on a lot of non-sequence 1754 # things as well. 1755 list_node = loop_node = ExprNodes.PythonCapiCallNode( 1756 node.pos, "PySequence_List", self.PySequence_List_func_type, 1757 args=pos_args, is_temp=True) 1758 1759 result_node = UtilNodes.ResultRefNode( 1760 pos=loop_node.pos, type=Builtin.list_type, may_hold_none=False) 1761 list_assign_node = Nodes.SingleAssignmentNode( 1762 node.pos, lhs=result_node, rhs=list_node, first=True) 1763 1764 sort_method = ExprNodes.AttributeNode( 1765 node.pos, obj=result_node, attribute=EncodedString('sort'), 1766 # entry ? type ? 1767 needs_none_check=False) 1768 sort_node = Nodes.ExprStatNode( 1769 node.pos, expr=ExprNodes.SimpleCallNode( 1770 node.pos, function=sort_method, args=[])) 1771 1772 sort_node.analyse_declarations(self.current_env()) 1773 1774 return UtilNodes.TempResultFromStatNode( 1775 result_node, 1776 Nodes.StatListNode(node.pos, stats=[list_assign_node, sort_node])) 1777 1778 def __handle_simple_function_sum(self, node, pos_args): 1779 """Transform sum(genexpr) into an equivalent inlined aggregation loop. 1780 """ 1781 if len(pos_args) not in (1,2): 1782 return node 1783 if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode, 1784 ExprNodes.ComprehensionNode)): 1785 return node 1786 gen_expr_node = pos_args[0] 1787 loop_node = gen_expr_node.loop 1788 1789 if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode): 1790 yield_expression, yield_stat_node = _find_single_yield_expression(loop_node) 1791 # FIXME: currently nonfunctional 1792 yield_expression = None 1793 if yield_expression is None: 1794 return node 1795 else: # ComprehensionNode 1796 yield_stat_node = gen_expr_node.append 1797 yield_expression = yield_stat_node.expr 1798 try: 1799 if not yield_expression.is_literal or not yield_expression.type.is_int: 1800 return node 1801 except AttributeError: 1802 return node # in case we don't have a type yet 1803 # special case: old Py2 backwards compatible "sum([int_const for ...])" 1804 # can safely be unpacked into a genexpr 1805 1806 if len(pos_args) == 1: 1807 start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) 1808 else: 1809 start = pos_args[1] 1810 1811 result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type) 1812 add_node = Nodes.SingleAssignmentNode( 1813 yield_expression.pos, 1814 lhs = result_ref, 1815 rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression) 1816 ) 1817 1818 Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, add_node) 1819 1820 exec_code = Nodes.StatListNode( 1821 node.pos, 1822 stats = [ 1823 Nodes.SingleAssignmentNode( 1824 start.pos, 1825 lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref), 1826 rhs = start, 1827 first = True), 1828 loop_node 1829 ]) 1830 1831 return ExprNodes.InlinedGeneratorExpressionNode( 1832 gen_expr_node.pos, loop = exec_code, result_node = result_ref, 1833 expr_scope = gen_expr_node.expr_scope, orig_func = 'sum', 1834 has_local_scope = gen_expr_node.has_local_scope) 1835 1836 def _handle_simple_function_min(self, node, pos_args): 1837 return self._optimise_min_max(node, pos_args, '<') 1838 1839 def _handle_simple_function_max(self, node, pos_args): 1840 return self._optimise_min_max(node, pos_args, '>') 1841 1842 def _optimise_min_max(self, node, args, operator): 1843 """Replace min(a,b,...) and max(a,b,...) by explicit comparison code. 1844 """ 1845 if len(args) <= 1: 1846 if len(args) == 1 and args[0].is_sequence_constructor: 1847 args = args[0].args 1848 if len(args) <= 1: 1849 # leave this to Python 1850 return node 1851 1852 cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:])) 1853 1854 last_result = args[0] 1855 for arg_node in cascaded_nodes: 1856 result_ref = UtilNodes.ResultRefNode(last_result) 1857 last_result = ExprNodes.CondExprNode( 1858 arg_node.pos, 1859 true_val = arg_node, 1860 false_val = result_ref, 1861 test = ExprNodes.PrimaryCmpNode( 1862 arg_node.pos, 1863 operand1 = arg_node, 1864 operator = operator, 1865 operand2 = result_ref, 1866 ) 1867 ) 1868 last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result) 1869 1870 for ref_node in cascaded_nodes[::-1]: 1871 last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result) 1872 1873 return last_result 1874 1875 # builtin type creation 1876 1877 def _DISABLED_handle_simple_function_tuple(self, node, pos_args): 1878 if not pos_args: 1879 return ExprNodes.TupleNode(node.pos, args=[], constant_result=()) 1880 # This is a bit special - for iterables (including genexps), 1881 # Python actually overallocates and resizes a newly created 1882 # tuple incrementally while reading items, which we can't 1883 # easily do without explicit node support. Instead, we read 1884 # the items into a list and then copy them into a tuple of the 1885 # final size. This takes up to twice as much memory, but will 1886 # have to do until we have real support for genexps. 1887 result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type) 1888 if result is not node: 1889 return ExprNodes.AsTupleNode(node.pos, arg=result) 1890 return node 1891 1892 def _handle_simple_function_frozenset(self, node, pos_args): 1893 """Replace frozenset([...]) by frozenset((...)) as tuples are more efficient. 1894 """ 1895 if len(pos_args) != 1: 1896 return node 1897 if pos_args[0].is_sequence_constructor and not pos_args[0].args: 1898 del pos_args[0] 1899 elif isinstance(pos_args[0], ExprNodes.ListNode): 1900 pos_args[0] = pos_args[0].as_tuple() 1901 return node 1902 1903 def _handle_simple_function_list(self, node, pos_args): 1904 if not pos_args: 1905 return ExprNodes.ListNode(node.pos, args=[], constant_result=[]) 1906 return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type) 1907 1908 def _handle_simple_function_set(self, node, pos_args): 1909 if not pos_args: 1910 return ExprNodes.SetNode(node.pos, args=[], constant_result=set()) 1911 return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type) 1912 1913 def _transform_list_set_genexpr(self, node, pos_args, target_type): 1914 """Replace set(genexpr) and list(genexpr) by an inlined comprehension. 1915 """ 1916 if len(pos_args) > 1: 1917 return node 1918 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): 1919 return node 1920 gen_expr_node = pos_args[0] 1921 loop_node = gen_expr_node.loop 1922 1923 yield_statements = _find_yield_statements(loop_node) 1924 if not yield_statements: 1925 return node 1926 1927 result_node = ExprNodes.InlinedGeneratorExpressionNode( 1928 node.pos, gen_expr_node, 1929 orig_func='set' if target_type is Builtin.set_type else 'list', 1930 comprehension_type=target_type) 1931 1932 for yield_expression, yield_stat_node in yield_statements: 1933 append_node = ExprNodes.ComprehensionAppendNode( 1934 yield_expression.pos, 1935 expr=yield_expression, 1936 target=result_node.target) 1937 Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) 1938 1939 return result_node 1940 1941 def _handle_simple_function_dict(self, node, pos_args): 1942 """Replace dict( (a,b) for ... ) by an inlined { a:b for ... } 1943 """ 1944 if len(pos_args) == 0: 1945 return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={}) 1946 if len(pos_args) > 1: 1947 return node 1948 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): 1949 return node 1950 gen_expr_node = pos_args[0] 1951 loop_node = gen_expr_node.loop 1952 1953 yield_statements = _find_yield_statements(loop_node) 1954 if not yield_statements: 1955 return node 1956 1957 for yield_expression, _ in yield_statements: 1958 if not isinstance(yield_expression, ExprNodes.TupleNode): 1959 return node 1960 if len(yield_expression.args) != 2: 1961 return node 1962 1963 result_node = ExprNodes.InlinedGeneratorExpressionNode( 1964 node.pos, gen_expr_node, orig_func='dict', 1965 comprehension_type=Builtin.dict_type) 1966 1967 for yield_expression, yield_stat_node in yield_statements: 1968 append_node = ExprNodes.DictComprehensionAppendNode( 1969 yield_expression.pos, 1970 key_expr=yield_expression.args[0], 1971 value_expr=yield_expression.args[1], 1972 target=result_node.target) 1973 Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) 1974 1975 return result_node 1976 1977 # specific handlers for general call nodes 1978 1979 def _handle_general_function_dict(self, node, pos_args, kwargs): 1980 """Replace dict(a=b,c=d,...) by the underlying keyword dict 1981 construction which is done anyway. 1982 """ 1983 if len(pos_args) > 0: 1984 return node 1985 if not isinstance(kwargs, ExprNodes.DictNode): 1986 return node 1987 return kwargs 1988 1989 1990class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform): 1991 visit_Node = Visitor.VisitorTransform.recurse_to_children 1992 1993 def get_constant_value_node(self, name_node): 1994 if name_node.cf_state is None: 1995 return None 1996 if name_node.cf_state.cf_is_null: 1997 return None 1998 entry = self.current_env().lookup(name_node.name) 1999 if not entry or (not entry.cf_assignments 2000 or len(entry.cf_assignments) != 1): 2001 # not just a single assignment in all closures 2002 return None 2003 return entry.cf_assignments[0].rhs 2004 2005 def visit_SimpleCallNode(self, node): 2006 self.visitchildren(node) 2007 if not self.current_directives.get('optimize.inline_defnode_calls'): 2008 return node 2009 function_name = node.function 2010 if not function_name.is_name: 2011 return node 2012 function = self.get_constant_value_node(function_name) 2013 if not isinstance(function, ExprNodes.PyCFunctionNode): 2014 return node 2015 inlined = ExprNodes.InlinedDefNodeCallNode( 2016 node.pos, function_name=function_name, 2017 function=function, args=node.args) 2018 if inlined.can_be_inlined(): 2019 return self.replace(node, inlined) 2020 return node 2021 2022 2023class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, 2024 Visitor.MethodDispatcherTransform): 2025 """Optimize some common methods calls and instantiation patterns 2026 for builtin types *after* the type analysis phase. 2027 2028 Running after type analysis, this transform can only perform 2029 function replacements that do not alter the function return type 2030 in a way that was not anticipated by the type analysis. 2031 """ 2032 ### cleanup to avoid redundant coercions to/from Python types 2033 2034 def visit_PyTypeTestNode(self, node): 2035 """Flatten redundant type checks after tree changes. 2036 """ 2037 self.visitchildren(node) 2038 return node.reanalyse() 2039 2040 def _visit_TypecastNode(self, node): 2041 # disabled - the user may have had a reason to put a type 2042 # cast, even if it looks redundant to Cython 2043 """ 2044 Drop redundant type casts. 2045 """ 2046 self.visitchildren(node) 2047 if node.type == node.operand.type: 2048 return node.operand 2049 return node 2050 2051 def visit_ExprStatNode(self, node): 2052 """ 2053 Drop dead code and useless coercions. 2054 """ 2055 self.visitchildren(node) 2056 if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode): 2057 node.expr = node.expr.arg 2058 expr = node.expr 2059 if expr is None or expr.is_none or expr.is_literal: 2060 # Expression was removed or is dead code => remove ExprStatNode as well. 2061 return None 2062 if expr.is_name and expr.entry and (expr.entry.is_local or expr.entry.is_arg): 2063 # Ignore dead references to local variables etc. 2064 return None 2065 return node 2066 2067 def visit_CoerceToBooleanNode(self, node): 2068 """Drop redundant conversion nodes after tree changes. 2069 """ 2070 self.visitchildren(node) 2071 arg = node.arg 2072 if isinstance(arg, ExprNodes.PyTypeTestNode): 2073 arg = arg.arg 2074 if isinstance(arg, ExprNodes.CoerceToPyTypeNode): 2075 if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type): 2076 return arg.arg.coerce_to_boolean(self.current_env()) 2077 return node 2078 2079 PyNumber_Float_func_type = PyrexTypes.CFuncType( 2080 PyrexTypes.py_object_type, [ 2081 PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None) 2082 ]) 2083 2084 def visit_CoerceToPyTypeNode(self, node): 2085 """Drop redundant conversion nodes after tree changes.""" 2086 self.visitchildren(node) 2087 arg = node.arg 2088 if isinstance(arg, ExprNodes.CoerceFromPyTypeNode): 2089 arg = arg.arg 2090 if isinstance(arg, ExprNodes.PythonCapiCallNode): 2091 if arg.function.name == 'float' and len(arg.args) == 1: 2092 # undo redundant Py->C->Py coercion 2093 func_arg = arg.args[0] 2094 if func_arg.type is Builtin.float_type: 2095 return func_arg.as_none_safe_node("float() argument must be a string or a number, not 'NoneType'") 2096 elif func_arg.type.is_pyobject: 2097 return ExprNodes.PythonCapiCallNode( 2098 node.pos, '__Pyx_PyNumber_Float', self.PyNumber_Float_func_type, 2099 args=[func_arg], 2100 py_name='float', 2101 is_temp=node.is_temp, 2102 result_is_used=node.result_is_used, 2103 ).coerce_to(node.type, self.current_env()) 2104 return node 2105 2106 def visit_CoerceFromPyTypeNode(self, node): 2107 """Drop redundant conversion nodes after tree changes. 2108 2109 Also, optimise away calls to Python's builtin int() and 2110 float() if the result is going to be coerced back into a C 2111 type anyway. 2112 """ 2113 self.visitchildren(node) 2114 arg = node.arg 2115 if not arg.type.is_pyobject: 2116 # no Python conversion left at all, just do a C coercion instead 2117 if node.type != arg.type: 2118 arg = arg.coerce_to(node.type, self.current_env()) 2119 return arg 2120 if isinstance(arg, ExprNodes.PyTypeTestNode): 2121 arg = arg.arg 2122 if arg.is_literal: 2123 if (node.type.is_int and isinstance(arg, ExprNodes.IntNode) or 2124 node.type.is_float and isinstance(arg, ExprNodes.FloatNode) or 2125 node.type.is_int and isinstance(arg, ExprNodes.BoolNode)): 2126 return arg.coerce_to(node.type, self.current_env()) 2127 elif isinstance(arg, ExprNodes.CoerceToPyTypeNode): 2128 if arg.type is PyrexTypes.py_object_type: 2129 if node.type.assignable_from(arg.arg.type): 2130 # completely redundant C->Py->C coercion 2131 return arg.arg.coerce_to(node.type, self.current_env()) 2132 elif arg.type is Builtin.unicode_type: 2133 if arg.arg.type.is_unicode_char and node.type.is_unicode_char: 2134 return arg.arg.coerce_to(node.type, self.current_env()) 2135 elif isinstance(arg, ExprNodes.SimpleCallNode): 2136 if node.type.is_int or node.type.is_float: 2137 return self._optimise_numeric_cast_call(node, arg) 2138 elif arg.is_subscript: 2139 index_node = arg.index 2140 if isinstance(index_node, ExprNodes.CoerceToPyTypeNode): 2141 index_node = index_node.arg 2142 if index_node.type.is_int: 2143 return self._optimise_int_indexing(node, arg, index_node) 2144 return node 2145 2146 PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType( 2147 PyrexTypes.c_char_type, [ 2148 PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None), 2149 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None), 2150 PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None), 2151 ], 2152 exception_value = "((char)-1)", 2153 exception_check = True) 2154 2155 def _optimise_int_indexing(self, coerce_node, arg, index_node): 2156 env = self.current_env() 2157 bound_check_bool = env.directives['boundscheck'] and 1 or 0 2158 if arg.base.type is Builtin.bytes_type: 2159 if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type): 2160 # bytes[index] -> char 2161 bound_check_node = ExprNodes.IntNode( 2162 coerce_node.pos, value=str(bound_check_bool), 2163 constant_result=bound_check_bool) 2164 node = ExprNodes.PythonCapiCallNode( 2165 coerce_node.pos, "__Pyx_PyBytes_GetItemInt", 2166 self.PyBytes_GetItemInt_func_type, 2167 args=[ 2168 arg.base.as_none_safe_node("'NoneType' object is not subscriptable"), 2169 index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env), 2170 bound_check_node, 2171 ], 2172 is_temp=True, 2173 utility_code=UtilityCode.load_cached( 2174 'bytes_index', 'StringTools.c')) 2175 if coerce_node.type is not PyrexTypes.c_char_type: 2176 node = node.coerce_to(coerce_node.type, env) 2177 return node 2178 return coerce_node 2179 2180 float_float_func_types = dict( 2181 (float_type, PyrexTypes.CFuncType( 2182 float_type, [ 2183 PyrexTypes.CFuncTypeArg("arg", float_type, None) 2184 ])) 2185 for float_type in (PyrexTypes.c_float_type, PyrexTypes.c_double_type, PyrexTypes.c_longdouble_type)) 2186 2187 def _optimise_numeric_cast_call(self, node, arg): 2188 function = arg.function 2189 args = None 2190 if isinstance(arg, ExprNodes.PythonCapiCallNode): 2191 args = arg.args 2192 elif isinstance(function, ExprNodes.NameNode): 2193 if function.type.is_builtin_type and isinstance(arg.arg_tuple, ExprNodes.TupleNode): 2194 args = arg.arg_tuple.args 2195 2196 if args is None or len(args) != 1: 2197 return node 2198 func_arg = args[0] 2199 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): 2200 func_arg = func_arg.arg 2201 elif func_arg.type.is_pyobject: 2202 # play it safe: Python conversion might work on all sorts of things 2203 return node 2204 2205 if function.name == 'int': 2206 if func_arg.type.is_int or node.type.is_int: 2207 if func_arg.type == node.type: 2208 return func_arg 2209 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float: 2210 return ExprNodes.TypecastNode(node.pos, operand=func_arg, type=node.type) 2211 elif func_arg.type.is_float and node.type.is_numeric: 2212 if func_arg.type.math_h_modifier == 'l': 2213 # Work around missing Cygwin definition. 2214 truncl = '__Pyx_truncl' 2215 else: 2216 truncl = 'trunc' + func_arg.type.math_h_modifier 2217 return ExprNodes.PythonCapiCallNode( 2218 node.pos, truncl, 2219 func_type=self.float_float_func_types[func_arg.type], 2220 args=[func_arg], 2221 py_name='int', 2222 is_temp=node.is_temp, 2223 result_is_used=node.result_is_used, 2224 ).coerce_to(node.type, self.current_env()) 2225 elif function.name == 'float': 2226 if func_arg.type.is_float or node.type.is_float: 2227 if func_arg.type == node.type: 2228 return func_arg 2229 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float: 2230 return ExprNodes.TypecastNode( 2231 node.pos, operand=func_arg, type=node.type) 2232 return node 2233 2234 def _error_wrong_arg_count(self, function_name, node, args, expected=None): 2235 if not expected: # None or 0 2236 arg_str = '' 2237 elif isinstance(expected, basestring) or expected > 1: 2238 arg_str = '...' 2239 elif expected == 1: 2240 arg_str = 'x' 2241 else: 2242 arg_str = '' 2243 if expected is not None: 2244 expected_str = 'expected %s, ' % expected 2245 else: 2246 expected_str = '' 2247 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % ( 2248 function_name, arg_str, expected_str, len(args))) 2249 2250 ### generic fallbacks 2251 2252 def _handle_function(self, node, function_name, function, arg_list, kwargs): 2253 return node 2254 2255 def _handle_method(self, node, type_name, attr_name, function, 2256 arg_list, is_unbound_method, kwargs): 2257 """ 2258 Try to inject C-API calls for unbound method calls to builtin types. 2259 While the method declarations in Builtin.py already handle this, we 2260 can additionally resolve bound and unbound methods here that were 2261 assigned to variables ahead of time. 2262 """ 2263 if kwargs: 2264 return node 2265 if not function or not function.is_attribute or not function.obj.is_name: 2266 # cannot track unbound method calls over more than one indirection as 2267 # the names might have been reassigned in the meantime 2268 return node 2269 type_entry = self.current_env().lookup(type_name) 2270 if not type_entry: 2271 return node 2272 method = ExprNodes.AttributeNode( 2273 node.function.pos, 2274 obj=ExprNodes.NameNode( 2275 function.pos, 2276 name=type_name, 2277 entry=type_entry, 2278 type=type_entry.type), 2279 attribute=attr_name, 2280 is_called=True).analyse_as_type_attribute(self.current_env()) 2281 if method is None: 2282 return self._optimise_generic_builtin_method_call( 2283 node, attr_name, function, arg_list, is_unbound_method) 2284 args = node.args 2285 if args is None and node.arg_tuple: 2286 args = node.arg_tuple.args 2287 call_node = ExprNodes.SimpleCallNode( 2288 node.pos, 2289 function=method, 2290 args=args) 2291 if not is_unbound_method: 2292 call_node.self = function.obj 2293 call_node.analyse_c_function_call(self.current_env()) 2294 call_node.analysed = True 2295 return call_node.coerce_to(node.type, self.current_env()) 2296 2297 ### builtin types 2298 2299 def _optimise_generic_builtin_method_call(self, node, attr_name, function, arg_list, is_unbound_method): 2300 """ 2301 Try to inject an unbound method call for a call to a method of a known builtin type. 2302 This enables caching the underlying C function of the method at runtime. 2303 """ 2304 arg_count = len(arg_list) 2305 if is_unbound_method or arg_count >= 3 or not (function.is_attribute and function.is_py_attr): 2306 return node 2307 if not function.obj.type.is_builtin_type: 2308 return node 2309 if function.obj.type.name in ('basestring', 'type'): 2310 # these allow different actual types => unsafe 2311 return node 2312 return ExprNodes.CachedBuiltinMethodCallNode( 2313 node, function.obj, attr_name, arg_list) 2314 2315 PyObject_Unicode_func_type = PyrexTypes.CFuncType( 2316 Builtin.unicode_type, [ 2317 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None) 2318 ]) 2319 2320 def _handle_simple_function_unicode(self, node, function, pos_args): 2321 """Optimise single argument calls to unicode(). 2322 """ 2323 if len(pos_args) != 1: 2324 if len(pos_args) == 0: 2325 return ExprNodes.UnicodeNode(node.pos, value=EncodedString(), constant_result=u'') 2326 return node 2327 arg = pos_args[0] 2328 if arg.type is Builtin.unicode_type: 2329 if not arg.may_be_none(): 2330 return arg 2331 cname = "__Pyx_PyUnicode_Unicode" 2332 utility_code = UtilityCode.load_cached('PyUnicode_Unicode', 'StringTools.c') 2333 else: 2334 cname = "__Pyx_PyObject_Unicode" 2335 utility_code = UtilityCode.load_cached('PyObject_Unicode', 'StringTools.c') 2336 return ExprNodes.PythonCapiCallNode( 2337 node.pos, cname, self.PyObject_Unicode_func_type, 2338 args=pos_args, 2339 is_temp=node.is_temp, 2340 utility_code=utility_code, 2341 py_name="unicode") 2342 2343 def visit_FormattedValueNode(self, node): 2344 """Simplify or avoid plain string formatting of a unicode value. 2345 This seems misplaced here, but plain unicode formatting is essentially 2346 a call to the unicode() builtin, which is optimised right above. 2347 """ 2348 self.visitchildren(node) 2349 if node.value.type is Builtin.unicode_type and not node.c_format_spec and not node.format_spec: 2350 if not node.conversion_char or node.conversion_char == 's': 2351 # value is definitely a unicode string and we don't format it any special 2352 return self._handle_simple_function_unicode(node, None, [node.value]) 2353 return node 2354 2355 PyDict_Copy_func_type = PyrexTypes.CFuncType( 2356 Builtin.dict_type, [ 2357 PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None) 2358 ]) 2359 2360 def _handle_simple_function_dict(self, node, function, pos_args): 2361 """Replace dict(some_dict) by PyDict_Copy(some_dict). 2362 """ 2363 if len(pos_args) != 1: 2364 return node 2365 arg = pos_args[0] 2366 if arg.type is Builtin.dict_type: 2367 arg = arg.as_none_safe_node("'NoneType' is not iterable") 2368 return ExprNodes.PythonCapiCallNode( 2369 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type, 2370 args = [arg], 2371 is_temp = node.is_temp 2372 ) 2373 return node 2374 2375 PySequence_List_func_type = PyrexTypes.CFuncType( 2376 Builtin.list_type, 2377 [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)]) 2378 2379 def _handle_simple_function_list(self, node, function, pos_args): 2380 """Turn list(ob) into PySequence_List(ob). 2381 """ 2382 if len(pos_args) != 1: 2383 return node 2384 arg = pos_args[0] 2385 return ExprNodes.PythonCapiCallNode( 2386 node.pos, "PySequence_List", self.PySequence_List_func_type, 2387 args=pos_args, is_temp=node.is_temp) 2388 2389 PyList_AsTuple_func_type = PyrexTypes.CFuncType( 2390 Builtin.tuple_type, [ 2391 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None) 2392 ]) 2393 2394 def _handle_simple_function_tuple(self, node, function, pos_args): 2395 """Replace tuple([...]) by PyList_AsTuple or PySequence_Tuple. 2396 """ 2397 if len(pos_args) != 1 or not node.is_temp: 2398 return node 2399 arg = pos_args[0] 2400 if arg.type is Builtin.tuple_type and not arg.may_be_none(): 2401 return arg 2402 if arg.type is Builtin.list_type: 2403 pos_args[0] = arg.as_none_safe_node( 2404 "'NoneType' object is not iterable") 2405 2406 return ExprNodes.PythonCapiCallNode( 2407 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type, 2408 args=pos_args, is_temp=node.is_temp) 2409 else: 2410 return ExprNodes.AsTupleNode(node.pos, arg=arg, type=Builtin.tuple_type) 2411 2412 PySet_New_func_type = PyrexTypes.CFuncType( 2413 Builtin.set_type, [ 2414 PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None) 2415 ]) 2416 2417 def _handle_simple_function_set(self, node, function, pos_args): 2418 if len(pos_args) != 1: 2419 return node 2420 if pos_args[0].is_sequence_constructor: 2421 # We can optimise set([x,y,z]) safely into a set literal, 2422 # but only if we create all items before adding them - 2423 # adding an item may raise an exception if it is not 2424 # hashable, but creating the later items may have 2425 # side-effects. 2426 args = [] 2427 temps = [] 2428 for arg in pos_args[0].args: 2429 if not arg.is_simple(): 2430 arg = UtilNodes.LetRefNode(arg) 2431 temps.append(arg) 2432 args.append(arg) 2433 result = ExprNodes.SetNode(node.pos, is_temp=1, args=args) 2434 self.replace(node, result) 2435 for temp in temps[::-1]: 2436 result = UtilNodes.EvalWithTempExprNode(temp, result) 2437 return result 2438 else: 2439 # PySet_New(it) is better than a generic Python call to set(it) 2440 return self.replace(node, ExprNodes.PythonCapiCallNode( 2441 node.pos, "PySet_New", 2442 self.PySet_New_func_type, 2443 args=pos_args, 2444 is_temp=node.is_temp, 2445 py_name="set")) 2446 2447 PyFrozenSet_New_func_type = PyrexTypes.CFuncType( 2448 Builtin.frozenset_type, [ 2449 PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None) 2450 ]) 2451 2452 def _handle_simple_function_frozenset(self, node, function, pos_args): 2453 if not pos_args: 2454 pos_args = [ExprNodes.NullNode(node.pos)] 2455 elif len(pos_args) > 1: 2456 return node 2457 elif pos_args[0].type is Builtin.frozenset_type and not pos_args[0].may_be_none(): 2458 return pos_args[0] 2459 # PyFrozenSet_New(it) is better than a generic Python call to frozenset(it) 2460 return ExprNodes.PythonCapiCallNode( 2461 node.pos, "__Pyx_PyFrozenSet_New", 2462 self.PyFrozenSet_New_func_type, 2463 args=pos_args, 2464 is_temp=node.is_temp, 2465 utility_code=UtilityCode.load_cached('pyfrozenset_new', 'Builtins.c'), 2466 py_name="frozenset") 2467 2468 PyObject_AsDouble_func_type = PyrexTypes.CFuncType( 2469 PyrexTypes.c_double_type, [ 2470 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None), 2471 ], 2472 exception_value = "((double)-1)", 2473 exception_check = True) 2474 2475 def _handle_simple_function_float(self, node, function, pos_args): 2476 """Transform float() into either a C type cast or a faster C 2477 function call. 2478 """ 2479 # Note: this requires the float() function to be typed as 2480 # returning a C 'double' 2481 if len(pos_args) == 0: 2482 return ExprNodes.FloatNode( 2483 node, value="0.0", constant_result=0.0 2484 ).coerce_to(Builtin.float_type, self.current_env()) 2485 elif len(pos_args) != 1: 2486 self._error_wrong_arg_count('float', node, pos_args, '0 or 1') 2487 return node 2488 func_arg = pos_args[0] 2489 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): 2490 func_arg = func_arg.arg 2491 if func_arg.type is PyrexTypes.c_double_type: 2492 return func_arg 2493 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric: 2494 return ExprNodes.TypecastNode( 2495 node.pos, operand=func_arg, type=node.type) 2496 return ExprNodes.PythonCapiCallNode( 2497 node.pos, "__Pyx_PyObject_AsDouble", 2498 self.PyObject_AsDouble_func_type, 2499 args = pos_args, 2500 is_temp = node.is_temp, 2501 utility_code = load_c_utility('pyobject_as_double'), 2502 py_name = "float") 2503 2504 PyNumber_Int_func_type = PyrexTypes.CFuncType( 2505 PyrexTypes.py_object_type, [ 2506 PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None) 2507 ]) 2508 2509 PyInt_FromDouble_func_type = PyrexTypes.CFuncType( 2510 PyrexTypes.py_object_type, [ 2511 PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_double_type, None) 2512 ]) 2513 2514 def _handle_simple_function_int(self, node, function, pos_args): 2515 """Transform int() into a faster C function call. 2516 """ 2517 if len(pos_args) == 0: 2518 return ExprNodes.IntNode(node.pos, value="0", constant_result=0, 2519 type=PyrexTypes.py_object_type) 2520 elif len(pos_args) != 1: 2521 return node # int(x, base) 2522 func_arg = pos_args[0] 2523 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): 2524 if func_arg.arg.type.is_float: 2525 return ExprNodes.PythonCapiCallNode( 2526 node.pos, "__Pyx_PyInt_FromDouble", self.PyInt_FromDouble_func_type, 2527 args=[func_arg.arg], is_temp=True, py_name='int', 2528 utility_code=UtilityCode.load_cached("PyIntFromDouble", "TypeConversion.c")) 2529 else: 2530 return node # handled in visit_CoerceFromPyTypeNode() 2531 if func_arg.type.is_pyobject and node.type.is_pyobject: 2532 return ExprNodes.PythonCapiCallNode( 2533 node.pos, "__Pyx_PyNumber_Int", self.PyNumber_Int_func_type, 2534 args=pos_args, is_temp=True, py_name='int') 2535 return node 2536 2537 def _handle_simple_function_bool(self, node, function, pos_args): 2538 """Transform bool(x) into a type coercion to a boolean. 2539 """ 2540 if len(pos_args) == 0: 2541 return ExprNodes.BoolNode( 2542 node.pos, value=False, constant_result=False 2543 ).coerce_to(Builtin.bool_type, self.current_env()) 2544 elif len(pos_args) != 1: 2545 self._error_wrong_arg_count('bool', node, pos_args, '0 or 1') 2546 return node 2547 else: 2548 # => !!<bint>(x) to make sure it's exactly 0 or 1 2549 operand = pos_args[0].coerce_to_boolean(self.current_env()) 2550 operand = ExprNodes.NotNode(node.pos, operand = operand) 2551 operand = ExprNodes.NotNode(node.pos, operand = operand) 2552 # coerce back to Python object as that's the result we are expecting 2553 return operand.coerce_to_pyobject(self.current_env()) 2554 2555 ### builtin functions 2556 2557 Pyx_strlen_func_type = PyrexTypes.CFuncType( 2558 PyrexTypes.c_size_t_type, [ 2559 PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_const_char_ptr_type, None) 2560 ]) 2561 2562 Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType( 2563 PyrexTypes.c_size_t_type, [ 2564 PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_const_py_unicode_ptr_type, None) 2565 ]) 2566 2567 PyObject_Size_func_type = PyrexTypes.CFuncType( 2568 PyrexTypes.c_py_ssize_t_type, [ 2569 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None) 2570 ], 2571 exception_value="-1") 2572 2573 _map_to_capi_len_function = { 2574 Builtin.unicode_type: "__Pyx_PyUnicode_GET_LENGTH", 2575 Builtin.bytes_type: "PyBytes_GET_SIZE", 2576 Builtin.bytearray_type: 'PyByteArray_GET_SIZE', 2577 Builtin.list_type: "PyList_GET_SIZE", 2578 Builtin.tuple_type: "PyTuple_GET_SIZE", 2579 Builtin.set_type: "PySet_GET_SIZE", 2580 Builtin.frozenset_type: "PySet_GET_SIZE", 2581 Builtin.dict_type: "PyDict_Size", 2582 }.get 2583 2584 _ext_types_with_pysize = set(["cpython.array.array"]) 2585 2586 def _handle_simple_function_len(self, node, function, pos_args): 2587 """Replace len(char*) by the equivalent call to strlen(), 2588 len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and 2589 len(known_builtin_type) by an equivalent C-API call. 2590 """ 2591 if len(pos_args) != 1: 2592 self._error_wrong_arg_count('len', node, pos_args, 1) 2593 return node 2594 arg = pos_args[0] 2595 if isinstance(arg, ExprNodes.CoerceToPyTypeNode): 2596 arg = arg.arg 2597 if arg.type.is_string: 2598 new_node = ExprNodes.PythonCapiCallNode( 2599 node.pos, "strlen", self.Pyx_strlen_func_type, 2600 args = [arg], 2601 is_temp = node.is_temp, 2602 utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c")) 2603 elif arg.type.is_pyunicode_ptr: 2604 new_node = ExprNodes.PythonCapiCallNode( 2605 node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_func_type, 2606 args = [arg], 2607 is_temp = node.is_temp) 2608 elif arg.type.is_memoryviewslice: 2609 func_type = PyrexTypes.CFuncType( 2610 PyrexTypes.c_size_t_type, [ 2611 PyrexTypes.CFuncTypeArg("memoryviewslice", arg.type, None) 2612 ], nogil=True) 2613 new_node = ExprNodes.PythonCapiCallNode( 2614 node.pos, "__Pyx_MemoryView_Len", func_type, 2615 args=[arg], is_temp=node.is_temp) 2616 elif arg.type.is_pyobject: 2617 cfunc_name = self._map_to_capi_len_function(arg.type) 2618 if cfunc_name is None: 2619 arg_type = arg.type 2620 if ((arg_type.is_extension_type or arg_type.is_builtin_type) 2621 and arg_type.entry.qualified_name in self._ext_types_with_pysize): 2622 cfunc_name = 'Py_SIZE' 2623 else: 2624 return node 2625 arg = arg.as_none_safe_node( 2626 "object of type 'NoneType' has no len()") 2627 new_node = ExprNodes.PythonCapiCallNode( 2628 node.pos, cfunc_name, self.PyObject_Size_func_type, 2629 args=[arg], is_temp=node.is_temp) 2630 elif arg.type.is_unicode_char: 2631 return ExprNodes.IntNode(node.pos, value='1', constant_result=1, 2632 type=node.type) 2633 else: 2634 return node 2635 if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type): 2636 new_node = new_node.coerce_to(node.type, self.current_env()) 2637 return new_node 2638 2639 Pyx_Type_func_type = PyrexTypes.CFuncType( 2640 Builtin.type_type, [ 2641 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None) 2642 ]) 2643 2644 def _handle_simple_function_type(self, node, function, pos_args): 2645 """Replace type(o) by a macro call to Py_TYPE(o). 2646 """ 2647 if len(pos_args) != 1: 2648 return node 2649 node = ExprNodes.PythonCapiCallNode( 2650 node.pos, "Py_TYPE", self.Pyx_Type_func_type, 2651 args = pos_args, 2652 is_temp = False) 2653 return ExprNodes.CastNode(node, PyrexTypes.py_object_type) 2654 2655 Py_type_check_func_type = PyrexTypes.CFuncType( 2656 PyrexTypes.c_bint_type, [ 2657 PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None) 2658 ]) 2659 2660 def _handle_simple_function_isinstance(self, node, function, pos_args): 2661 """Replace isinstance() checks against builtin types by the 2662 corresponding C-API call. 2663 """ 2664 if len(pos_args) != 2: 2665 return node 2666 arg, types = pos_args 2667 temps = [] 2668 if isinstance(types, ExprNodes.TupleNode): 2669 types = types.args 2670 if len(types) == 1 and not types[0].type is Builtin.type_type: 2671 return node # nothing to improve here 2672 if arg.is_attribute or not arg.is_simple(): 2673 arg = UtilNodes.ResultRefNode(arg) 2674 temps.append(arg) 2675 elif types.type is Builtin.type_type: 2676 types = [types] 2677 else: 2678 return node 2679 2680 tests = [] 2681 test_nodes = [] 2682 env = self.current_env() 2683 for test_type_node in types: 2684 builtin_type = None 2685 if test_type_node.is_name: 2686 if test_type_node.entry: 2687 entry = env.lookup(test_type_node.entry.name) 2688 if entry and entry.type and entry.type.is_builtin_type: 2689 builtin_type = entry.type 2690 if builtin_type is Builtin.type_type: 2691 # all types have type "type", but there's only one 'type' 2692 if entry.name != 'type' or not ( 2693 entry.scope and entry.scope.is_builtin_scope): 2694 builtin_type = None 2695 if builtin_type is not None: 2696 type_check_function = entry.type.type_check_function(exact=False) 2697 if type_check_function in tests: 2698 continue 2699 tests.append(type_check_function) 2700 type_check_args = [arg] 2701 elif test_type_node.type is Builtin.type_type: 2702 type_check_function = '__Pyx_TypeCheck' 2703 type_check_args = [arg, test_type_node] 2704 else: 2705 if not test_type_node.is_literal: 2706 test_type_node = UtilNodes.ResultRefNode(test_type_node) 2707 temps.append(test_type_node) 2708 type_check_function = 'PyObject_IsInstance' 2709 type_check_args = [arg, test_type_node] 2710 test_nodes.append( 2711 ExprNodes.PythonCapiCallNode( 2712 test_type_node.pos, type_check_function, self.Py_type_check_func_type, 2713 args=type_check_args, 2714 is_temp=True, 2715 )) 2716 2717 def join_with_or(a, b, make_binop_node=ExprNodes.binop_node): 2718 or_node = make_binop_node(node.pos, 'or', a, b) 2719 or_node.type = PyrexTypes.c_bint_type 2720 or_node.wrap_operands(env) 2721 return or_node 2722 2723 test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env) 2724 for temp in temps[::-1]: 2725 test_node = UtilNodes.EvalWithTempExprNode(temp, test_node) 2726 return test_node 2727 2728 def _handle_simple_function_ord(self, node, function, pos_args): 2729 """Unpack ord(Py_UNICODE) and ord('X'). 2730 """ 2731 if len(pos_args) != 1: 2732 return node 2733 arg = pos_args[0] 2734 if isinstance(arg, ExprNodes.CoerceToPyTypeNode): 2735 if arg.arg.type.is_unicode_char: 2736 return ExprNodes.TypecastNode( 2737 arg.pos, operand=arg.arg, type=PyrexTypes.c_long_type 2738 ).coerce_to(node.type, self.current_env()) 2739 elif isinstance(arg, ExprNodes.UnicodeNode): 2740 if len(arg.value) == 1: 2741 return ExprNodes.IntNode( 2742 arg.pos, type=PyrexTypes.c_int_type, 2743 value=str(ord(arg.value)), 2744 constant_result=ord(arg.value) 2745 ).coerce_to(node.type, self.current_env()) 2746 elif isinstance(arg, ExprNodes.StringNode): 2747 if arg.unicode_value and len(arg.unicode_value) == 1 \ 2748 and ord(arg.unicode_value) <= 255: # Py2/3 portability 2749 return ExprNodes.IntNode( 2750 arg.pos, type=PyrexTypes.c_int_type, 2751 value=str(ord(arg.unicode_value)), 2752 constant_result=ord(arg.unicode_value) 2753 ).coerce_to(node.type, self.current_env()) 2754 return node 2755 2756 ### special methods 2757 2758 Pyx_tp_new_func_type = PyrexTypes.CFuncType( 2759 PyrexTypes.py_object_type, [ 2760 PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None), 2761 PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None), 2762 ]) 2763 2764 Pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType( 2765 PyrexTypes.py_object_type, [ 2766 PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None), 2767 PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None), 2768 PyrexTypes.CFuncTypeArg("kwargs", Builtin.dict_type, None), 2769 ]) 2770 2771 def _handle_any_slot__new__(self, node, function, args, 2772 is_unbound_method, kwargs=None): 2773 """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new() 2774 """ 2775 obj = function.obj 2776 if not is_unbound_method or len(args) < 1: 2777 return node 2778 type_arg = args[0] 2779 if not obj.is_name or not type_arg.is_name: 2780 # play safe 2781 return node 2782 if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type: 2783 # not a known type, play safe 2784 return node 2785 if not type_arg.type_entry or not obj.type_entry: 2786 if obj.name != type_arg.name: 2787 return node 2788 # otherwise, we know it's a type and we know it's the same 2789 # type for both - that should do 2790 elif type_arg.type_entry != obj.type_entry: 2791 # different types - may or may not lead to an error at runtime 2792 return node 2793 2794 args_tuple = ExprNodes.TupleNode(node.pos, args=args[1:]) 2795 args_tuple = args_tuple.analyse_types( 2796 self.current_env(), skip_children=True) 2797 2798 if type_arg.type_entry: 2799 ext_type = type_arg.type_entry.type 2800 if (ext_type.is_extension_type and ext_type.typeobj_cname and 2801 ext_type.scope.global_scope() == self.current_env().global_scope()): 2802 # known type in current module 2803 tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__') 2804 slot_func_cname = TypeSlots.get_slot_function(ext_type.scope, tp_slot) 2805 if slot_func_cname: 2806 cython_scope = self.context.cython_scope 2807 PyTypeObjectPtr = PyrexTypes.CPtrType( 2808 cython_scope.lookup('PyTypeObject').type) 2809 pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType( 2810 ext_type, [ 2811 PyrexTypes.CFuncTypeArg("type", PyTypeObjectPtr, None), 2812 PyrexTypes.CFuncTypeArg("args", PyrexTypes.py_object_type, None), 2813 PyrexTypes.CFuncTypeArg("kwargs", PyrexTypes.py_object_type, None), 2814 ]) 2815 2816 type_arg = ExprNodes.CastNode(type_arg, PyTypeObjectPtr) 2817 if not kwargs: 2818 kwargs = ExprNodes.NullNode(node.pos, type=PyrexTypes.py_object_type) # hack? 2819 return ExprNodes.PythonCapiCallNode( 2820 node.pos, slot_func_cname, 2821 pyx_tp_new_kwargs_func_type, 2822 args=[type_arg, args_tuple, kwargs], 2823 may_return_none=False, 2824 is_temp=True) 2825 else: 2826 # arbitrary variable, needs a None check for safety 2827 type_arg = type_arg.as_none_safe_node( 2828 "object.__new__(X): X is not a type object (NoneType)") 2829 2830 utility_code = UtilityCode.load_cached('tp_new', 'ObjectHandling.c') 2831 if kwargs: 2832 return ExprNodes.PythonCapiCallNode( 2833 node.pos, "__Pyx_tp_new_kwargs", self.Pyx_tp_new_kwargs_func_type, 2834 args=[type_arg, args_tuple, kwargs], 2835 utility_code=utility_code, 2836 is_temp=node.is_temp 2837 ) 2838 else: 2839 return ExprNodes.PythonCapiCallNode( 2840 node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type, 2841 args=[type_arg, args_tuple], 2842 utility_code=utility_code, 2843 is_temp=node.is_temp 2844 ) 2845 2846 ### methods of builtin types 2847 2848 PyObject_Append_func_type = PyrexTypes.CFuncType( 2849 PyrexTypes.c_returncode_type, [ 2850 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), 2851 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None), 2852 ], 2853 exception_value="-1") 2854 2855 def _handle_simple_method_object_append(self, node, function, args, is_unbound_method): 2856 """Optimistic optimisation as X.append() is almost always 2857 referring to a list. 2858 """ 2859 if len(args) != 2 or node.result_is_used: 2860 return node 2861 2862 return ExprNodes.PythonCapiCallNode( 2863 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type, 2864 args=args, 2865 may_return_none=False, 2866 is_temp=node.is_temp, 2867 result_is_used=False, 2868 utility_code=load_c_utility('append') 2869 ) 2870 2871 def _handle_simple_method_list_extend(self, node, function, args, is_unbound_method): 2872 """Replace list.extend([...]) for short sequence literals values by sequential appends 2873 to avoid creating an intermediate sequence argument. 2874 """ 2875 if len(args) != 2: 2876 return node 2877 obj, value = args 2878 if not value.is_sequence_constructor: 2879 return node 2880 items = list(value.args) 2881 if value.mult_factor is not None or len(items) > 8: 2882 # Appending wins for short sequences but slows down when multiple resize operations are needed. 2883 # This seems to be a good enough limit that avoids repeated resizing. 2884 if False and isinstance(value, ExprNodes.ListNode): 2885 # One would expect that tuples are more efficient here, but benchmarking with 2886 # Py3.5 and Py3.7 suggests that they are not. Probably worth revisiting at some point. 2887 # Might be related to the usage of PySequence_FAST() in CPython's list.extend(), 2888 # which is probably tuned more towards lists than tuples (and rightly so). 2889 tuple_node = args[1].as_tuple().analyse_types(self.current_env(), skip_children=True) 2890 Visitor.recursively_replace_node(node, args[1], tuple_node) 2891 return node 2892 wrapped_obj = self._wrap_self_arg(obj, function, is_unbound_method, 'extend') 2893 if not items: 2894 # Empty sequences are not likely to occur, but why waste a call to list.extend() for them? 2895 wrapped_obj.result_is_used = node.result_is_used 2896 return wrapped_obj 2897 cloned_obj = obj = wrapped_obj 2898 if len(items) > 1 and not obj.is_simple(): 2899 cloned_obj = UtilNodes.LetRefNode(obj) 2900 # Use ListComp_Append() for all but the last item and finish with PyList_Append() 2901 # to shrink the list storage size at the very end if necessary. 2902 temps = [] 2903 arg = items[-1] 2904 if not arg.is_simple(): 2905 arg = UtilNodes.LetRefNode(arg) 2906 temps.append(arg) 2907 new_node = ExprNodes.PythonCapiCallNode( 2908 node.pos, "__Pyx_PyList_Append", self.PyObject_Append_func_type, 2909 args=[cloned_obj, arg], 2910 is_temp=True, 2911 utility_code=load_c_utility("ListAppend")) 2912 for arg in items[-2::-1]: 2913 if not arg.is_simple(): 2914 arg = UtilNodes.LetRefNode(arg) 2915 temps.append(arg) 2916 new_node = ExprNodes.binop_node( 2917 node.pos, '|', 2918 ExprNodes.PythonCapiCallNode( 2919 node.pos, "__Pyx_ListComp_Append", self.PyObject_Append_func_type, 2920 args=[cloned_obj, arg], py_name="extend", 2921 is_temp=True, 2922 utility_code=load_c_utility("ListCompAppend")), 2923 new_node, 2924 type=PyrexTypes.c_returncode_type, 2925 ) 2926 new_node.result_is_used = node.result_is_used 2927 if cloned_obj is not obj: 2928 temps.append(cloned_obj) 2929 for temp in temps: 2930 new_node = UtilNodes.EvalWithTempExprNode(temp, new_node) 2931 new_node.result_is_used = node.result_is_used 2932 return new_node 2933 2934 PyByteArray_Append_func_type = PyrexTypes.CFuncType( 2935 PyrexTypes.c_returncode_type, [ 2936 PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None), 2937 PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_int_type, None), 2938 ], 2939 exception_value="-1") 2940 2941 PyByteArray_AppendObject_func_type = PyrexTypes.CFuncType( 2942 PyrexTypes.c_returncode_type, [ 2943 PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None), 2944 PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None), 2945 ], 2946 exception_value="-1") 2947 2948 def _handle_simple_method_bytearray_append(self, node, function, args, is_unbound_method): 2949 if len(args) != 2: 2950 return node 2951 func_name = "__Pyx_PyByteArray_Append" 2952 func_type = self.PyByteArray_Append_func_type 2953 2954 value = unwrap_coerced_node(args[1]) 2955 if value.type.is_int or isinstance(value, ExprNodes.IntNode): 2956 value = value.coerce_to(PyrexTypes.c_int_type, self.current_env()) 2957 utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c") 2958 elif value.is_string_literal: 2959 if not value.can_coerce_to_char_literal(): 2960 return node 2961 value = value.coerce_to(PyrexTypes.c_char_type, self.current_env()) 2962 utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c") 2963 elif value.type.is_pyobject: 2964 func_name = "__Pyx_PyByteArray_AppendObject" 2965 func_type = self.PyByteArray_AppendObject_func_type 2966 utility_code = UtilityCode.load_cached("ByteArrayAppendObject", "StringTools.c") 2967 else: 2968 return node 2969 2970 new_node = ExprNodes.PythonCapiCallNode( 2971 node.pos, func_name, func_type, 2972 args=[args[0], value], 2973 may_return_none=False, 2974 is_temp=node.is_temp, 2975 utility_code=utility_code, 2976 ) 2977 if node.result_is_used: 2978 new_node = new_node.coerce_to(node.type, self.current_env()) 2979 return new_node 2980 2981 PyObject_Pop_func_type = PyrexTypes.CFuncType( 2982 PyrexTypes.py_object_type, [ 2983 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), 2984 ]) 2985 2986 PyObject_PopIndex_func_type = PyrexTypes.CFuncType( 2987 PyrexTypes.py_object_type, [ 2988 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), 2989 PyrexTypes.CFuncTypeArg("py_index", PyrexTypes.py_object_type, None), 2990 PyrexTypes.CFuncTypeArg("c_index", PyrexTypes.c_py_ssize_t_type, None), 2991 PyrexTypes.CFuncTypeArg("is_signed", PyrexTypes.c_int_type, None), 2992 ], 2993 has_varargs=True) # to fake the additional macro args that lack a proper C type 2994 2995 def _handle_simple_method_list_pop(self, node, function, args, is_unbound_method): 2996 return self._handle_simple_method_object_pop( 2997 node, function, args, is_unbound_method, is_list=True) 2998 2999 def _handle_simple_method_object_pop(self, node, function, args, is_unbound_method, is_list=False): 3000 """Optimistic optimisation as X.pop([n]) is almost always 3001 referring to a list. 3002 """ 3003 if not args: 3004 return node 3005 obj = args[0] 3006 if is_list: 3007 type_name = 'List' 3008 obj = obj.as_none_safe_node( 3009 "'NoneType' object has no attribute '%.30s'", 3010 error="PyExc_AttributeError", 3011 format_args=['pop']) 3012 else: 3013 type_name = 'Object' 3014 if len(args) == 1: 3015 return ExprNodes.PythonCapiCallNode( 3016 node.pos, "__Pyx_Py%s_Pop" % type_name, 3017 self.PyObject_Pop_func_type, 3018 args=[obj], 3019 may_return_none=True, 3020 is_temp=node.is_temp, 3021 utility_code=load_c_utility('pop'), 3022 ) 3023 elif len(args) == 2: 3024 index = unwrap_coerced_node(args[1]) 3025 py_index = ExprNodes.NoneNode(index.pos) 3026 orig_index_type = index.type 3027 if not index.type.is_int: 3028 if isinstance(index, ExprNodes.IntNode): 3029 py_index = index.coerce_to_pyobject(self.current_env()) 3030 index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) 3031 elif is_list: 3032 if index.type.is_pyobject: 3033 py_index = index.coerce_to_simple(self.current_env()) 3034 index = ExprNodes.CloneNode(py_index) 3035 index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) 3036 else: 3037 return node 3038 elif not PyrexTypes.numeric_type_fits(index.type, PyrexTypes.c_py_ssize_t_type): 3039 return node 3040 elif isinstance(index, ExprNodes.IntNode): 3041 py_index = index.coerce_to_pyobject(self.current_env()) 3042 # real type might still be larger at runtime 3043 if not orig_index_type.is_int: 3044 orig_index_type = index.type 3045 if not orig_index_type.create_to_py_utility_code(self.current_env()): 3046 return node 3047 convert_func = orig_index_type.to_py_function 3048 conversion_type = PyrexTypes.CFuncType( 3049 PyrexTypes.py_object_type, [PyrexTypes.CFuncTypeArg("intval", orig_index_type, None)]) 3050 return ExprNodes.PythonCapiCallNode( 3051 node.pos, "__Pyx_Py%s_PopIndex" % type_name, 3052 self.PyObject_PopIndex_func_type, 3053 args=[obj, py_index, index, 3054 ExprNodes.IntNode(index.pos, value=str(orig_index_type.signed and 1 or 0), 3055 constant_result=orig_index_type.signed and 1 or 0, 3056 type=PyrexTypes.c_int_type), 3057 ExprNodes.RawCNameExprNode(index.pos, PyrexTypes.c_void_type, 3058 orig_index_type.empty_declaration_code()), 3059 ExprNodes.RawCNameExprNode(index.pos, conversion_type, convert_func)], 3060 may_return_none=True, 3061 is_temp=node.is_temp, 3062 utility_code=load_c_utility("pop_index"), 3063 ) 3064 3065 return node 3066 3067 single_param_func_type = PyrexTypes.CFuncType( 3068 PyrexTypes.c_returncode_type, [ 3069 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None), 3070 ], 3071 exception_value = "-1") 3072 3073 def _handle_simple_method_list_sort(self, node, function, args, is_unbound_method): 3074 """Call PyList_Sort() instead of the 0-argument l.sort(). 3075 """ 3076 if len(args) != 1: 3077 return node 3078 return self._substitute_method_call( 3079 node, function, "PyList_Sort", self.single_param_func_type, 3080 'sort', is_unbound_method, args).coerce_to(node.type, self.current_env) 3081 3082 Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType( 3083 PyrexTypes.py_object_type, [ 3084 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), 3085 PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None), 3086 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None), 3087 ]) 3088 3089 def _handle_simple_method_dict_get(self, node, function, args, is_unbound_method): 3090 """Replace dict.get() by a call to PyDict_GetItem(). 3091 """ 3092 if len(args) == 2: 3093 args.append(ExprNodes.NoneNode(node.pos)) 3094 elif len(args) != 3: 3095 self._error_wrong_arg_count('dict.get', node, args, "2 or 3") 3096 return node 3097 3098 return self._substitute_method_call( 3099 node, function, 3100 "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type, 3101 'get', is_unbound_method, args, 3102 may_return_none = True, 3103 utility_code = load_c_utility("dict_getitem_default")) 3104 3105 Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType( 3106 PyrexTypes.py_object_type, [ 3107 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), 3108 PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None), 3109 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None), 3110 PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None), 3111 ]) 3112 3113 def _handle_simple_method_dict_setdefault(self, node, function, args, is_unbound_method): 3114 """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem(). 3115 """ 3116 if len(args) == 2: 3117 args.append(ExprNodes.NoneNode(node.pos)) 3118 elif len(args) != 3: 3119 self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3") 3120 return node 3121 key_type = args[1].type 3122 if key_type.is_builtin_type: 3123 is_safe_type = int(key_type.name in 3124 'str bytes unicode float int long bool') 3125 elif key_type is PyrexTypes.py_object_type: 3126 is_safe_type = -1 # don't know 3127 else: 3128 is_safe_type = 0 # definitely not 3129 args.append(ExprNodes.IntNode( 3130 node.pos, value=str(is_safe_type), constant_result=is_safe_type)) 3131 3132 return self._substitute_method_call( 3133 node, function, 3134 "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type, 3135 'setdefault', is_unbound_method, args, 3136 may_return_none=True, 3137 utility_code=load_c_utility('dict_setdefault')) 3138 3139 PyDict_Pop_func_type = PyrexTypes.CFuncType( 3140 PyrexTypes.py_object_type, [ 3141 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), 3142 PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None), 3143 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None), 3144 ]) 3145 3146 def _handle_simple_method_dict_pop(self, node, function, args, is_unbound_method): 3147 """Replace dict.pop() by a call to _PyDict_Pop(). 3148 """ 3149 if len(args) == 2: 3150 args.append(ExprNodes.NullNode(node.pos)) 3151 elif len(args) != 3: 3152 self._error_wrong_arg_count('dict.pop', node, args, "2 or 3") 3153 return node 3154 3155 return self._substitute_method_call( 3156 node, function, 3157 "__Pyx_PyDict_Pop", self.PyDict_Pop_func_type, 3158 'pop', is_unbound_method, args, 3159 may_return_none=True, 3160 utility_code=load_c_utility('py_dict_pop')) 3161 3162 Pyx_BinopInt_func_types = dict( 3163 ((ctype, ret_type), PyrexTypes.CFuncType( 3164 ret_type, [ 3165 PyrexTypes.CFuncTypeArg("op1", PyrexTypes.py_object_type, None), 3166 PyrexTypes.CFuncTypeArg("op2", PyrexTypes.py_object_type, None), 3167 PyrexTypes.CFuncTypeArg("cval", ctype, None), 3168 PyrexTypes.CFuncTypeArg("inplace", PyrexTypes.c_bint_type, None), 3169 PyrexTypes.CFuncTypeArg("zerodiv_check", PyrexTypes.c_bint_type, None), 3170 ], exception_value=None if ret_type.is_pyobject else ret_type.exception_value)) 3171 for ctype in (PyrexTypes.c_long_type, PyrexTypes.c_double_type) 3172 for ret_type in (PyrexTypes.py_object_type, PyrexTypes.c_bint_type) 3173 ) 3174 3175 def _handle_simple_method_object___add__(self, node, function, args, is_unbound_method): 3176 return self._optimise_num_binop('Add', node, function, args, is_unbound_method) 3177 3178 def _handle_simple_method_object___sub__(self, node, function, args, is_unbound_method): 3179 return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method) 3180 3181 def _handle_simple_method_object___eq__(self, node, function, args, is_unbound_method): 3182 return self._optimise_num_binop('Eq', node, function, args, is_unbound_method) 3183 3184 def _handle_simple_method_object___ne__(self, node, function, args, is_unbound_method): 3185 return self._optimise_num_binop('Ne', node, function, args, is_unbound_method) 3186 3187 def _handle_simple_method_object___and__(self, node, function, args, is_unbound_method): 3188 return self._optimise_num_binop('And', node, function, args, is_unbound_method) 3189 3190 def _handle_simple_method_object___or__(self, node, function, args, is_unbound_method): 3191 return self._optimise_num_binop('Or', node, function, args, is_unbound_method) 3192 3193 def _handle_simple_method_object___xor__(self, node, function, args, is_unbound_method): 3194 return self._optimise_num_binop('Xor', node, function, args, is_unbound_method) 3195 3196 def _handle_simple_method_object___rshift__(self, node, function, args, is_unbound_method): 3197 if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode): 3198 return node 3199 if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63): 3200 return node 3201 return self._optimise_num_binop('Rshift', node, function, args, is_unbound_method) 3202 3203 def _handle_simple_method_object___lshift__(self, node, function, args, is_unbound_method): 3204 if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode): 3205 return node 3206 if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63): 3207 return node 3208 return self._optimise_num_binop('Lshift', node, function, args, is_unbound_method) 3209 3210 def _handle_simple_method_object___mod__(self, node, function, args, is_unbound_method): 3211 return self._optimise_num_div('Remainder', node, function, args, is_unbound_method) 3212 3213 def _handle_simple_method_object___floordiv__(self, node, function, args, is_unbound_method): 3214 return self._optimise_num_div('FloorDivide', node, function, args, is_unbound_method) 3215 3216 def _handle_simple_method_object___truediv__(self, node, function, args, is_unbound_method): 3217 return self._optimise_num_div('TrueDivide', node, function, args, is_unbound_method) 3218 3219 def _handle_simple_method_object___div__(self, node, function, args, is_unbound_method): 3220 return self._optimise_num_div('Divide', node, function, args, is_unbound_method) 3221 3222 def _optimise_num_div(self, operator, node, function, args, is_unbound_method): 3223 if len(args) != 2 or not args[1].has_constant_result() or args[1].constant_result == 0: 3224 return node 3225 if isinstance(args[1], ExprNodes.IntNode): 3226 if not (-2**30 <= args[1].constant_result <= 2**30): 3227 return node 3228 elif isinstance(args[1], ExprNodes.FloatNode): 3229 if not (-2**53 <= args[1].constant_result <= 2**53): 3230 return node 3231 else: 3232 return node 3233 return self._optimise_num_binop(operator, node, function, args, is_unbound_method) 3234 3235 def _handle_simple_method_float___add__(self, node, function, args, is_unbound_method): 3236 return self._optimise_num_binop('Add', node, function, args, is_unbound_method) 3237 3238 def _handle_simple_method_float___sub__(self, node, function, args, is_unbound_method): 3239 return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method) 3240 3241 def _handle_simple_method_float___truediv__(self, node, function, args, is_unbound_method): 3242 return self._optimise_num_binop('TrueDivide', node, function, args, is_unbound_method) 3243 3244 def _handle_simple_method_float___div__(self, node, function, args, is_unbound_method): 3245 return self._optimise_num_binop('Divide', node, function, args, is_unbound_method) 3246 3247 def _handle_simple_method_float___mod__(self, node, function, args, is_unbound_method): 3248 return self._optimise_num_binop('Remainder', node, function, args, is_unbound_method) 3249 3250 def _handle_simple_method_float___eq__(self, node, function, args, is_unbound_method): 3251 return self._optimise_num_binop('Eq', node, function, args, is_unbound_method) 3252 3253 def _handle_simple_method_float___ne__(self, node, function, args, is_unbound_method): 3254 return self._optimise_num_binop('Ne', node, function, args, is_unbound_method) 3255 3256 def _optimise_num_binop(self, operator, node, function, args, is_unbound_method): 3257 """ 3258 Optimise math operators for (likely) float or small integer operations. 3259 """ 3260 if len(args) != 2: 3261 return node 3262 3263 if node.type.is_pyobject: 3264 ret_type = PyrexTypes.py_object_type 3265 elif node.type is PyrexTypes.c_bint_type and operator in ('Eq', 'Ne'): 3266 ret_type = PyrexTypes.c_bint_type 3267 else: 3268 return node 3269 3270 # When adding IntNode/FloatNode to something else, assume other operand is also numeric. 3271 # Prefer constants on RHS as they allows better size control for some operators. 3272 num_nodes = (ExprNodes.IntNode, ExprNodes.FloatNode) 3273 if isinstance(args[1], num_nodes): 3274 if args[0].type is not PyrexTypes.py_object_type: 3275 return node 3276 numval = args[1] 3277 arg_order = 'ObjC' 3278 elif isinstance(args[0], num_nodes): 3279 if args[1].type is not PyrexTypes.py_object_type: 3280 return node 3281 numval = args[0] 3282 arg_order = 'CObj' 3283 else: 3284 return node 3285 3286 if not numval.has_constant_result(): 3287 return node 3288 3289 is_float = isinstance(numval, ExprNodes.FloatNode) 3290 num_type = PyrexTypes.c_double_type if is_float else PyrexTypes.c_long_type 3291 if is_float: 3292 if operator not in ('Add', 'Subtract', 'Remainder', 'TrueDivide', 'Divide', 'Eq', 'Ne'): 3293 return node 3294 elif operator == 'Divide': 3295 # mixed old-/new-style division is not currently optimised for integers 3296 return node 3297 elif abs(numval.constant_result) > 2**30: 3298 # Cut off at an integer border that is still safe for all operations. 3299 return node 3300 3301 if operator in ('TrueDivide', 'FloorDivide', 'Divide', 'Remainder'): 3302 if args[1].constant_result == 0: 3303 # Don't optimise division by 0. :) 3304 return node 3305 3306 args = list(args) 3307 args.append((ExprNodes.FloatNode if is_float else ExprNodes.IntNode)( 3308 numval.pos, value=numval.value, constant_result=numval.constant_result, 3309 type=num_type)) 3310 inplace = node.inplace if isinstance(node, ExprNodes.NumBinopNode) else False 3311 args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace)) 3312 if is_float or operator not in ('Eq', 'Ne'): 3313 # "PyFloatBinop" and "PyIntBinop" take an additional "check for zero division" argument. 3314 zerodivision_check = arg_order == 'CObj' and ( 3315 not node.cdivision if isinstance(node, ExprNodes.DivNode) else False) 3316 args.append(ExprNodes.BoolNode(node.pos, value=zerodivision_check, constant_result=zerodivision_check)) 3317 3318 utility_code = TempitaUtilityCode.load_cached( 3319 "PyFloatBinop" if is_float else "PyIntCompare" if operator in ('Eq', 'Ne') else "PyIntBinop", 3320 "Optimize.c", 3321 context=dict(op=operator, order=arg_order, ret_type=ret_type)) 3322 3323 call_node = self._substitute_method_call( 3324 node, function, 3325 "__Pyx_Py%s_%s%s%s" % ( 3326 'Float' if is_float else 'Int', 3327 '' if ret_type.is_pyobject else 'Bool', 3328 operator, 3329 arg_order), 3330 self.Pyx_BinopInt_func_types[(num_type, ret_type)], 3331 '__%s__' % operator[:3].lower(), is_unbound_method, args, 3332 may_return_none=True, 3333 with_none_check=False, 3334 utility_code=utility_code) 3335 3336 if node.type.is_pyobject and not ret_type.is_pyobject: 3337 call_node = ExprNodes.CoerceToPyTypeNode(call_node, self.current_env(), node.type) 3338 return call_node 3339 3340 ### unicode type methods 3341 3342 PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType( 3343 PyrexTypes.c_bint_type, [ 3344 PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None), 3345 ]) 3346 3347 def _inject_unicode_predicate(self, node, function, args, is_unbound_method): 3348 if is_unbound_method or len(args) != 1: 3349 return node 3350 ustring = args[0] 3351 if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \ 3352 not ustring.arg.type.is_unicode_char: 3353 return node 3354 uchar = ustring.arg 3355 method_name = function.attribute 3356 if method_name == 'istitle': 3357 # istitle() doesn't directly map to Py_UNICODE_ISTITLE() 3358 utility_code = UtilityCode.load_cached( 3359 "py_unicode_istitle", "StringTools.c") 3360 function_name = '__Pyx_Py_UNICODE_ISTITLE' 3361 else: 3362 utility_code = None 3363 function_name = 'Py_UNICODE_%s' % method_name.upper() 3364 func_call = self._substitute_method_call( 3365 node, function, 3366 function_name, self.PyUnicode_uchar_predicate_func_type, 3367 method_name, is_unbound_method, [uchar], 3368 utility_code = utility_code) 3369 if node.type.is_pyobject: 3370 func_call = func_call.coerce_to_pyobject(self.current_env) 3371 return func_call 3372 3373 _handle_simple_method_unicode_isalnum = _inject_unicode_predicate 3374 _handle_simple_method_unicode_isalpha = _inject_unicode_predicate 3375 _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate 3376 _handle_simple_method_unicode_isdigit = _inject_unicode_predicate 3377 _handle_simple_method_unicode_islower = _inject_unicode_predicate 3378 _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate 3379 _handle_simple_method_unicode_isspace = _inject_unicode_predicate 3380 _handle_simple_method_unicode_istitle = _inject_unicode_predicate 3381 _handle_simple_method_unicode_isupper = _inject_unicode_predicate 3382 3383 PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType( 3384 PyrexTypes.c_py_ucs4_type, [ 3385 PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None), 3386 ]) 3387 3388 def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method): 3389 if is_unbound_method or len(args) != 1: 3390 return node 3391 ustring = args[0] 3392 if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \ 3393 not ustring.arg.type.is_unicode_char: 3394 return node 3395 uchar = ustring.arg 3396 method_name = function.attribute 3397 function_name = 'Py_UNICODE_TO%s' % method_name.upper() 3398 func_call = self._substitute_method_call( 3399 node, function, 3400 function_name, self.PyUnicode_uchar_conversion_func_type, 3401 method_name, is_unbound_method, [uchar]) 3402 if node.type.is_pyobject: 3403 func_call = func_call.coerce_to_pyobject(self.current_env) 3404 return func_call 3405 3406 _handle_simple_method_unicode_lower = _inject_unicode_character_conversion 3407 _handle_simple_method_unicode_upper = _inject_unicode_character_conversion 3408 _handle_simple_method_unicode_title = _inject_unicode_character_conversion 3409 3410 PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType( 3411 Builtin.list_type, [ 3412 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), 3413 PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None), 3414 ]) 3415 3416 def _handle_simple_method_unicode_splitlines(self, node, function, args, is_unbound_method): 3417 """Replace unicode.splitlines(...) by a direct call to the 3418 corresponding C-API function. 3419 """ 3420 if len(args) not in (1,2): 3421 self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2") 3422 return node 3423 self._inject_bint_default_argument(node, args, 1, False) 3424 3425 return self._substitute_method_call( 3426 node, function, 3427 "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type, 3428 'splitlines', is_unbound_method, args) 3429 3430 PyUnicode_Split_func_type = PyrexTypes.CFuncType( 3431 Builtin.list_type, [ 3432 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), 3433 PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None), 3434 PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None), 3435 ] 3436 ) 3437 3438 def _handle_simple_method_unicode_split(self, node, function, args, is_unbound_method): 3439 """Replace unicode.split(...) by a direct call to the 3440 corresponding C-API function. 3441 """ 3442 if len(args) not in (1,2,3): 3443 self._error_wrong_arg_count('unicode.split', node, args, "1-3") 3444 return node 3445 if len(args) < 2: 3446 args.append(ExprNodes.NullNode(node.pos)) 3447 self._inject_int_default_argument( 3448 node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1") 3449 3450 return self._substitute_method_call( 3451 node, function, 3452 "PyUnicode_Split", self.PyUnicode_Split_func_type, 3453 'split', is_unbound_method, args) 3454 3455 PyUnicode_Join_func_type = PyrexTypes.CFuncType( 3456 Builtin.unicode_type, [ 3457 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), 3458 PyrexTypes.CFuncTypeArg("seq", PyrexTypes.py_object_type, None), 3459 ]) 3460 3461 def _handle_simple_method_unicode_join(self, node, function, args, is_unbound_method): 3462 """ 3463 unicode.join() builds a list first => see if we can do this more efficiently 3464 """ 3465 if len(args) != 2: 3466 self._error_wrong_arg_count('unicode.join', node, args, "2") 3467 return node 3468 if isinstance(args[1], ExprNodes.GeneratorExpressionNode): 3469 gen_expr_node = args[1] 3470 loop_node = gen_expr_node.loop 3471 3472 yield_statements = _find_yield_statements(loop_node) 3473 if yield_statements: 3474 inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode( 3475 node.pos, gen_expr_node, orig_func='list', 3476 comprehension_type=Builtin.list_type) 3477 3478 for yield_expression, yield_stat_node in yield_statements: 3479 append_node = ExprNodes.ComprehensionAppendNode( 3480 yield_expression.pos, 3481 expr=yield_expression, 3482 target=inlined_genexpr.target) 3483 3484 Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) 3485 3486 args[1] = inlined_genexpr 3487 3488 return self._substitute_method_call( 3489 node, function, 3490 "PyUnicode_Join", self.PyUnicode_Join_func_type, 3491 'join', is_unbound_method, args) 3492 3493 PyString_Tailmatch_func_type = PyrexTypes.CFuncType( 3494 PyrexTypes.c_bint_type, [ 3495 PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None), # bytes/str/unicode 3496 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), 3497 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), 3498 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), 3499 PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None), 3500 ], 3501 exception_value = '-1') 3502 3503 def _handle_simple_method_unicode_endswith(self, node, function, args, is_unbound_method): 3504 return self._inject_tailmatch( 3505 node, function, args, is_unbound_method, 'unicode', 'endswith', 3506 unicode_tailmatch_utility_code, +1) 3507 3508 def _handle_simple_method_unicode_startswith(self, node, function, args, is_unbound_method): 3509 return self._inject_tailmatch( 3510 node, function, args, is_unbound_method, 'unicode', 'startswith', 3511 unicode_tailmatch_utility_code, -1) 3512 3513 def _inject_tailmatch(self, node, function, args, is_unbound_method, type_name, 3514 method_name, utility_code, direction): 3515 """Replace unicode.startswith(...) and unicode.endswith(...) 3516 by a direct call to the corresponding C-API function. 3517 """ 3518 if len(args) not in (2,3,4): 3519 self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node, args, "2-4") 3520 return node 3521 self._inject_int_default_argument( 3522 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") 3523 self._inject_int_default_argument( 3524 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") 3525 args.append(ExprNodes.IntNode( 3526 node.pos, value=str(direction), type=PyrexTypes.c_int_type)) 3527 3528 method_call = self._substitute_method_call( 3529 node, function, 3530 "__Pyx_Py%s_Tailmatch" % type_name.capitalize(), 3531 self.PyString_Tailmatch_func_type, 3532 method_name, is_unbound_method, args, 3533 utility_code = utility_code) 3534 return method_call.coerce_to(Builtin.bool_type, self.current_env()) 3535 3536 PyUnicode_Find_func_type = PyrexTypes.CFuncType( 3537 PyrexTypes.c_py_ssize_t_type, [ 3538 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), 3539 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), 3540 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), 3541 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), 3542 PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None), 3543 ], 3544 exception_value = '-2') 3545 3546 def _handle_simple_method_unicode_find(self, node, function, args, is_unbound_method): 3547 return self._inject_unicode_find( 3548 node, function, args, is_unbound_method, 'find', +1) 3549 3550 def _handle_simple_method_unicode_rfind(self, node, function, args, is_unbound_method): 3551 return self._inject_unicode_find( 3552 node, function, args, is_unbound_method, 'rfind', -1) 3553 3554 def _inject_unicode_find(self, node, function, args, is_unbound_method, 3555 method_name, direction): 3556 """Replace unicode.find(...) and unicode.rfind(...) by a 3557 direct call to the corresponding C-API function. 3558 """ 3559 if len(args) not in (2,3,4): 3560 self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4") 3561 return node 3562 self._inject_int_default_argument( 3563 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") 3564 self._inject_int_default_argument( 3565 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") 3566 args.append(ExprNodes.IntNode( 3567 node.pos, value=str(direction), type=PyrexTypes.c_int_type)) 3568 3569 method_call = self._substitute_method_call( 3570 node, function, "PyUnicode_Find", self.PyUnicode_Find_func_type, 3571 method_name, is_unbound_method, args) 3572 return method_call.coerce_to_pyobject(self.current_env()) 3573 3574 PyUnicode_Count_func_type = PyrexTypes.CFuncType( 3575 PyrexTypes.c_py_ssize_t_type, [ 3576 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), 3577 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), 3578 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), 3579 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), 3580 ], 3581 exception_value = '-1') 3582 3583 def _handle_simple_method_unicode_count(self, node, function, args, is_unbound_method): 3584 """Replace unicode.count(...) by a direct call to the 3585 corresponding C-API function. 3586 """ 3587 if len(args) not in (2,3,4): 3588 self._error_wrong_arg_count('unicode.count', node, args, "2-4") 3589 return node 3590 self._inject_int_default_argument( 3591 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") 3592 self._inject_int_default_argument( 3593 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") 3594 3595 method_call = self._substitute_method_call( 3596 node, function, "PyUnicode_Count", self.PyUnicode_Count_func_type, 3597 'count', is_unbound_method, args) 3598 return method_call.coerce_to_pyobject(self.current_env()) 3599 3600 PyUnicode_Replace_func_type = PyrexTypes.CFuncType( 3601 Builtin.unicode_type, [ 3602 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), 3603 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), 3604 PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None), 3605 PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None), 3606 ]) 3607 3608 def _handle_simple_method_unicode_replace(self, node, function, args, is_unbound_method): 3609 """Replace unicode.replace(...) by a direct call to the 3610 corresponding C-API function. 3611 """ 3612 if len(args) not in (3,4): 3613 self._error_wrong_arg_count('unicode.replace', node, args, "3-4") 3614 return node 3615 self._inject_int_default_argument( 3616 node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1") 3617 3618 return self._substitute_method_call( 3619 node, function, "PyUnicode_Replace", self.PyUnicode_Replace_func_type, 3620 'replace', is_unbound_method, args) 3621 3622 PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType( 3623 Builtin.bytes_type, [ 3624 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), 3625 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), 3626 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), 3627 ]) 3628 3629 PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType( 3630 Builtin.bytes_type, [ 3631 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), 3632 ]) 3633 3634 _special_encodings = ['UTF8', 'UTF16', 'UTF-16LE', 'UTF-16BE', 'Latin1', 'ASCII', 3635 'unicode_escape', 'raw_unicode_escape'] 3636 3637 _special_codecs = [ (name, codecs.getencoder(name)) 3638 for name in _special_encodings ] 3639 3640 def _handle_simple_method_unicode_encode(self, node, function, args, is_unbound_method): 3641 """Replace unicode.encode(...) by a direct C-API call to the 3642 corresponding codec. 3643 """ 3644 if len(args) < 1 or len(args) > 3: 3645 self._error_wrong_arg_count('unicode.encode', node, args, '1-3') 3646 return node 3647 3648 string_node = args[0] 3649 3650 if len(args) == 1: 3651 null_node = ExprNodes.NullNode(node.pos) 3652 return self._substitute_method_call( 3653 node, function, "PyUnicode_AsEncodedString", 3654 self.PyUnicode_AsEncodedString_func_type, 3655 'encode', is_unbound_method, [string_node, null_node, null_node]) 3656 3657 parameters = self._unpack_encoding_and_error_mode(node.pos, args) 3658 if parameters is None: 3659 return node 3660 encoding, encoding_node, error_handling, error_handling_node = parameters 3661 3662 if encoding and isinstance(string_node, ExprNodes.UnicodeNode): 3663 # constant, so try to do the encoding at compile time 3664 try: 3665 value = string_node.value.encode(encoding, error_handling) 3666 except: 3667 # well, looks like we can't 3668 pass 3669 else: 3670 value = bytes_literal(value, encoding) 3671 return ExprNodes.BytesNode(string_node.pos, value=value, type=Builtin.bytes_type) 3672 3673 if encoding and error_handling == 'strict': 3674 # try to find a specific encoder function 3675 codec_name = self._find_special_codec_name(encoding) 3676 if codec_name is not None and '-' not in codec_name: 3677 encode_function = "PyUnicode_As%sString" % codec_name 3678 return self._substitute_method_call( 3679 node, function, encode_function, 3680 self.PyUnicode_AsXyzString_func_type, 3681 'encode', is_unbound_method, [string_node]) 3682 3683 return self._substitute_method_call( 3684 node, function, "PyUnicode_AsEncodedString", 3685 self.PyUnicode_AsEncodedString_func_type, 3686 'encode', is_unbound_method, 3687 [string_node, encoding_node, error_handling_node]) 3688 3689 PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType( 3690 Builtin.unicode_type, [ 3691 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None), 3692 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None), 3693 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), 3694 ])) 3695 3696 _decode_c_string_func_type = PyrexTypes.CFuncType( 3697 Builtin.unicode_type, [ 3698 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None), 3699 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), 3700 PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), 3701 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), 3702 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), 3703 PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None), 3704 ]) 3705 3706 _decode_bytes_func_type = PyrexTypes.CFuncType( 3707 Builtin.unicode_type, [ 3708 PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None), 3709 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), 3710 PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), 3711 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), 3712 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), 3713 PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None), 3714 ]) 3715 3716 _decode_cpp_string_func_type = None # lazy init 3717 3718 def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_method): 3719 """Replace char*.decode() by a direct C-API call to the 3720 corresponding codec, possibly resolving a slice on the char*. 3721 """ 3722 if not (1 <= len(args) <= 3): 3723 self._error_wrong_arg_count('bytes.decode', node, args, '1-3') 3724 return node 3725 3726 # normalise input nodes 3727 string_node = args[0] 3728 start = stop = None 3729 if isinstance(string_node, ExprNodes.SliceIndexNode): 3730 index_node = string_node 3731 string_node = index_node.base 3732 start, stop = index_node.start, index_node.stop 3733 if not start or start.constant_result == 0: 3734 start = None 3735 if isinstance(string_node, ExprNodes.CoerceToPyTypeNode): 3736 string_node = string_node.arg 3737 3738 string_type = string_node.type 3739 if string_type in (Builtin.bytes_type, Builtin.bytearray_type): 3740 if is_unbound_method: 3741 string_node = string_node.as_none_safe_node( 3742 "descriptor '%s' requires a '%s' object but received a 'NoneType'", 3743 format_args=['decode', string_type.name]) 3744 else: 3745 string_node = string_node.as_none_safe_node( 3746 "'NoneType' object has no attribute '%.30s'", 3747 error="PyExc_AttributeError", 3748 format_args=['decode']) 3749 elif not string_type.is_string and not string_type.is_cpp_string: 3750 # nothing to optimise here 3751 return node 3752 3753 parameters = self._unpack_encoding_and_error_mode(node.pos, args) 3754 if parameters is None: 3755 return node 3756 encoding, encoding_node, error_handling, error_handling_node = parameters 3757 3758 if not start: 3759 start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) 3760 elif not start.type.is_int: 3761 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) 3762 if stop and not stop.type.is_int: 3763 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) 3764 3765 # try to find a specific encoder function 3766 codec_name = None 3767 if encoding is not None: 3768 codec_name = self._find_special_codec_name(encoding) 3769 if codec_name is not None: 3770 if codec_name in ('UTF16', 'UTF-16LE', 'UTF-16BE'): 3771 codec_cname = "__Pyx_PyUnicode_Decode%s" % codec_name.replace('-', '') 3772 else: 3773 codec_cname = "PyUnicode_Decode%s" % codec_name 3774 decode_function = ExprNodes.RawCNameExprNode( 3775 node.pos, type=self.PyUnicode_DecodeXyz_func_ptr_type, cname=codec_cname) 3776 encoding_node = ExprNodes.NullNode(node.pos) 3777 else: 3778 decode_function = ExprNodes.NullNode(node.pos) 3779 3780 # build the helper function call 3781 temps = [] 3782 if string_type.is_string: 3783 # C string 3784 if not stop: 3785 # use strlen() to find the string length, just as CPython would 3786 if not string_node.is_name: 3787 string_node = UtilNodes.LetRefNode(string_node) # used twice 3788 temps.append(string_node) 3789 stop = ExprNodes.PythonCapiCallNode( 3790 string_node.pos, "strlen", self.Pyx_strlen_func_type, 3791 args=[string_node], 3792 is_temp=False, 3793 utility_code=UtilityCode.load_cached("IncludeStringH", "StringTools.c"), 3794 ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) 3795 helper_func_type = self._decode_c_string_func_type 3796 utility_code_name = 'decode_c_string' 3797 elif string_type.is_cpp_string: 3798 # C++ std::string 3799 if not stop: 3800 stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX', 3801 constant_result=ExprNodes.not_a_constant) 3802 if self._decode_cpp_string_func_type is None: 3803 # lazy init to reuse the C++ string type 3804 self._decode_cpp_string_func_type = PyrexTypes.CFuncType( 3805 Builtin.unicode_type, [ 3806 PyrexTypes.CFuncTypeArg("string", string_type, None), 3807 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), 3808 PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), 3809 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), 3810 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), 3811 PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None), 3812 ]) 3813 helper_func_type = self._decode_cpp_string_func_type 3814 utility_code_name = 'decode_cpp_string' 3815 else: 3816 # Python bytes/bytearray object 3817 if not stop: 3818 stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX', 3819 constant_result=ExprNodes.not_a_constant) 3820 helper_func_type = self._decode_bytes_func_type 3821 if string_type is Builtin.bytes_type: 3822 utility_code_name = 'decode_bytes' 3823 else: 3824 utility_code_name = 'decode_bytearray' 3825 3826 node = ExprNodes.PythonCapiCallNode( 3827 node.pos, '__Pyx_%s' % utility_code_name, helper_func_type, 3828 args=[string_node, start, stop, encoding_node, error_handling_node, decode_function], 3829 is_temp=node.is_temp, 3830 utility_code=UtilityCode.load_cached(utility_code_name, 'StringTools.c'), 3831 ) 3832 3833 for temp in temps[::-1]: 3834 node = UtilNodes.EvalWithTempExprNode(temp, node) 3835 return node 3836 3837 _handle_simple_method_bytearray_decode = _handle_simple_method_bytes_decode 3838 3839 def _find_special_codec_name(self, encoding): 3840 try: 3841 requested_codec = codecs.getencoder(encoding) 3842 except LookupError: 3843 return None 3844 for name, codec in self._special_codecs: 3845 if codec == requested_codec: 3846 if '_' in name: 3847 name = ''.join([s.capitalize() 3848 for s in name.split('_')]) 3849 return name 3850 return None 3851 3852 def _unpack_encoding_and_error_mode(self, pos, args): 3853 null_node = ExprNodes.NullNode(pos) 3854 3855 if len(args) >= 2: 3856 encoding, encoding_node = self._unpack_string_and_cstring_node(args[1]) 3857 if encoding_node is None: 3858 return None 3859 else: 3860 encoding = None 3861 encoding_node = null_node 3862 3863 if len(args) == 3: 3864 error_handling, error_handling_node = self._unpack_string_and_cstring_node(args[2]) 3865 if error_handling_node is None: 3866 return None 3867 if error_handling == 'strict': 3868 error_handling_node = null_node 3869 else: 3870 error_handling = 'strict' 3871 error_handling_node = null_node 3872 3873 return (encoding, encoding_node, error_handling, error_handling_node) 3874 3875 def _unpack_string_and_cstring_node(self, node): 3876 if isinstance(node, ExprNodes.CoerceToPyTypeNode): 3877 node = node.arg 3878 if isinstance(node, ExprNodes.UnicodeNode): 3879 encoding = node.value 3880 node = ExprNodes.BytesNode( 3881 node.pos, value=encoding.as_utf8_string(), type=PyrexTypes.c_const_char_ptr_type) 3882 elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)): 3883 encoding = node.value.decode('ISO-8859-1') 3884 node = ExprNodes.BytesNode( 3885 node.pos, value=node.value, type=PyrexTypes.c_const_char_ptr_type) 3886 elif node.type is Builtin.bytes_type: 3887 encoding = None 3888 node = node.coerce_to(PyrexTypes.c_const_char_ptr_type, self.current_env()) 3889 elif node.type.is_string: 3890 encoding = None 3891 else: 3892 encoding = node = None 3893 return encoding, node 3894 3895 def _handle_simple_method_str_endswith(self, node, function, args, is_unbound_method): 3896 return self._inject_tailmatch( 3897 node, function, args, is_unbound_method, 'str', 'endswith', 3898 str_tailmatch_utility_code, +1) 3899 3900 def _handle_simple_method_str_startswith(self, node, function, args, is_unbound_method): 3901 return self._inject_tailmatch( 3902 node, function, args, is_unbound_method, 'str', 'startswith', 3903 str_tailmatch_utility_code, -1) 3904 3905 def _handle_simple_method_bytes_endswith(self, node, function, args, is_unbound_method): 3906 return self._inject_tailmatch( 3907 node, function, args, is_unbound_method, 'bytes', 'endswith', 3908 bytes_tailmatch_utility_code, +1) 3909 3910 def _handle_simple_method_bytes_startswith(self, node, function, args, is_unbound_method): 3911 return self._inject_tailmatch( 3912 node, function, args, is_unbound_method, 'bytes', 'startswith', 3913 bytes_tailmatch_utility_code, -1) 3914 3915 ''' # disabled for now, enable when we consider it worth it (see StringTools.c) 3916 def _handle_simple_method_bytearray_endswith(self, node, function, args, is_unbound_method): 3917 return self._inject_tailmatch( 3918 node, function, args, is_unbound_method, 'bytearray', 'endswith', 3919 bytes_tailmatch_utility_code, +1) 3920 3921 def _handle_simple_method_bytearray_startswith(self, node, function, args, is_unbound_method): 3922 return self._inject_tailmatch( 3923 node, function, args, is_unbound_method, 'bytearray', 'startswith', 3924 bytes_tailmatch_utility_code, -1) 3925 ''' 3926 3927 ### helpers 3928 3929 def _substitute_method_call(self, node, function, name, func_type, 3930 attr_name, is_unbound_method, args=(), 3931 utility_code=None, is_temp=None, 3932 may_return_none=ExprNodes.PythonCapiCallNode.may_return_none, 3933 with_none_check=True): 3934 args = list(args) 3935 if with_none_check and args: 3936 args[0] = self._wrap_self_arg(args[0], function, is_unbound_method, attr_name) 3937 if is_temp is None: 3938 is_temp = node.is_temp 3939 return ExprNodes.PythonCapiCallNode( 3940 node.pos, name, func_type, 3941 args = args, 3942 is_temp = is_temp, 3943 utility_code = utility_code, 3944 may_return_none = may_return_none, 3945 result_is_used = node.result_is_used, 3946 ) 3947 3948 def _wrap_self_arg(self, self_arg, function, is_unbound_method, attr_name): 3949 if self_arg.is_literal: 3950 return self_arg 3951 if is_unbound_method: 3952 self_arg = self_arg.as_none_safe_node( 3953 "descriptor '%s' requires a '%s' object but received a 'NoneType'", 3954 format_args=[attr_name, self_arg.type.name]) 3955 else: 3956 self_arg = self_arg.as_none_safe_node( 3957 "'NoneType' object has no attribute '%{0}s'".format('.30' if len(attr_name) <= 30 else ''), 3958 error="PyExc_AttributeError", 3959 format_args=[attr_name]) 3960 return self_arg 3961 3962 def _inject_int_default_argument(self, node, args, arg_index, type, default_value): 3963 assert len(args) >= arg_index 3964 if len(args) == arg_index: 3965 args.append(ExprNodes.IntNode(node.pos, value=str(default_value), 3966 type=type, constant_result=default_value)) 3967 else: 3968 args[arg_index] = args[arg_index].coerce_to(type, self.current_env()) 3969 3970 def _inject_bint_default_argument(self, node, args, arg_index, default_value): 3971 assert len(args) >= arg_index 3972 if len(args) == arg_index: 3973 default_value = bool(default_value) 3974 args.append(ExprNodes.BoolNode(node.pos, value=default_value, 3975 constant_result=default_value)) 3976 else: 3977 args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env()) 3978 3979 3980unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c') 3981bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c') 3982str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringTools.c') 3983 3984 3985class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): 3986 """Calculate the result of constant expressions to store it in 3987 ``expr_node.constant_result``, and replace trivial cases by their 3988 constant result. 3989 3990 General rules: 3991 3992 - We calculate float constants to make them available to the 3993 compiler, but we do not aggregate them into a single literal 3994 node to prevent any loss of precision. 3995 3996 - We recursively calculate constants from non-literal nodes to 3997 make them available to the compiler, but we only aggregate 3998 literal nodes at each step. Non-literal nodes are never merged 3999 into a single node. 4000 """ 4001 4002 def __init__(self, reevaluate=False): 4003 """ 4004 The reevaluate argument specifies whether constant values that were 4005 previously computed should be recomputed. 4006 """ 4007 super(ConstantFolding, self).__init__() 4008 self.reevaluate = reevaluate 4009 4010 def _calculate_const(self, node): 4011 if (not self.reevaluate and 4012 node.constant_result is not ExprNodes.constant_value_not_set): 4013 return 4014 4015 # make sure we always set the value 4016 not_a_constant = ExprNodes.not_a_constant 4017 node.constant_result = not_a_constant 4018 4019 # check if all children are constant 4020 children = self.visitchildren(node) 4021 for child_result in children.values(): 4022 if type(child_result) is list: 4023 for child in child_result: 4024 if getattr(child, 'constant_result', not_a_constant) is not_a_constant: 4025 return 4026 elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant: 4027 return 4028 4029 # now try to calculate the real constant value 4030 try: 4031 node.calculate_constant_result() 4032# if node.constant_result is not ExprNodes.not_a_constant: 4033# print node.__class__.__name__, node.constant_result 4034 except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError): 4035 # ignore all 'normal' errors here => no constant result 4036 pass 4037 except Exception: 4038 # this looks like a real error 4039 import traceback, sys 4040 traceback.print_exc(file=sys.stdout) 4041 4042 NODE_TYPE_ORDER = [ExprNodes.BoolNode, ExprNodes.CharNode, 4043 ExprNodes.IntNode, ExprNodes.FloatNode] 4044 4045 def _widest_node_class(self, *nodes): 4046 try: 4047 return self.NODE_TYPE_ORDER[ 4048 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))] 4049 except ValueError: 4050 return None 4051 4052 def _bool_node(self, node, value): 4053 value = bool(value) 4054 return ExprNodes.BoolNode(node.pos, value=value, constant_result=value) 4055 4056 def visit_ExprNode(self, node): 4057 self._calculate_const(node) 4058 return node 4059 4060 def visit_UnopNode(self, node): 4061 self._calculate_const(node) 4062 if not node.has_constant_result(): 4063 if node.operator == '!': 4064 return self._handle_NotNode(node) 4065 return node 4066 if not node.operand.is_literal: 4067 return node 4068 if node.operator == '!': 4069 return self._bool_node(node, node.constant_result) 4070 elif isinstance(node.operand, ExprNodes.BoolNode): 4071 return ExprNodes.IntNode(node.pos, value=str(int(node.constant_result)), 4072 type=PyrexTypes.c_int_type, 4073 constant_result=int(node.constant_result)) 4074 elif node.operator == '+': 4075 return self._handle_UnaryPlusNode(node) 4076 elif node.operator == '-': 4077 return self._handle_UnaryMinusNode(node) 4078 return node 4079 4080 _negate_operator = { 4081 'in': 'not_in', 4082 'not_in': 'in', 4083 'is': 'is_not', 4084 'is_not': 'is' 4085 }.get 4086 4087 def _handle_NotNode(self, node): 4088 operand = node.operand 4089 if isinstance(operand, ExprNodes.PrimaryCmpNode): 4090 operator = self._negate_operator(operand.operator) 4091 if operator: 4092 node = copy.copy(operand) 4093 node.operator = operator 4094 node = self.visit_PrimaryCmpNode(node) 4095 return node 4096 4097 def _handle_UnaryMinusNode(self, node): 4098 def _negate(value): 4099 if value.startswith('-'): 4100 value = value[1:] 4101 else: 4102 value = '-' + value 4103 return value 4104 4105 node_type = node.operand.type 4106 if isinstance(node.operand, ExprNodes.FloatNode): 4107 # this is a safe operation 4108 return ExprNodes.FloatNode(node.pos, value=_negate(node.operand.value), 4109 type=node_type, 4110 constant_result=node.constant_result) 4111 if node_type.is_int and node_type.signed or \ 4112 isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject: 4113 return ExprNodes.IntNode(node.pos, value=_negate(node.operand.value), 4114 type=node_type, 4115 longness=node.operand.longness, 4116 constant_result=node.constant_result) 4117 return node 4118 4119 def _handle_UnaryPlusNode(self, node): 4120 if (node.operand.has_constant_result() and 4121 node.constant_result == node.operand.constant_result): 4122 return node.operand 4123 return node 4124 4125 def visit_BoolBinopNode(self, node): 4126 self._calculate_const(node) 4127 if not node.operand1.has_constant_result(): 4128 return node 4129 if node.operand1.constant_result: 4130 if node.operator == 'and': 4131 return node.operand2 4132 else: 4133 return node.operand1 4134 else: 4135 if node.operator == 'and': 4136 return node.operand1 4137 else: 4138 return node.operand2 4139 4140 def visit_BinopNode(self, node): 4141 self._calculate_const(node) 4142 if node.constant_result is ExprNodes.not_a_constant: 4143 return node 4144 if isinstance(node.constant_result, float): 4145 return node 4146 operand1, operand2 = node.operand1, node.operand2 4147 if not operand1.is_literal or not operand2.is_literal: 4148 return node 4149 4150 # now inject a new constant node with the calculated value 4151 try: 4152 type1, type2 = operand1.type, operand2.type 4153 if type1 is None or type2 is None: 4154 return node 4155 except AttributeError: 4156 return node 4157 4158 if type1.is_numeric and type2.is_numeric: 4159 widest_type = PyrexTypes.widest_numeric_type(type1, type2) 4160 else: 4161 widest_type = PyrexTypes.py_object_type 4162 4163 target_class = self._widest_node_class(operand1, operand2) 4164 if target_class is None: 4165 return node 4166 elif target_class is ExprNodes.BoolNode and node.operator in '+-//<<%**>>': 4167 # C arithmetic results in at least an int type 4168 target_class = ExprNodes.IntNode 4169 elif target_class is ExprNodes.CharNode and node.operator in '+-//<<%**>>&|^': 4170 # C arithmetic results in at least an int type 4171 target_class = ExprNodes.IntNode 4172 4173 if target_class is ExprNodes.IntNode: 4174 unsigned = getattr(operand1, 'unsigned', '') and \ 4175 getattr(operand2, 'unsigned', '') 4176 longness = "LL"[:max(len(getattr(operand1, 'longness', '')), 4177 len(getattr(operand2, 'longness', '')))] 4178 new_node = ExprNodes.IntNode(pos=node.pos, 4179 unsigned=unsigned, longness=longness, 4180 value=str(int(node.constant_result)), 4181 constant_result=int(node.constant_result)) 4182 # IntNode is smart about the type it chooses, so we just 4183 # make sure we were not smarter this time 4184 if widest_type.is_pyobject or new_node.type.is_pyobject: 4185 new_node.type = PyrexTypes.py_object_type 4186 else: 4187 new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type) 4188 else: 4189 if target_class is ExprNodes.BoolNode: 4190 node_value = node.constant_result 4191 else: 4192 node_value = str(node.constant_result) 4193 new_node = target_class(pos=node.pos, type = widest_type, 4194 value = node_value, 4195 constant_result = node.constant_result) 4196 return new_node 4197 4198 def visit_AddNode(self, node): 4199 self._calculate_const(node) 4200 if node.constant_result is ExprNodes.not_a_constant: 4201 return node 4202 if node.operand1.is_string_literal and node.operand2.is_string_literal: 4203 # some people combine string literals with a '+' 4204 str1, str2 = node.operand1, node.operand2 4205 if isinstance(str1, ExprNodes.UnicodeNode) and isinstance(str2, ExprNodes.UnicodeNode): 4206 bytes_value = None 4207 if str1.bytes_value is not None and str2.bytes_value is not None: 4208 if str1.bytes_value.encoding == str2.bytes_value.encoding: 4209 bytes_value = bytes_literal( 4210 str1.bytes_value + str2.bytes_value, 4211 str1.bytes_value.encoding) 4212 string_value = EncodedString(node.constant_result) 4213 return ExprNodes.UnicodeNode( 4214 str1.pos, value=string_value, constant_result=node.constant_result, bytes_value=bytes_value) 4215 elif isinstance(str1, ExprNodes.BytesNode) and isinstance(str2, ExprNodes.BytesNode): 4216 if str1.value.encoding == str2.value.encoding: 4217 bytes_value = bytes_literal(node.constant_result, str1.value.encoding) 4218 return ExprNodes.BytesNode(str1.pos, value=bytes_value, constant_result=node.constant_result) 4219 # all other combinations are rather complicated 4220 # to get right in Py2/3: encodings, unicode escapes, ... 4221 return self.visit_BinopNode(node) 4222 4223 def visit_MulNode(self, node): 4224 self._calculate_const(node) 4225 if node.operand1.is_sequence_constructor: 4226 return self._calculate_constant_seq(node, node.operand1, node.operand2) 4227 if isinstance(node.operand1, ExprNodes.IntNode) and \ 4228 node.operand2.is_sequence_constructor: 4229 return self._calculate_constant_seq(node, node.operand2, node.operand1) 4230 if node.operand1.is_string_literal: 4231 return self._multiply_string(node, node.operand1, node.operand2) 4232 elif node.operand2.is_string_literal: 4233 return self._multiply_string(node, node.operand2, node.operand1) 4234 return self.visit_BinopNode(node) 4235 4236 def _multiply_string(self, node, string_node, multiplier_node): 4237 multiplier = multiplier_node.constant_result 4238 if not isinstance(multiplier, _py_int_types): 4239 return node 4240 if not (node.has_constant_result() and isinstance(node.constant_result, _py_string_types)): 4241 return node 4242 if len(node.constant_result) > 256: 4243 # Too long for static creation, leave it to runtime. (-> arbitrary limit) 4244 return node 4245 4246 build_string = encoded_string 4247 if isinstance(string_node, ExprNodes.BytesNode): 4248 build_string = bytes_literal 4249 elif isinstance(string_node, ExprNodes.StringNode): 4250 if string_node.unicode_value is not None: 4251 string_node.unicode_value = encoded_string( 4252 string_node.unicode_value * multiplier, 4253 string_node.unicode_value.encoding) 4254 build_string = encoded_string if string_node.value.is_unicode else bytes_literal 4255 elif isinstance(string_node, ExprNodes.UnicodeNode): 4256 if string_node.bytes_value is not None: 4257 string_node.bytes_value = bytes_literal( 4258 string_node.bytes_value * multiplier, 4259 string_node.bytes_value.encoding) 4260 else: 4261 assert False, "unknown string node type: %s" % type(string_node) 4262 string_node.value = build_string( 4263 string_node.value * multiplier, 4264 string_node.value.encoding) 4265 # follow constant-folding and use unicode_value in preference 4266 if isinstance(string_node, ExprNodes.StringNode) and string_node.unicode_value is not None: 4267 string_node.constant_result = string_node.unicode_value 4268 else: 4269 string_node.constant_result = string_node.value 4270 return string_node 4271 4272 def _calculate_constant_seq(self, node, sequence_node, factor): 4273 if factor.constant_result != 1 and sequence_node.args: 4274 if isinstance(factor.constant_result, _py_int_types) and factor.constant_result <= 0: 4275 del sequence_node.args[:] 4276 sequence_node.mult_factor = None 4277 elif sequence_node.mult_factor is not None: 4278 if (isinstance(factor.constant_result, _py_int_types) and 4279 isinstance(sequence_node.mult_factor.constant_result, _py_int_types)): 4280 value = sequence_node.mult_factor.constant_result * factor.constant_result 4281 sequence_node.mult_factor = ExprNodes.IntNode( 4282 sequence_node.mult_factor.pos, 4283 value=str(value), constant_result=value) 4284 else: 4285 # don't know if we can combine the factors, so don't 4286 return self.visit_BinopNode(node) 4287 else: 4288 sequence_node.mult_factor = factor 4289 return sequence_node 4290 4291 def visit_ModNode(self, node): 4292 self.visitchildren(node) 4293 if isinstance(node.operand1, ExprNodes.UnicodeNode) and isinstance(node.operand2, ExprNodes.TupleNode): 4294 if not node.operand2.mult_factor: 4295 fstring = self._build_fstring(node.operand1.pos, node.operand1.value, node.operand2.args) 4296 if fstring is not None: 4297 return fstring 4298 return self.visit_BinopNode(node) 4299 4300 _parse_string_format_regex = ( 4301 u'(%(?:' # %... 4302 u'(?:[-0-9]+|[ ])?' # width (optional) or space prefix fill character (optional) 4303 u'(?:[.][0-9]+)?' # precision (optional) 4304 u')?.)' # format type (or something different for unsupported formats) 4305 ) 4306 4307 def _build_fstring(self, pos, ustring, format_args): 4308 # Issues formatting warnings instead of errors since we really only catch a few errors by accident. 4309 args = iter(format_args) 4310 substrings = [] 4311 can_be_optimised = True 4312 for s in re.split(self._parse_string_format_regex, ustring): 4313 if not s: 4314 continue 4315 if s == u'%%': 4316 substrings.append(ExprNodes.UnicodeNode(pos, value=EncodedString(u'%'), constant_result=u'%')) 4317 continue 4318 if s[0] != u'%': 4319 if s[-1] == u'%': 4320 warning(pos, "Incomplete format: '...%s'" % s[-3:], level=1) 4321 can_be_optimised = False 4322 substrings.append(ExprNodes.UnicodeNode(pos, value=EncodedString(s), constant_result=s)) 4323 continue 4324 format_type = s[-1] 4325 try: 4326 arg = next(args) 4327 except StopIteration: 4328 warning(pos, "Too few arguments for format placeholders", level=1) 4329 can_be_optimised = False 4330 break 4331 if arg.is_starred: 4332 can_be_optimised = False 4333 break 4334 if format_type in u'asrfdoxX': 4335 format_spec = s[1:] 4336 conversion_char = None 4337 if format_type in u'doxX' and u'.' in format_spec: 4338 # Precision is not allowed for integers in format(), but ok in %-formatting. 4339 can_be_optimised = False 4340 elif format_type in u'ars': 4341 format_spec = format_spec[:-1] 4342 conversion_char = format_type 4343 if format_spec.startswith('0'): 4344 format_spec = '>' + format_spec[1:] # right-alignment '%05s' spells '{:>5}' 4345 elif format_type == u'd': 4346 # '%d' formatting supports float, but '{obj:d}' does not => convert to int first. 4347 conversion_char = 'd' 4348 4349 if format_spec.startswith('-'): 4350 format_spec = '<' + format_spec[1:] # left-alignment '%-5s' spells '{:<5}' 4351 4352 substrings.append(ExprNodes.FormattedValueNode( 4353 arg.pos, value=arg, 4354 conversion_char=conversion_char, 4355 format_spec=ExprNodes.UnicodeNode( 4356 pos, value=EncodedString(format_spec), constant_result=format_spec) 4357 if format_spec else None, 4358 )) 4359 else: 4360 # keep it simple for now ... 4361 can_be_optimised = False 4362 break 4363 4364 if not can_be_optimised: 4365 # Print all warnings we can find before finally giving up here. 4366 return None 4367 4368 try: 4369 next(args) 4370 except StopIteration: pass 4371 else: 4372 warning(pos, "Too many arguments for format placeholders", level=1) 4373 return None 4374 4375 node = ExprNodes.JoinedStrNode(pos, values=substrings) 4376 return self.visit_JoinedStrNode(node) 4377 4378 def visit_FormattedValueNode(self, node): 4379 self.visitchildren(node) 4380 conversion_char = node.conversion_char or 's' 4381 if isinstance(node.format_spec, ExprNodes.UnicodeNode) and not node.format_spec.value: 4382 node.format_spec = None 4383 if node.format_spec is None and isinstance(node.value, ExprNodes.IntNode): 4384 value = EncodedString(node.value.value) 4385 if value.isdigit(): 4386 return ExprNodes.UnicodeNode(node.value.pos, value=value, constant_result=value) 4387 if node.format_spec is None and conversion_char == 's': 4388 value = None 4389 if isinstance(node.value, ExprNodes.UnicodeNode): 4390 value = node.value.value 4391 elif isinstance(node.value, ExprNodes.StringNode): 4392 value = node.value.unicode_value 4393 if value is not None: 4394 return ExprNodes.UnicodeNode(node.value.pos, value=value, constant_result=value) 4395 return node 4396 4397 def visit_JoinedStrNode(self, node): 4398 """ 4399 Clean up after the parser by discarding empty Unicode strings and merging 4400 substring sequences. Empty or single-value join lists are not uncommon 4401 because f-string format specs are always parsed into JoinedStrNodes. 4402 """ 4403 self.visitchildren(node) 4404 unicode_node = ExprNodes.UnicodeNode 4405 4406 values = [] 4407 for is_unode_group, substrings in itertools.groupby(node.values, lambda v: isinstance(v, unicode_node)): 4408 if is_unode_group: 4409 substrings = list(substrings) 4410 unode = substrings[0] 4411 if len(substrings) > 1: 4412 value = EncodedString(u''.join(value.value for value in substrings)) 4413 unode = ExprNodes.UnicodeNode(unode.pos, value=value, constant_result=value) 4414 # ignore empty Unicode strings 4415 if unode.value: 4416 values.append(unode) 4417 else: 4418 values.extend(substrings) 4419 4420 if not values: 4421 value = EncodedString('') 4422 node = ExprNodes.UnicodeNode(node.pos, value=value, constant_result=value) 4423 elif len(values) == 1: 4424 node = values[0] 4425 elif len(values) == 2: 4426 # reduce to string concatenation 4427 node = ExprNodes.binop_node(node.pos, '+', *values) 4428 else: 4429 node.values = values 4430 return node 4431 4432 def visit_MergedDictNode(self, node): 4433 """Unpack **args in place if we can.""" 4434 self.visitchildren(node) 4435 args = [] 4436 items = [] 4437 4438 def add(arg): 4439 if arg.is_dict_literal: 4440 if items: 4441 items[0].key_value_pairs.extend(arg.key_value_pairs) 4442 else: 4443 items.append(arg) 4444 elif isinstance(arg, ExprNodes.MergedDictNode): 4445 for child_arg in arg.keyword_args: 4446 add(child_arg) 4447 else: 4448 if items: 4449 args.append(items[0]) 4450 del items[:] 4451 args.append(arg) 4452 4453 for arg in node.keyword_args: 4454 add(arg) 4455 if items: 4456 args.append(items[0]) 4457 4458 if len(args) == 1: 4459 arg = args[0] 4460 if arg.is_dict_literal or isinstance(arg, ExprNodes.MergedDictNode): 4461 return arg 4462 node.keyword_args[:] = args 4463 self._calculate_const(node) 4464 return node 4465 4466 def visit_MergedSequenceNode(self, node): 4467 """Unpack *args in place if we can.""" 4468 self.visitchildren(node) 4469 4470 is_set = node.type is Builtin.set_type 4471 args = [] 4472 values = [] 4473 4474 def add(arg): 4475 if (is_set and arg.is_set_literal) or (arg.is_sequence_constructor and not arg.mult_factor): 4476 if values: 4477 values[0].args.extend(arg.args) 4478 else: 4479 values.append(arg) 4480 elif isinstance(arg, ExprNodes.MergedSequenceNode): 4481 for child_arg in arg.args: 4482 add(child_arg) 4483 else: 4484 if values: 4485 args.append(values[0]) 4486 del values[:] 4487 args.append(arg) 4488 4489 for arg in node.args: 4490 add(arg) 4491 if values: 4492 args.append(values[0]) 4493 4494 if len(args) == 1: 4495 arg = args[0] 4496 if ((is_set and arg.is_set_literal) or 4497 (arg.is_sequence_constructor and arg.type is node.type) or 4498 isinstance(arg, ExprNodes.MergedSequenceNode)): 4499 return arg 4500 node.args[:] = args 4501 self._calculate_const(node) 4502 return node 4503 4504 def visit_SequenceNode(self, node): 4505 """Unpack *args in place if we can.""" 4506 self.visitchildren(node) 4507 args = [] 4508 for arg in node.args: 4509 if not arg.is_starred: 4510 args.append(arg) 4511 elif arg.target.is_sequence_constructor and not arg.target.mult_factor: 4512 args.extend(arg.target.args) 4513 else: 4514 args.append(arg) 4515 node.args[:] = args 4516 self._calculate_const(node) 4517 return node 4518 4519 def visit_PrimaryCmpNode(self, node): 4520 # calculate constant partial results in the comparison cascade 4521 self.visitchildren(node, ['operand1']) 4522 left_node = node.operand1 4523 cmp_node = node 4524 while cmp_node is not None: 4525 self.visitchildren(cmp_node, ['operand2']) 4526 right_node = cmp_node.operand2 4527 cmp_node.constant_result = not_a_constant 4528 if left_node.has_constant_result() and right_node.has_constant_result(): 4529 try: 4530 cmp_node.calculate_cascaded_constant_result(left_node.constant_result) 4531 except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError): 4532 pass # ignore all 'normal' errors here => no constant result 4533 left_node = right_node 4534 cmp_node = cmp_node.cascade 4535 4536 if not node.cascade: 4537 if node.has_constant_result(): 4538 return self._bool_node(node, node.constant_result) 4539 return node 4540 4541 # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...] 4542 cascades = [[node.operand1]] 4543 final_false_result = [] 4544 4545 def split_cascades(cmp_node): 4546 if cmp_node.has_constant_result(): 4547 if not cmp_node.constant_result: 4548 # False => short-circuit 4549 final_false_result.append(self._bool_node(cmp_node, False)) 4550 return 4551 else: 4552 # True => discard and start new cascade 4553 cascades.append([cmp_node.operand2]) 4554 else: 4555 # not constant => append to current cascade 4556 cascades[-1].append(cmp_node) 4557 if cmp_node.cascade: 4558 split_cascades(cmp_node.cascade) 4559 4560 split_cascades(node) 4561 4562 cmp_nodes = [] 4563 for cascade in cascades: 4564 if len(cascade) < 2: 4565 continue 4566 cmp_node = cascade[1] 4567 pcmp_node = ExprNodes.PrimaryCmpNode( 4568 cmp_node.pos, 4569 operand1=cascade[0], 4570 operator=cmp_node.operator, 4571 operand2=cmp_node.operand2, 4572 constant_result=not_a_constant) 4573 cmp_nodes.append(pcmp_node) 4574 4575 last_cmp_node = pcmp_node 4576 for cmp_node in cascade[2:]: 4577 last_cmp_node.cascade = cmp_node 4578 last_cmp_node = cmp_node 4579 last_cmp_node.cascade = None 4580 4581 if final_false_result: 4582 # last cascade was constant False 4583 cmp_nodes.append(final_false_result[0]) 4584 elif not cmp_nodes: 4585 # only constants, but no False result 4586 return self._bool_node(node, True) 4587 node = cmp_nodes[0] 4588 if len(cmp_nodes) == 1: 4589 if node.has_constant_result(): 4590 return self._bool_node(node, node.constant_result) 4591 else: 4592 for cmp_node in cmp_nodes[1:]: 4593 node = ExprNodes.BoolBinopNode( 4594 node.pos, 4595 operand1=node, 4596 operator='and', 4597 operand2=cmp_node, 4598 constant_result=not_a_constant) 4599 return node 4600 4601 def visit_CondExprNode(self, node): 4602 self._calculate_const(node) 4603 if not node.test.has_constant_result(): 4604 return node 4605 if node.test.constant_result: 4606 return node.true_val 4607 else: 4608 return node.false_val 4609 4610 def visit_IfStatNode(self, node): 4611 self.visitchildren(node) 4612 # eliminate dead code based on constant condition results 4613 if_clauses = [] 4614 for if_clause in node.if_clauses: 4615 condition = if_clause.condition 4616 if condition.has_constant_result(): 4617 if condition.constant_result: 4618 # always true => subsequent clauses can safely be dropped 4619 node.else_clause = if_clause.body 4620 break 4621 # else: false => drop clause 4622 else: 4623 # unknown result => normal runtime evaluation 4624 if_clauses.append(if_clause) 4625 if if_clauses: 4626 node.if_clauses = if_clauses 4627 return node 4628 elif node.else_clause: 4629 return node.else_clause 4630 else: 4631 return Nodes.StatListNode(node.pos, stats=[]) 4632 4633 def visit_SliceIndexNode(self, node): 4634 self._calculate_const(node) 4635 # normalise start/stop values 4636 if node.start is None or node.start.constant_result is None: 4637 start = node.start = None 4638 else: 4639 start = node.start.constant_result 4640 if node.stop is None or node.stop.constant_result is None: 4641 stop = node.stop = None 4642 else: 4643 stop = node.stop.constant_result 4644 # cut down sliced constant sequences 4645 if node.constant_result is not not_a_constant: 4646 base = node.base 4647 if base.is_sequence_constructor and base.mult_factor is None: 4648 base.args = base.args[start:stop] 4649 return base 4650 elif base.is_string_literal: 4651 base = base.as_sliced_node(start, stop) 4652 if base is not None: 4653 return base 4654 return node 4655 4656 def visit_ComprehensionNode(self, node): 4657 self.visitchildren(node) 4658 if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats: 4659 # loop was pruned already => transform into literal 4660 if node.type is Builtin.list_type: 4661 return ExprNodes.ListNode( 4662 node.pos, args=[], constant_result=[]) 4663 elif node.type is Builtin.set_type: 4664 return ExprNodes.SetNode( 4665 node.pos, args=[], constant_result=set()) 4666 elif node.type is Builtin.dict_type: 4667 return ExprNodes.DictNode( 4668 node.pos, key_value_pairs=[], constant_result={}) 4669 return node 4670 4671 def visit_ForInStatNode(self, node): 4672 self.visitchildren(node) 4673 sequence = node.iterator.sequence 4674 if isinstance(sequence, ExprNodes.SequenceNode): 4675 if not sequence.args: 4676 if node.else_clause: 4677 return node.else_clause 4678 else: 4679 # don't break list comprehensions 4680 return Nodes.StatListNode(node.pos, stats=[]) 4681 # iterating over a list literal? => tuples are more efficient 4682 if isinstance(sequence, ExprNodes.ListNode): 4683 node.iterator.sequence = sequence.as_tuple() 4684 return node 4685 4686 def visit_WhileStatNode(self, node): 4687 self.visitchildren(node) 4688 if node.condition and node.condition.has_constant_result(): 4689 if node.condition.constant_result: 4690 node.condition = None 4691 node.else_clause = None 4692 else: 4693 return node.else_clause 4694 return node 4695 4696 def visit_ExprStatNode(self, node): 4697 self.visitchildren(node) 4698 if not isinstance(node.expr, ExprNodes.ExprNode): 4699 # ParallelRangeTransform does this ... 4700 return node 4701 # drop unused constant expressions 4702 if node.expr.has_constant_result(): 4703 return None 4704 return node 4705 4706 # in the future, other nodes can have their own handler method here 4707 # that can replace them with a constant result node 4708 4709 visit_Node = Visitor.VisitorTransform.recurse_to_children 4710 4711 4712class FinalOptimizePhase(Visitor.EnvTransform, Visitor.NodeRefCleanupMixin): 4713 """ 4714 This visitor handles several commuting optimizations, and is run 4715 just before the C code generation phase. 4716 4717 The optimizations currently implemented in this class are: 4718 - eliminate None assignment and refcounting for first assignment. 4719 - isinstance -> typecheck for cdef types 4720 - eliminate checks for None and/or types that became redundant after tree changes 4721 - eliminate useless string formatting steps 4722 - replace Python function calls that look like method calls by a faster PyMethodCallNode 4723 """ 4724 in_loop = False 4725 4726 def visit_SingleAssignmentNode(self, node): 4727 """Avoid redundant initialisation of local variables before their 4728 first assignment. 4729 """ 4730 self.visitchildren(node) 4731 if node.first: 4732 lhs = node.lhs 4733 lhs.lhs_of_first_assignment = True 4734 return node 4735 4736 def visit_SimpleCallNode(self, node): 4737 """ 4738 Replace generic calls to isinstance(x, type) by a more efficient type check. 4739 Replace likely Python method calls by a specialised PyMethodCallNode. 4740 """ 4741 self.visitchildren(node) 4742 function = node.function 4743 if function.type.is_cfunction and function.is_name: 4744 if function.name == 'isinstance' and len(node.args) == 2: 4745 type_arg = node.args[1] 4746 if type_arg.type.is_builtin_type and type_arg.type.name == 'type': 4747 cython_scope = self.context.cython_scope 4748 function.entry = cython_scope.lookup('PyObject_TypeCheck') 4749 function.type = function.entry.type 4750 PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('PyTypeObject').type) 4751 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr) 4752 elif (node.is_temp and function.type.is_pyobject and self.current_directives.get( 4753 "optimize.unpack_method_calls_in_pyinit" 4754 if not self.in_loop and self.current_env().is_module_scope 4755 else "optimize.unpack_method_calls")): 4756 # optimise simple Python methods calls 4757 if isinstance(node.arg_tuple, ExprNodes.TupleNode) and not ( 4758 node.arg_tuple.mult_factor or (node.arg_tuple.is_literal and len(node.arg_tuple.args) > 1)): 4759 # simple call, now exclude calls to objects that are definitely not methods 4760 may_be_a_method = True 4761 if function.type is Builtin.type_type: 4762 may_be_a_method = False 4763 elif function.is_attribute: 4764 if function.entry and function.entry.type.is_cfunction: 4765 # optimised builtin method 4766 may_be_a_method = False 4767 elif function.is_name: 4768 entry = function.entry 4769 if entry.is_builtin or entry.type.is_cfunction: 4770 may_be_a_method = False 4771 elif entry.cf_assignments: 4772 # local functions/classes are definitely not methods 4773 non_method_nodes = (ExprNodes.PyCFunctionNode, ExprNodes.ClassNode, ExprNodes.Py3ClassNode) 4774 may_be_a_method = any( 4775 assignment.rhs and not isinstance(assignment.rhs, non_method_nodes) 4776 for assignment in entry.cf_assignments) 4777 if may_be_a_method: 4778 if (node.self and function.is_attribute and 4779 isinstance(function.obj, ExprNodes.CloneNode) and function.obj.arg is node.self): 4780 # function self object was moved into a CloneNode => undo 4781 function.obj = function.obj.arg 4782 node = self.replace(node, ExprNodes.PyMethodCallNode.from_node( 4783 node, function=function, arg_tuple=node.arg_tuple, type=node.type)) 4784 return node 4785 4786 def visit_NumPyMethodCallNode(self, node): 4787 # Exclude from replacement above. 4788 self.visitchildren(node) 4789 return node 4790 4791 def visit_PyTypeTestNode(self, node): 4792 """Remove tests for alternatively allowed None values from 4793 type tests when we know that the argument cannot be None 4794 anyway. 4795 """ 4796 self.visitchildren(node) 4797 if not node.notnone: 4798 if not node.arg.may_be_none(): 4799 node.notnone = True 4800 return node 4801 4802 def visit_NoneCheckNode(self, node): 4803 """Remove None checks from expressions that definitely do not 4804 carry a None value. 4805 """ 4806 self.visitchildren(node) 4807 if not node.arg.may_be_none(): 4808 return node.arg 4809 return node 4810 4811 def visit_LoopNode(self, node): 4812 """Remember when we enter a loop as some expensive optimisations might still be worth it there. 4813 """ 4814 old_val = self.in_loop 4815 self.in_loop = True 4816 self.visitchildren(node) 4817 self.in_loop = old_val 4818 return node 4819 4820 4821class ConsolidateOverflowCheck(Visitor.CythonTransform): 4822 """ 4823 This class facilitates the sharing of overflow checking among all nodes 4824 of a nested arithmetic expression. For example, given the expression 4825 a*b + c, where a, b, and x are all possibly overflowing ints, the entire 4826 sequence will be evaluated and the overflow bit checked only at the end. 4827 """ 4828 overflow_bit_node = None 4829 4830 def visit_Node(self, node): 4831 if self.overflow_bit_node is not None: 4832 saved = self.overflow_bit_node 4833 self.overflow_bit_node = None 4834 self.visitchildren(node) 4835 self.overflow_bit_node = saved 4836 else: 4837 self.visitchildren(node) 4838 return node 4839 4840 def visit_NumBinopNode(self, node): 4841 if node.overflow_check and node.overflow_fold: 4842 top_level_overflow = self.overflow_bit_node is None 4843 if top_level_overflow: 4844 self.overflow_bit_node = node 4845 else: 4846 node.overflow_bit_node = self.overflow_bit_node 4847 node.overflow_check = False 4848 self.visitchildren(node) 4849 if top_level_overflow: 4850 self.overflow_bit_node = None 4851 else: 4852 self.visitchildren(node) 4853 return node 4854