1from __future__ import absolute_import 2 3import copy 4 5from . import (ExprNodes, PyrexTypes, MemoryView, 6 ParseTreeTransforms, StringEncoding, Errors) 7from .ExprNodes import CloneNode, ProxyNode, TupleNode 8from .Nodes import FuncDefNode, CFuncDefNode, StatListNode, DefNode 9from ..Utils import OrderedSet 10 11 12class FusedCFuncDefNode(StatListNode): 13 """ 14 This node replaces a function with fused arguments. It deep-copies the 15 function for every permutation of fused types, and allocates a new local 16 scope for it. It keeps track of the original function in self.node, and 17 the entry of the original function in the symbol table is given the 18 'fused_cfunction' attribute which points back to us. 19 Then when a function lookup occurs (to e.g. call it), the call can be 20 dispatched to the right function. 21 22 node FuncDefNode the original function 23 nodes [FuncDefNode] list of copies of node with different specific types 24 py_func DefNode the fused python function subscriptable from 25 Python space 26 __signatures__ A DictNode mapping signature specialization strings 27 to PyCFunction nodes 28 resulting_fused_function PyCFunction for the fused DefNode that delegates 29 to specializations 30 fused_func_assignment Assignment of the fused function to the function name 31 defaults_tuple TupleNode of defaults (letting PyCFunctionNode build 32 defaults would result in many different tuples) 33 specialized_pycfuncs List of synthesized pycfunction nodes for the 34 specializations 35 code_object CodeObjectNode shared by all specializations and the 36 fused function 37 38 fused_compound_types All fused (compound) types (e.g. floating[:]) 39 """ 40 41 __signatures__ = None 42 resulting_fused_function = None 43 fused_func_assignment = None 44 defaults_tuple = None 45 decorators = None 46 47 child_attrs = StatListNode.child_attrs + [ 48 '__signatures__', 'resulting_fused_function', 'fused_func_assignment'] 49 50 def __init__(self, node, env): 51 super(FusedCFuncDefNode, self).__init__(node.pos) 52 53 self.nodes = [] 54 self.node = node 55 56 is_def = isinstance(self.node, DefNode) 57 if is_def: 58 # self.node.decorators = [] 59 self.copy_def(env) 60 else: 61 self.copy_cdef(env) 62 63 # Perform some sanity checks. If anything fails, it's a bug 64 for n in self.nodes: 65 assert not n.entry.type.is_fused 66 assert not n.local_scope.return_type.is_fused 67 if node.return_type.is_fused: 68 assert not n.return_type.is_fused 69 70 if not is_def and n.cfunc_declarator.optional_arg_count: 71 assert n.type.op_arg_struct 72 73 node.entry.fused_cfunction = self 74 # Copy the nodes as AnalyseDeclarationsTransform will prepend 75 # self.py_func to self.stats, as we only want specialized 76 # CFuncDefNodes in self.nodes 77 self.stats = self.nodes[:] 78 79 def copy_def(self, env): 80 """ 81 Create a copy of the original def or lambda function for specialized 82 versions. 83 """ 84 fused_compound_types = PyrexTypes.unique( 85 [arg.type for arg in self.node.args if arg.type.is_fused]) 86 fused_types = self._get_fused_base_types(fused_compound_types) 87 permutations = PyrexTypes.get_all_specialized_permutations(fused_types) 88 89 self.fused_compound_types = fused_compound_types 90 91 if self.node.entry in env.pyfunc_entries: 92 env.pyfunc_entries.remove(self.node.entry) 93 94 for cname, fused_to_specific in permutations: 95 copied_node = copy.deepcopy(self.node) 96 # keep signature object identity for special casing in DefNode.analyse_declarations() 97 copied_node.entry.signature = self.node.entry.signature 98 99 self._specialize_function_args(copied_node.args, fused_to_specific) 100 copied_node.return_type = self.node.return_type.specialize( 101 fused_to_specific) 102 103 copied_node.analyse_declarations(env) 104 # copied_node.is_staticmethod = self.node.is_staticmethod 105 # copied_node.is_classmethod = self.node.is_classmethod 106 self.create_new_local_scope(copied_node, env, fused_to_specific) 107 self.specialize_copied_def(copied_node, cname, self.node.entry, 108 fused_to_specific, fused_compound_types) 109 110 PyrexTypes.specialize_entry(copied_node.entry, cname) 111 copied_node.entry.used = True 112 env.entries[copied_node.entry.name] = copied_node.entry 113 114 if not self.replace_fused_typechecks(copied_node): 115 break 116 117 self.orig_py_func = self.node 118 self.py_func = self.make_fused_cpdef(self.node, env, is_def=True) 119 120 def copy_cdef(self, env): 121 """ 122 Create a copy of the original c(p)def function for all specialized 123 versions. 124 """ 125 permutations = self.node.type.get_all_specialized_permutations() 126 # print 'Node %s has %d specializations:' % (self.node.entry.name, 127 # len(permutations)) 128 # import pprint; pprint.pprint([d for cname, d in permutations]) 129 130 # Prevent copying of the python function 131 self.orig_py_func = orig_py_func = self.node.py_func 132 self.node.py_func = None 133 if orig_py_func: 134 env.pyfunc_entries.remove(orig_py_func.entry) 135 136 fused_types = self.node.type.get_fused_types() 137 self.fused_compound_types = fused_types 138 139 new_cfunc_entries = [] 140 for cname, fused_to_specific in permutations: 141 copied_node = copy.deepcopy(self.node) 142 143 # Make the types in our CFuncType specific. 144 type = copied_node.type.specialize(fused_to_specific) 145 entry = copied_node.entry 146 type.specialize_entry(entry, cname) 147 148 # Reuse existing Entries (e.g. from .pxd files). 149 for i, orig_entry in enumerate(env.cfunc_entries): 150 if entry.cname == orig_entry.cname and type.same_as_resolved_type(orig_entry.type): 151 copied_node.entry = env.cfunc_entries[i] 152 if not copied_node.entry.func_cname: 153 copied_node.entry.func_cname = entry.func_cname 154 entry = copied_node.entry 155 type = entry.type 156 break 157 else: 158 new_cfunc_entries.append(entry) 159 160 copied_node.type = type 161 entry.type, type.entry = type, entry 162 163 entry.used = (entry.used or 164 self.node.entry.defined_in_pxd or 165 env.is_c_class_scope or 166 entry.is_cmethod) 167 168 if self.node.cfunc_declarator.optional_arg_count: 169 self.node.cfunc_declarator.declare_optional_arg_struct( 170 type, env, fused_cname=cname) 171 172 copied_node.return_type = type.return_type 173 self.create_new_local_scope(copied_node, env, fused_to_specific) 174 175 # Make the argument types in the CFuncDeclarator specific 176 self._specialize_function_args(copied_node.cfunc_declarator.args, 177 fused_to_specific) 178 179 # If a cpdef, declare all specialized cpdefs (this 180 # also calls analyse_declarations) 181 copied_node.declare_cpdef_wrapper(env) 182 if copied_node.py_func: 183 env.pyfunc_entries.remove(copied_node.py_func.entry) 184 185 self.specialize_copied_def( 186 copied_node.py_func, cname, self.node.entry.as_variable, 187 fused_to_specific, fused_types) 188 189 if not self.replace_fused_typechecks(copied_node): 190 break 191 192 # replace old entry with new entries 193 try: 194 cindex = env.cfunc_entries.index(self.node.entry) 195 except ValueError: 196 env.cfunc_entries.extend(new_cfunc_entries) 197 else: 198 env.cfunc_entries[cindex:cindex+1] = new_cfunc_entries 199 200 if orig_py_func: 201 self.py_func = self.make_fused_cpdef(orig_py_func, env, 202 is_def=False) 203 else: 204 self.py_func = orig_py_func 205 206 def _get_fused_base_types(self, fused_compound_types): 207 """ 208 Get a list of unique basic fused types, from a list of 209 (possibly) compound fused types. 210 """ 211 base_types = [] 212 seen = set() 213 for fused_type in fused_compound_types: 214 fused_type.get_fused_types(result=base_types, seen=seen) 215 return base_types 216 217 def _specialize_function_args(self, args, fused_to_specific): 218 for arg in args: 219 if arg.type.is_fused: 220 arg.type = arg.type.specialize(fused_to_specific) 221 if arg.type.is_memoryviewslice: 222 arg.type.validate_memslice_dtype(arg.pos) 223 224 def create_new_local_scope(self, node, env, f2s): 225 """ 226 Create a new local scope for the copied node and append it to 227 self.nodes. A new local scope is needed because the arguments with the 228 fused types are already in the local scope, and we need the specialized 229 entries created after analyse_declarations on each specialized version 230 of the (CFunc)DefNode. 231 f2s is a dict mapping each fused type to its specialized version 232 """ 233 node.create_local_scope(env) 234 node.local_scope.fused_to_specific = f2s 235 236 # This is copied from the original function, set it to false to 237 # stop recursion 238 node.has_fused_arguments = False 239 self.nodes.append(node) 240 241 def specialize_copied_def(self, node, cname, py_entry, f2s, fused_compound_types): 242 """Specialize the copy of a DefNode given the copied node, 243 the specialization cname and the original DefNode entry""" 244 fused_types = self._get_fused_base_types(fused_compound_types) 245 type_strings = [ 246 PyrexTypes.specialization_signature_string(fused_type, f2s) 247 for fused_type in fused_types 248 ] 249 250 node.specialized_signature_string = '|'.join(type_strings) 251 252 node.entry.pymethdef_cname = PyrexTypes.get_fused_cname( 253 cname, node.entry.pymethdef_cname) 254 node.entry.doc = py_entry.doc 255 node.entry.doc_cname = py_entry.doc_cname 256 257 def replace_fused_typechecks(self, copied_node): 258 """ 259 Branch-prune fused type checks like 260 261 if fused_t is int: 262 ... 263 264 Returns whether an error was issued and whether we should stop in 265 in order to prevent a flood of errors. 266 """ 267 num_errors = Errors.num_errors 268 transform = ParseTreeTransforms.ReplaceFusedTypeChecks( 269 copied_node.local_scope) 270 transform(copied_node) 271 272 if Errors.num_errors > num_errors: 273 return False 274 275 return True 276 277 def _fused_instance_checks(self, normal_types, pyx_code, env): 278 """ 279 Generate Cython code for instance checks, matching an object to 280 specialized types. 281 """ 282 for specialized_type in normal_types: 283 # all_numeric = all_numeric and specialized_type.is_numeric 284 pyx_code.context.update( 285 py_type_name=specialized_type.py_type_name(), 286 specialized_type_name=specialized_type.specialization_string, 287 ) 288 pyx_code.put_chunk( 289 u""" 290 if isinstance(arg, {{py_type_name}}): 291 dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'; break 292 """) 293 294 def _dtype_name(self, dtype): 295 if dtype.is_typedef: 296 return '___pyx_%s' % dtype 297 return str(dtype).replace(' ', '_') 298 299 def _dtype_type(self, dtype): 300 if dtype.is_typedef: 301 return self._dtype_name(dtype) 302 return str(dtype) 303 304 def _sizeof_dtype(self, dtype): 305 if dtype.is_pyobject: 306 return 'sizeof(void *)' 307 else: 308 return "sizeof(%s)" % self._dtype_type(dtype) 309 310 def _buffer_check_numpy_dtype_setup_cases(self, pyx_code): 311 "Setup some common cases to match dtypes against specializations" 312 if pyx_code.indenter("if kind in b'iu':"): 313 pyx_code.putln("pass") 314 pyx_code.named_insertion_point("dtype_int") 315 pyx_code.dedent() 316 317 if pyx_code.indenter("elif kind == b'f':"): 318 pyx_code.putln("pass") 319 pyx_code.named_insertion_point("dtype_float") 320 pyx_code.dedent() 321 322 if pyx_code.indenter("elif kind == b'c':"): 323 pyx_code.putln("pass") 324 pyx_code.named_insertion_point("dtype_complex") 325 pyx_code.dedent() 326 327 if pyx_code.indenter("elif kind == b'O':"): 328 pyx_code.putln("pass") 329 pyx_code.named_insertion_point("dtype_object") 330 pyx_code.dedent() 331 332 match = "dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'" 333 no_match = "dest_sig[{{dest_sig_idx}}] = None" 334 def _buffer_check_numpy_dtype(self, pyx_code, specialized_buffer_types, pythran_types): 335 """ 336 Match a numpy dtype object to the individual specializations. 337 """ 338 self._buffer_check_numpy_dtype_setup_cases(pyx_code) 339 340 for specialized_type in pythran_types+specialized_buffer_types: 341 final_type = specialized_type 342 if specialized_type.is_pythran_expr: 343 specialized_type = specialized_type.org_buffer 344 dtype = specialized_type.dtype 345 pyx_code.context.update( 346 itemsize_match=self._sizeof_dtype(dtype) + " == itemsize", 347 signed_match="not (%s_is_signed ^ dtype_signed)" % self._dtype_name(dtype), 348 dtype=dtype, 349 specialized_type_name=final_type.specialization_string) 350 351 dtypes = [ 352 (dtype.is_int, pyx_code.dtype_int), 353 (dtype.is_float, pyx_code.dtype_float), 354 (dtype.is_complex, pyx_code.dtype_complex) 355 ] 356 357 for dtype_category, codewriter in dtypes: 358 if dtype_category: 359 cond = '{{itemsize_match}} and (<Py_ssize_t>arg.ndim) == %d' % ( 360 specialized_type.ndim,) 361 if dtype.is_int: 362 cond += ' and {{signed_match}}' 363 364 if final_type.is_pythran_expr: 365 cond += ' and arg_is_pythran_compatible' 366 367 if codewriter.indenter("if %s:" % cond): 368 #codewriter.putln("print 'buffer match found based on numpy dtype'") 369 codewriter.putln(self.match) 370 codewriter.putln("break") 371 codewriter.dedent() 372 373 def _buffer_parse_format_string_check(self, pyx_code, decl_code, 374 specialized_type, env): 375 """ 376 For each specialized type, try to coerce the object to a memoryview 377 slice of that type. This means obtaining a buffer and parsing the 378 format string. 379 TODO: separate buffer acquisition from format parsing 380 """ 381 dtype = specialized_type.dtype 382 if specialized_type.is_buffer: 383 axes = [('direct', 'strided')] * specialized_type.ndim 384 else: 385 axes = specialized_type.axes 386 387 memslice_type = PyrexTypes.MemoryViewSliceType(dtype, axes) 388 memslice_type.create_from_py_utility_code(env) 389 pyx_code.context.update( 390 coerce_from_py_func=memslice_type.from_py_function, 391 dtype=dtype) 392 decl_code.putln( 393 "{{memviewslice_cname}} {{coerce_from_py_func}}(object, int)") 394 395 pyx_code.context.update( 396 specialized_type_name=specialized_type.specialization_string, 397 sizeof_dtype=self._sizeof_dtype(dtype)) 398 399 pyx_code.put_chunk( 400 u""" 401 # try {{dtype}} 402 if itemsize == -1 or itemsize == {{sizeof_dtype}}: 403 memslice = {{coerce_from_py_func}}(arg, 0) 404 if memslice.memview: 405 __PYX_XDEC_MEMVIEW(&memslice, 1) 406 # print 'found a match for the buffer through format parsing' 407 %s 408 break 409 else: 410 __pyx_PyErr_Clear() 411 """ % self.match) 412 413 def _buffer_checks(self, buffer_types, pythran_types, pyx_code, decl_code, env): 414 """ 415 Generate Cython code to match objects to buffer specializations. 416 First try to get a numpy dtype object and match it against the individual 417 specializations. If that fails, try naively to coerce the object 418 to each specialization, which obtains the buffer each time and tries 419 to match the format string. 420 """ 421 # The first thing to find a match in this loop breaks out of the loop 422 pyx_code.put_chunk( 423 u""" 424 """ + (u"arg_is_pythran_compatible = False" if pythran_types else u"") + u""" 425 if ndarray is not None: 426 if isinstance(arg, ndarray): 427 dtype = arg.dtype 428 """ + (u"arg_is_pythran_compatible = True" if pythran_types else u"") + u""" 429 elif __pyx_memoryview_check(arg): 430 arg_base = arg.base 431 if isinstance(arg_base, ndarray): 432 dtype = arg_base.dtype 433 else: 434 dtype = None 435 else: 436 dtype = None 437 438 itemsize = -1 439 if dtype is not None: 440 itemsize = dtype.itemsize 441 kind = ord(dtype.kind) 442 dtype_signed = kind == 'i' 443 """) 444 pyx_code.indent(2) 445 if pythran_types: 446 pyx_code.put_chunk( 447 u""" 448 # Pythran only supports the endianness of the current compiler 449 byteorder = dtype.byteorder 450 if byteorder == "<" and not __Pyx_Is_Little_Endian(): 451 arg_is_pythran_compatible = False 452 elif byteorder == ">" and __Pyx_Is_Little_Endian(): 453 arg_is_pythran_compatible = False 454 if arg_is_pythran_compatible: 455 cur_stride = itemsize 456 shape = arg.shape 457 strides = arg.strides 458 for i in range(arg.ndim-1, -1, -1): 459 if (<Py_ssize_t>strides[i]) != cur_stride: 460 arg_is_pythran_compatible = False 461 break 462 cur_stride *= <Py_ssize_t> shape[i] 463 else: 464 arg_is_pythran_compatible = not (arg.flags.f_contiguous and (<Py_ssize_t>arg.ndim) > 1) 465 """) 466 pyx_code.named_insertion_point("numpy_dtype_checks") 467 self._buffer_check_numpy_dtype(pyx_code, buffer_types, pythran_types) 468 pyx_code.dedent(2) 469 470 for specialized_type in buffer_types: 471 self._buffer_parse_format_string_check( 472 pyx_code, decl_code, specialized_type, env) 473 474 def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types, pythran_types): 475 """ 476 If we have any buffer specializations, write out some variable 477 declarations and imports. 478 """ 479 decl_code.put_chunk( 480 u""" 481 ctypedef struct {{memviewslice_cname}}: 482 void *memview 483 484 void __PYX_XDEC_MEMVIEW({{memviewslice_cname}} *, int have_gil) 485 bint __pyx_memoryview_check(object) 486 """) 487 488 pyx_code.local_variable_declarations.put_chunk( 489 u""" 490 cdef {{memviewslice_cname}} memslice 491 cdef Py_ssize_t itemsize 492 cdef bint dtype_signed 493 cdef char kind 494 495 itemsize = -1 496 """) 497 498 if pythran_types: 499 pyx_code.local_variable_declarations.put_chunk(u""" 500 cdef bint arg_is_pythran_compatible 501 cdef Py_ssize_t cur_stride 502 """) 503 504 pyx_code.imports.put_chunk( 505 u""" 506 cdef type ndarray 507 ndarray = __Pyx_ImportNumPyArrayTypeIfAvailable() 508 """) 509 510 seen_typedefs = set() 511 seen_int_dtypes = set() 512 for buffer_type in all_buffer_types: 513 dtype = buffer_type.dtype 514 dtype_name = self._dtype_name(dtype) 515 if dtype.is_typedef: 516 if dtype_name not in seen_typedefs: 517 seen_typedefs.add(dtype_name) 518 decl_code.putln( 519 'ctypedef %s %s "%s"' % (dtype.resolve(), dtype_name, 520 dtype.empty_declaration_code())) 521 522 if buffer_type.dtype.is_int: 523 if str(dtype) not in seen_int_dtypes: 524 seen_int_dtypes.add(str(dtype)) 525 pyx_code.context.update(dtype_name=dtype_name, 526 dtype_type=self._dtype_type(dtype)) 527 pyx_code.local_variable_declarations.put_chunk( 528 u""" 529 cdef bint {{dtype_name}}_is_signed 530 {{dtype_name}}_is_signed = not (<{{dtype_type}}> -1 > 0) 531 """) 532 533 def _split_fused_types(self, arg): 534 """ 535 Specialize fused types and split into normal types and buffer types. 536 """ 537 specialized_types = PyrexTypes.get_specialized_types(arg.type) 538 539 # Prefer long over int, etc by sorting (see type classes in PyrexTypes.py) 540 specialized_types.sort() 541 542 seen_py_type_names = set() 543 normal_types, buffer_types, pythran_types = [], [], [] 544 has_object_fallback = False 545 for specialized_type in specialized_types: 546 py_type_name = specialized_type.py_type_name() 547 if py_type_name: 548 if py_type_name in seen_py_type_names: 549 continue 550 seen_py_type_names.add(py_type_name) 551 if py_type_name == 'object': 552 has_object_fallback = True 553 else: 554 normal_types.append(specialized_type) 555 elif specialized_type.is_pythran_expr: 556 pythran_types.append(specialized_type) 557 elif specialized_type.is_buffer or specialized_type.is_memoryviewslice: 558 buffer_types.append(specialized_type) 559 560 return normal_types, buffer_types, pythran_types, has_object_fallback 561 562 def _unpack_argument(self, pyx_code): 563 pyx_code.put_chunk( 564 u""" 565 # PROCESSING ARGUMENT {{arg_tuple_idx}} 566 if {{arg_tuple_idx}} < len(<tuple>args): 567 arg = (<tuple>args)[{{arg_tuple_idx}}] 568 elif kwargs is not None and '{{arg.name}}' in <dict>kwargs: 569 arg = (<dict>kwargs)['{{arg.name}}'] 570 else: 571 {{if arg.default}} 572 arg = (<tuple>defaults)[{{default_idx}}] 573 {{else}} 574 {{if arg_tuple_idx < min_positional_args}} 575 raise TypeError("Expected at least %d argument%s, got %d" % ( 576 {{min_positional_args}}, {{'"s"' if min_positional_args != 1 else '""'}}, len(<tuple>args))) 577 {{else}} 578 raise TypeError("Missing keyword-only argument: '%s'" % "{{arg.default}}") 579 {{endif}} 580 {{endif}} 581 """) 582 583 def make_fused_cpdef(self, orig_py_func, env, is_def): 584 """ 585 This creates the function that is indexable from Python and does 586 runtime dispatch based on the argument types. The function gets the 587 arg tuple and kwargs dict (or None) and the defaults tuple 588 as arguments from the Binding Fused Function's tp_call. 589 """ 590 from . import TreeFragment, Code, UtilityCode 591 592 fused_types = self._get_fused_base_types([ 593 arg.type for arg in self.node.args if arg.type.is_fused]) 594 595 context = { 596 'memviewslice_cname': MemoryView.memviewslice_cname, 597 'func_args': self.node.args, 598 'n_fused': len(fused_types), 599 'min_positional_args': 600 self.node.num_required_args - self.node.num_required_kw_args 601 if is_def else 602 sum(1 for arg in self.node.args if arg.default is None), 603 'name': orig_py_func.entry.name, 604 } 605 606 pyx_code = Code.PyxCodeWriter(context=context) 607 decl_code = Code.PyxCodeWriter(context=context) 608 decl_code.put_chunk( 609 u""" 610 cdef extern from *: 611 void __pyx_PyErr_Clear "PyErr_Clear" () 612 type __Pyx_ImportNumPyArrayTypeIfAvailable() 613 int __Pyx_Is_Little_Endian() 614 """) 615 decl_code.indent() 616 617 pyx_code.put_chunk( 618 u""" 619 def __pyx_fused_cpdef(signatures, args, kwargs, defaults): 620 # FIXME: use a typed signature - currently fails badly because 621 # default arguments inherit the types we specify here! 622 623 dest_sig = [None] * {{n_fused}} 624 625 if kwargs is not None and not kwargs: 626 kwargs = None 627 628 cdef Py_ssize_t i 629 630 # instance check body 631 """) 632 633 pyx_code.indent() # indent following code to function body 634 pyx_code.named_insertion_point("imports") 635 pyx_code.named_insertion_point("func_defs") 636 pyx_code.named_insertion_point("local_variable_declarations") 637 638 fused_index = 0 639 default_idx = 0 640 all_buffer_types = OrderedSet() 641 seen_fused_types = set() 642 for i, arg in enumerate(self.node.args): 643 if arg.type.is_fused: 644 arg_fused_types = arg.type.get_fused_types() 645 if len(arg_fused_types) > 1: 646 raise NotImplementedError("Determination of more than one fused base " 647 "type per argument is not implemented.") 648 fused_type = arg_fused_types[0] 649 650 if arg.type.is_fused and fused_type not in seen_fused_types: 651 seen_fused_types.add(fused_type) 652 653 context.update( 654 arg_tuple_idx=i, 655 arg=arg, 656 dest_sig_idx=fused_index, 657 default_idx=default_idx, 658 ) 659 660 normal_types, buffer_types, pythran_types, has_object_fallback = self._split_fused_types(arg) 661 self._unpack_argument(pyx_code) 662 663 # 'unrolled' loop, first match breaks out of it 664 if pyx_code.indenter("while 1:"): 665 if normal_types: 666 self._fused_instance_checks(normal_types, pyx_code, env) 667 if buffer_types or pythran_types: 668 env.use_utility_code(Code.UtilityCode.load_cached("IsLittleEndian", "ModuleSetupCode.c")) 669 self._buffer_checks(buffer_types, pythran_types, pyx_code, decl_code, env) 670 if has_object_fallback: 671 pyx_code.context.update(specialized_type_name='object') 672 pyx_code.putln(self.match) 673 else: 674 pyx_code.putln(self.no_match) 675 pyx_code.putln("break") 676 pyx_code.dedent() 677 678 fused_index += 1 679 all_buffer_types.update(buffer_types) 680 all_buffer_types.update(ty.org_buffer for ty in pythran_types) 681 682 if arg.default: 683 default_idx += 1 684 685 if all_buffer_types: 686 self._buffer_declarations(pyx_code, decl_code, all_buffer_types, pythran_types) 687 env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c")) 688 env.use_utility_code(Code.UtilityCode.load_cached("ImportNumPyArray", "ImportExport.c")) 689 690 pyx_code.put_chunk( 691 u""" 692 candidates = [] 693 for sig in <dict>signatures: 694 match_found = False 695 src_sig = sig.strip('()').split('|') 696 for i in range(len(dest_sig)): 697 dst_type = dest_sig[i] 698 if dst_type is not None: 699 if src_sig[i] == dst_type: 700 match_found = True 701 else: 702 match_found = False 703 break 704 705 if match_found: 706 candidates.append(sig) 707 708 if not candidates: 709 raise TypeError("No matching signature found") 710 elif len(candidates) > 1: 711 raise TypeError("Function call with ambiguous argument types") 712 else: 713 return (<dict>signatures)[candidates[0]] 714 """) 715 716 fragment_code = pyx_code.getvalue() 717 # print decl_code.getvalue() 718 # print fragment_code 719 from .Optimize import ConstantFolding 720 fragment = TreeFragment.TreeFragment( 721 fragment_code, level='module', pipeline=[ConstantFolding()]) 722 ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root) 723 UtilityCode.declare_declarations_in_scope( 724 decl_code.getvalue(), env.global_scope()) 725 ast.scope = env 726 # FIXME: for static methods of cdef classes, we build the wrong signature here: first arg becomes 'self' 727 ast.analyse_declarations(env) 728 py_func = ast.stats[-1] # the DefNode 729 self.fragment_scope = ast.scope 730 731 if isinstance(self.node, DefNode): 732 py_func.specialized_cpdefs = self.nodes[:] 733 else: 734 py_func.specialized_cpdefs = [n.py_func for n in self.nodes] 735 736 return py_func 737 738 def update_fused_defnode_entry(self, env): 739 copy_attributes = ( 740 'name', 'pos', 'cname', 'func_cname', 'pyfunc_cname', 741 'pymethdef_cname', 'doc', 'doc_cname', 'is_member', 742 'scope' 743 ) 744 745 entry = self.py_func.entry 746 747 for attr in copy_attributes: 748 setattr(entry, attr, 749 getattr(self.orig_py_func.entry, attr)) 750 751 self.py_func.name = self.orig_py_func.name 752 self.py_func.doc = self.orig_py_func.doc 753 754 env.entries.pop('__pyx_fused_cpdef', None) 755 if isinstance(self.node, DefNode): 756 env.entries[entry.name] = entry 757 else: 758 env.entries[entry.name].as_variable = entry 759 760 env.pyfunc_entries.append(entry) 761 762 self.py_func.entry.fused_cfunction = self 763 for node in self.nodes: 764 if isinstance(self.node, DefNode): 765 node.fused_py_func = self.py_func 766 else: 767 node.py_func.fused_py_func = self.py_func 768 node.entry.as_variable = entry 769 770 self.synthesize_defnodes() 771 self.stats.append(self.__signatures__) 772 773 def analyse_expressions(self, env): 774 """ 775 Analyse the expressions. Take care to only evaluate default arguments 776 once and clone the result for all specializations 777 """ 778 for fused_compound_type in self.fused_compound_types: 779 for fused_type in fused_compound_type.get_fused_types(): 780 for specialization_type in fused_type.types: 781 if specialization_type.is_complex: 782 specialization_type.create_declaration_utility_code(env) 783 784 if self.py_func: 785 self.__signatures__ = self.__signatures__.analyse_expressions(env) 786 self.py_func = self.py_func.analyse_expressions(env) 787 self.resulting_fused_function = self.resulting_fused_function.analyse_expressions(env) 788 self.fused_func_assignment = self.fused_func_assignment.analyse_expressions(env) 789 790 self.defaults = defaults = [] 791 792 for arg in self.node.args: 793 if arg.default: 794 arg.default = arg.default.analyse_expressions(env) 795 defaults.append(ProxyNode(arg.default)) 796 else: 797 defaults.append(None) 798 799 for i, stat in enumerate(self.stats): 800 stat = self.stats[i] = stat.analyse_expressions(env) 801 if isinstance(stat, FuncDefNode): 802 for arg, default in zip(stat.args, defaults): 803 if default is not None: 804 arg.default = CloneNode(default).coerce_to(arg.type, env) 805 806 if self.py_func: 807 args = [CloneNode(default) for default in defaults if default] 808 self.defaults_tuple = TupleNode(self.pos, args=args) 809 self.defaults_tuple = self.defaults_tuple.analyse_types(env, skip_children=True).coerce_to_pyobject(env) 810 self.defaults_tuple = ProxyNode(self.defaults_tuple) 811 self.code_object = ProxyNode(self.specialized_pycfuncs[0].code_object) 812 813 fused_func = self.resulting_fused_function.arg 814 fused_func.defaults_tuple = CloneNode(self.defaults_tuple) 815 fused_func.code_object = CloneNode(self.code_object) 816 817 for i, pycfunc in enumerate(self.specialized_pycfuncs): 818 pycfunc.code_object = CloneNode(self.code_object) 819 pycfunc = self.specialized_pycfuncs[i] = pycfunc.analyse_types(env) 820 pycfunc.defaults_tuple = CloneNode(self.defaults_tuple) 821 return self 822 823 def synthesize_defnodes(self): 824 """ 825 Create the __signatures__ dict of PyCFunctionNode specializations. 826 """ 827 if isinstance(self.nodes[0], CFuncDefNode): 828 nodes = [node.py_func for node in self.nodes] 829 else: 830 nodes = self.nodes 831 832 signatures = [StringEncoding.EncodedString(node.specialized_signature_string) 833 for node in nodes] 834 keys = [ExprNodes.StringNode(node.pos, value=sig) 835 for node, sig in zip(nodes, signatures)] 836 values = [ExprNodes.PyCFunctionNode.from_defnode(node, binding=True) 837 for node in nodes] 838 839 self.__signatures__ = ExprNodes.DictNode.from_pairs(self.pos, zip(keys, values)) 840 841 self.specialized_pycfuncs = values 842 for pycfuncnode in values: 843 pycfuncnode.is_specialization = True 844 845 def generate_function_definitions(self, env, code): 846 if self.py_func: 847 self.py_func.pymethdef_required = True 848 self.fused_func_assignment.generate_function_definitions(env, code) 849 850 for stat in self.stats: 851 if isinstance(stat, FuncDefNode) and stat.entry.used: 852 code.mark_pos(stat.pos) 853 stat.generate_function_definitions(env, code) 854 855 def generate_execution_code(self, code): 856 # Note: all def function specialization are wrapped in PyCFunction 857 # nodes in the self.__signatures__ dictnode. 858 for default in self.defaults: 859 if default is not None: 860 default.generate_evaluation_code(code) 861 862 if self.py_func: 863 self.defaults_tuple.generate_evaluation_code(code) 864 self.code_object.generate_evaluation_code(code) 865 866 for stat in self.stats: 867 code.mark_pos(stat.pos) 868 if isinstance(stat, ExprNodes.ExprNode): 869 stat.generate_evaluation_code(code) 870 else: 871 stat.generate_execution_code(code) 872 873 if self.__signatures__: 874 self.resulting_fused_function.generate_evaluation_code(code) 875 876 code.putln( 877 "((__pyx_FusedFunctionObject *) %s)->__signatures__ = %s;" % 878 (self.resulting_fused_function.result(), 879 self.__signatures__.result())) 880 code.put_giveref(self.__signatures__.result()) 881 self.__signatures__.generate_post_assignment_code(code) 882 self.__signatures__.free_temps(code) 883 884 self.fused_func_assignment.generate_execution_code(code) 885 886 # Dispose of results 887 self.resulting_fused_function.generate_disposal_code(code) 888 self.resulting_fused_function.free_temps(code) 889 self.defaults_tuple.generate_disposal_code(code) 890 self.defaults_tuple.free_temps(code) 891 self.code_object.generate_disposal_code(code) 892 self.code_object.free_temps(code) 893 894 for default in self.defaults: 895 if default is not None: 896 default.generate_disposal_code(code) 897 default.free_temps(code) 898 899 def annotate(self, code): 900 for stat in self.stats: 901 stat.annotate(code) 902