1from __future__ import absolute_import 2 3import copy 4 5from . import (ExprNodes, PyrexTypes, MemoryView, 6 ParseTreeTransforms, StringEncoding, Errors) 7from .ExprNodes import CloneNode, ProxyNode, TupleNode 8from .Nodes import FuncDefNode, CFuncDefNode, StatListNode, DefNode 9from ..Utils import OrderedSet 10from .Errors import error, CannotSpecialize 11 12 13class FusedCFuncDefNode(StatListNode): 14 """ 15 This node replaces a function with fused arguments. It deep-copies the 16 function for every permutation of fused types, and allocates a new local 17 scope for it. It keeps track of the original function in self.node, and 18 the entry of the original function in the symbol table is given the 19 'fused_cfunction' attribute which points back to us. 20 Then when a function lookup occurs (to e.g. call it), the call can be 21 dispatched to the right function. 22 23 node FuncDefNode the original function 24 nodes [FuncDefNode] list of copies of node with different specific types 25 py_func DefNode the fused python function subscriptable from 26 Python space 27 __signatures__ A DictNode mapping signature specialization strings 28 to PyCFunction nodes 29 resulting_fused_function PyCFunction for the fused DefNode that delegates 30 to specializations 31 fused_func_assignment Assignment of the fused function to the function name 32 defaults_tuple TupleNode of defaults (letting PyCFunctionNode build 33 defaults would result in many different tuples) 34 specialized_pycfuncs List of synthesized pycfunction nodes for the 35 specializations 36 code_object CodeObjectNode shared by all specializations and the 37 fused function 38 39 fused_compound_types All fused (compound) types (e.g. floating[:]) 40 """ 41 42 __signatures__ = None 43 resulting_fused_function = None 44 fused_func_assignment = None 45 defaults_tuple = None 46 decorators = None 47 48 child_attrs = StatListNode.child_attrs + [ 49 '__signatures__', 'resulting_fused_function', 'fused_func_assignment'] 50 51 def __init__(self, node, env): 52 super(FusedCFuncDefNode, self).__init__(node.pos) 53 54 self.nodes = [] 55 self.node = node 56 57 is_def = isinstance(self.node, DefNode) 58 if is_def: 59 # self.node.decorators = [] 60 self.copy_def(env) 61 else: 62 self.copy_cdef(env) 63 64 # Perform some sanity checks. If anything fails, it's a bug 65 for n in self.nodes: 66 assert not n.entry.type.is_fused 67 assert not n.local_scope.return_type.is_fused 68 if node.return_type.is_fused: 69 assert not n.return_type.is_fused 70 71 if not is_def and n.cfunc_declarator.optional_arg_count: 72 assert n.type.op_arg_struct 73 74 node.entry.fused_cfunction = self 75 # Copy the nodes as AnalyseDeclarationsTransform will prepend 76 # self.py_func to self.stats, as we only want specialized 77 # CFuncDefNodes in self.nodes 78 self.stats = self.nodes[:] 79 80 def copy_def(self, env): 81 """ 82 Create a copy of the original def or lambda function for specialized 83 versions. 84 """ 85 fused_compound_types = PyrexTypes.unique( 86 [arg.type for arg in self.node.args if arg.type.is_fused]) 87 fused_types = self._get_fused_base_types(fused_compound_types) 88 permutations = PyrexTypes.get_all_specialized_permutations(fused_types) 89 90 self.fused_compound_types = fused_compound_types 91 92 if self.node.entry in env.pyfunc_entries: 93 env.pyfunc_entries.remove(self.node.entry) 94 95 for cname, fused_to_specific in permutations: 96 copied_node = copy.deepcopy(self.node) 97 # keep signature object identity for special casing in DefNode.analyse_declarations() 98 copied_node.entry.signature = self.node.entry.signature 99 100 self._specialize_function_args(copied_node.args, fused_to_specific) 101 copied_node.return_type = self.node.return_type.specialize( 102 fused_to_specific) 103 104 copied_node.analyse_declarations(env) 105 # copied_node.is_staticmethod = self.node.is_staticmethod 106 # copied_node.is_classmethod = self.node.is_classmethod 107 self.create_new_local_scope(copied_node, env, fused_to_specific) 108 self.specialize_copied_def(copied_node, cname, self.node.entry, 109 fused_to_specific, fused_compound_types) 110 111 PyrexTypes.specialize_entry(copied_node.entry, cname) 112 copied_node.entry.used = True 113 env.entries[copied_node.entry.name] = copied_node.entry 114 115 if not self.replace_fused_typechecks(copied_node): 116 break 117 118 self.orig_py_func = self.node 119 self.py_func = self.make_fused_cpdef(self.node, env, is_def=True) 120 121 def copy_cdef(self, env): 122 """ 123 Create a copy of the original c(p)def function for all specialized 124 versions. 125 """ 126 permutations = self.node.type.get_all_specialized_permutations() 127 # print 'Node %s has %d specializations:' % (self.node.entry.name, 128 # len(permutations)) 129 # import pprint; pprint.pprint([d for cname, d in permutations]) 130 131 # Prevent copying of the python function 132 self.orig_py_func = orig_py_func = self.node.py_func 133 self.node.py_func = None 134 if orig_py_func: 135 env.pyfunc_entries.remove(orig_py_func.entry) 136 137 fused_types = self.node.type.get_fused_types() 138 self.fused_compound_types = fused_types 139 140 new_cfunc_entries = [] 141 for cname, fused_to_specific in permutations: 142 copied_node = copy.deepcopy(self.node) 143 144 # Make the types in our CFuncType specific. 145 try: 146 type = copied_node.type.specialize(fused_to_specific) 147 except CannotSpecialize: 148 # unlike for the argument types, specializing the return type can fail 149 error(copied_node.pos, "Return type is a fused type that cannot " 150 "be determined from the function arguments") 151 self.py_func = None # this is just to let the compiler exit gracefully 152 return 153 entry = copied_node.entry 154 type.specialize_entry(entry, cname) 155 156 # Reuse existing Entries (e.g. from .pxd files). 157 for i, orig_entry in enumerate(env.cfunc_entries): 158 if entry.cname == orig_entry.cname and type.same_as_resolved_type(orig_entry.type): 159 copied_node.entry = env.cfunc_entries[i] 160 if not copied_node.entry.func_cname: 161 copied_node.entry.func_cname = entry.func_cname 162 entry = copied_node.entry 163 type = entry.type 164 break 165 else: 166 new_cfunc_entries.append(entry) 167 168 copied_node.type = type 169 entry.type, type.entry = type, entry 170 171 entry.used = (entry.used or 172 self.node.entry.defined_in_pxd or 173 env.is_c_class_scope or 174 entry.is_cmethod) 175 176 if self.node.cfunc_declarator.optional_arg_count: 177 self.node.cfunc_declarator.declare_optional_arg_struct( 178 type, env, fused_cname=cname) 179 180 copied_node.return_type = type.return_type 181 self.create_new_local_scope(copied_node, env, fused_to_specific) 182 183 # Make the argument types in the CFuncDeclarator specific 184 self._specialize_function_args(copied_node.cfunc_declarator.args, 185 fused_to_specific) 186 187 # If a cpdef, declare all specialized cpdefs (this 188 # also calls analyse_declarations) 189 copied_node.declare_cpdef_wrapper(env) 190 if copied_node.py_func: 191 env.pyfunc_entries.remove(copied_node.py_func.entry) 192 193 self.specialize_copied_def( 194 copied_node.py_func, cname, self.node.entry.as_variable, 195 fused_to_specific, fused_types) 196 197 if not self.replace_fused_typechecks(copied_node): 198 break 199 200 # replace old entry with new entries 201 try: 202 cindex = env.cfunc_entries.index(self.node.entry) 203 except ValueError: 204 env.cfunc_entries.extend(new_cfunc_entries) 205 else: 206 env.cfunc_entries[cindex:cindex+1] = new_cfunc_entries 207 208 if orig_py_func: 209 self.py_func = self.make_fused_cpdef(orig_py_func, env, 210 is_def=False) 211 else: 212 self.py_func = orig_py_func 213 214 def _get_fused_base_types(self, fused_compound_types): 215 """ 216 Get a list of unique basic fused types, from a list of 217 (possibly) compound fused types. 218 """ 219 base_types = [] 220 seen = set() 221 for fused_type in fused_compound_types: 222 fused_type.get_fused_types(result=base_types, seen=seen) 223 return base_types 224 225 def _specialize_function_args(self, args, fused_to_specific): 226 for arg in args: 227 if arg.type.is_fused: 228 arg.type = arg.type.specialize(fused_to_specific) 229 if arg.type.is_memoryviewslice: 230 arg.type.validate_memslice_dtype(arg.pos) 231 if arg.annotation: 232 # TODO might be nice if annotations were specialized instead? 233 # (Or might be hard to do reliably) 234 arg.annotation.untyped = True 235 236 def create_new_local_scope(self, node, env, f2s): 237 """ 238 Create a new local scope for the copied node and append it to 239 self.nodes. A new local scope is needed because the arguments with the 240 fused types are already in the local scope, and we need the specialized 241 entries created after analyse_declarations on each specialized version 242 of the (CFunc)DefNode. 243 f2s is a dict mapping each fused type to its specialized version 244 """ 245 node.create_local_scope(env) 246 node.local_scope.fused_to_specific = f2s 247 248 # This is copied from the original function, set it to false to 249 # stop recursion 250 node.has_fused_arguments = False 251 self.nodes.append(node) 252 253 def specialize_copied_def(self, node, cname, py_entry, f2s, fused_compound_types): 254 """Specialize the copy of a DefNode given the copied node, 255 the specialization cname and the original DefNode entry""" 256 fused_types = self._get_fused_base_types(fused_compound_types) 257 type_strings = [ 258 PyrexTypes.specialization_signature_string(fused_type, f2s) 259 for fused_type in fused_types 260 ] 261 262 node.specialized_signature_string = '|'.join(type_strings) 263 264 node.entry.pymethdef_cname = PyrexTypes.get_fused_cname( 265 cname, node.entry.pymethdef_cname) 266 node.entry.doc = py_entry.doc 267 node.entry.doc_cname = py_entry.doc_cname 268 269 def replace_fused_typechecks(self, copied_node): 270 """ 271 Branch-prune fused type checks like 272 273 if fused_t is int: 274 ... 275 276 Returns whether an error was issued and whether we should stop in 277 in order to prevent a flood of errors. 278 """ 279 num_errors = Errors.num_errors 280 transform = ParseTreeTransforms.ReplaceFusedTypeChecks( 281 copied_node.local_scope) 282 transform(copied_node) 283 284 if Errors.num_errors > num_errors: 285 return False 286 287 return True 288 289 def _fused_instance_checks(self, normal_types, pyx_code, env): 290 """ 291 Generate Cython code for instance checks, matching an object to 292 specialized types. 293 """ 294 for specialized_type in normal_types: 295 # all_numeric = all_numeric and specialized_type.is_numeric 296 pyx_code.context.update( 297 py_type_name=specialized_type.py_type_name(), 298 specialized_type_name=specialized_type.specialization_string, 299 ) 300 pyx_code.put_chunk( 301 u""" 302 if isinstance(arg, {{py_type_name}}): 303 dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'; break 304 """) 305 306 def _dtype_name(self, dtype): 307 if dtype.is_typedef: 308 return '___pyx_%s' % dtype 309 return str(dtype).replace(' ', '_') 310 311 def _dtype_type(self, dtype): 312 if dtype.is_typedef: 313 return self._dtype_name(dtype) 314 return str(dtype) 315 316 def _sizeof_dtype(self, dtype): 317 if dtype.is_pyobject: 318 return 'sizeof(void *)' 319 else: 320 return "sizeof(%s)" % self._dtype_type(dtype) 321 322 def _buffer_check_numpy_dtype_setup_cases(self, pyx_code): 323 "Setup some common cases to match dtypes against specializations" 324 if pyx_code.indenter("if kind in b'iu':"): 325 pyx_code.putln("pass") 326 pyx_code.named_insertion_point("dtype_int") 327 pyx_code.dedent() 328 329 if pyx_code.indenter("elif kind == b'f':"): 330 pyx_code.putln("pass") 331 pyx_code.named_insertion_point("dtype_float") 332 pyx_code.dedent() 333 334 if pyx_code.indenter("elif kind == b'c':"): 335 pyx_code.putln("pass") 336 pyx_code.named_insertion_point("dtype_complex") 337 pyx_code.dedent() 338 339 if pyx_code.indenter("elif kind == b'O':"): 340 pyx_code.putln("pass") 341 pyx_code.named_insertion_point("dtype_object") 342 pyx_code.dedent() 343 344 match = "dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'" 345 no_match = "dest_sig[{{dest_sig_idx}}] = None" 346 def _buffer_check_numpy_dtype(self, pyx_code, specialized_buffer_types, pythran_types): 347 """ 348 Match a numpy dtype object to the individual specializations. 349 """ 350 self._buffer_check_numpy_dtype_setup_cases(pyx_code) 351 352 for specialized_type in pythran_types+specialized_buffer_types: 353 final_type = specialized_type 354 if specialized_type.is_pythran_expr: 355 specialized_type = specialized_type.org_buffer 356 dtype = specialized_type.dtype 357 pyx_code.context.update( 358 itemsize_match=self._sizeof_dtype(dtype) + " == itemsize", 359 signed_match="not (%s_is_signed ^ dtype_signed)" % self._dtype_name(dtype), 360 dtype=dtype, 361 specialized_type_name=final_type.specialization_string) 362 363 dtypes = [ 364 (dtype.is_int, pyx_code.dtype_int), 365 (dtype.is_float, pyx_code.dtype_float), 366 (dtype.is_complex, pyx_code.dtype_complex) 367 ] 368 369 for dtype_category, codewriter in dtypes: 370 if dtype_category: 371 cond = '{{itemsize_match}} and (<Py_ssize_t>arg.ndim) == %d' % ( 372 specialized_type.ndim,) 373 if dtype.is_int: 374 cond += ' and {{signed_match}}' 375 376 if final_type.is_pythran_expr: 377 cond += ' and arg_is_pythran_compatible' 378 379 if codewriter.indenter("if %s:" % cond): 380 #codewriter.putln("print 'buffer match found based on numpy dtype'") 381 codewriter.putln(self.match) 382 codewriter.putln("break") 383 codewriter.dedent() 384 385 def _buffer_parse_format_string_check(self, pyx_code, decl_code, 386 specialized_type, env): 387 """ 388 For each specialized type, try to coerce the object to a memoryview 389 slice of that type. This means obtaining a buffer and parsing the 390 format string. 391 TODO: separate buffer acquisition from format parsing 392 """ 393 dtype = specialized_type.dtype 394 if specialized_type.is_buffer: 395 axes = [('direct', 'strided')] * specialized_type.ndim 396 else: 397 axes = specialized_type.axes 398 399 memslice_type = PyrexTypes.MemoryViewSliceType(dtype, axes) 400 memslice_type.create_from_py_utility_code(env) 401 pyx_code.context.update( 402 coerce_from_py_func=memslice_type.from_py_function, 403 dtype=dtype) 404 decl_code.putln( 405 "{{memviewslice_cname}} {{coerce_from_py_func}}(object, int)") 406 407 pyx_code.context.update( 408 specialized_type_name=specialized_type.specialization_string, 409 sizeof_dtype=self._sizeof_dtype(dtype)) 410 411 pyx_code.put_chunk( 412 u""" 413 # try {{dtype}} 414 if itemsize == -1 or itemsize == {{sizeof_dtype}}: 415 memslice = {{coerce_from_py_func}}(arg, 0) 416 if memslice.memview: 417 __PYX_XCLEAR_MEMVIEW(&memslice, 1) 418 # print 'found a match for the buffer through format parsing' 419 %s 420 break 421 else: 422 __pyx_PyErr_Clear() 423 """ % self.match) 424 425 def _buffer_checks(self, buffer_types, pythran_types, pyx_code, decl_code, env): 426 """ 427 Generate Cython code to match objects to buffer specializations. 428 First try to get a numpy dtype object and match it against the individual 429 specializations. If that fails, try naively to coerce the object 430 to each specialization, which obtains the buffer each time and tries 431 to match the format string. 432 """ 433 # The first thing to find a match in this loop breaks out of the loop 434 pyx_code.put_chunk( 435 u""" 436 """ + (u"arg_is_pythran_compatible = False" if pythran_types else u"") + u""" 437 if ndarray is not None: 438 if isinstance(arg, ndarray): 439 dtype = arg.dtype 440 """ + (u"arg_is_pythran_compatible = True" if pythran_types else u"") + u""" 441 elif __pyx_memoryview_check(arg): 442 arg_base = arg.base 443 if isinstance(arg_base, ndarray): 444 dtype = arg_base.dtype 445 else: 446 dtype = None 447 else: 448 dtype = None 449 450 itemsize = -1 451 if dtype is not None: 452 itemsize = dtype.itemsize 453 kind = ord(dtype.kind) 454 dtype_signed = kind == 'i' 455 """) 456 pyx_code.indent(2) 457 if pythran_types: 458 pyx_code.put_chunk( 459 u""" 460 # Pythran only supports the endianness of the current compiler 461 byteorder = dtype.byteorder 462 if byteorder == "<" and not __Pyx_Is_Little_Endian(): 463 arg_is_pythran_compatible = False 464 elif byteorder == ">" and __Pyx_Is_Little_Endian(): 465 arg_is_pythran_compatible = False 466 if arg_is_pythran_compatible: 467 cur_stride = itemsize 468 shape = arg.shape 469 strides = arg.strides 470 for i in range(arg.ndim-1, -1, -1): 471 if (<Py_ssize_t>strides[i]) != cur_stride: 472 arg_is_pythran_compatible = False 473 break 474 cur_stride *= <Py_ssize_t> shape[i] 475 else: 476 arg_is_pythran_compatible = not (arg.flags.f_contiguous and (<Py_ssize_t>arg.ndim) > 1) 477 """) 478 pyx_code.named_insertion_point("numpy_dtype_checks") 479 self._buffer_check_numpy_dtype(pyx_code, buffer_types, pythran_types) 480 pyx_code.dedent(2) 481 482 for specialized_type in buffer_types: 483 self._buffer_parse_format_string_check( 484 pyx_code, decl_code, specialized_type, env) 485 486 def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types, pythran_types): 487 """ 488 If we have any buffer specializations, write out some variable 489 declarations and imports. 490 """ 491 decl_code.put_chunk( 492 u""" 493 ctypedef struct {{memviewslice_cname}}: 494 void *memview 495 496 void __PYX_XCLEAR_MEMVIEW({{memviewslice_cname}} *, int have_gil) 497 bint __pyx_memoryview_check(object) 498 """) 499 500 pyx_code.local_variable_declarations.put_chunk( 501 u""" 502 cdef {{memviewslice_cname}} memslice 503 cdef Py_ssize_t itemsize 504 cdef bint dtype_signed 505 cdef char kind 506 507 itemsize = -1 508 """) 509 510 if pythran_types: 511 pyx_code.local_variable_declarations.put_chunk(u""" 512 cdef bint arg_is_pythran_compatible 513 cdef Py_ssize_t cur_stride 514 """) 515 516 pyx_code.imports.put_chunk( 517 u""" 518 cdef type ndarray 519 ndarray = __Pyx_ImportNumPyArrayTypeIfAvailable() 520 """) 521 522 seen_typedefs = set() 523 seen_int_dtypes = set() 524 for buffer_type in all_buffer_types: 525 dtype = buffer_type.dtype 526 dtype_name = self._dtype_name(dtype) 527 if dtype.is_typedef: 528 if dtype_name not in seen_typedefs: 529 seen_typedefs.add(dtype_name) 530 decl_code.putln( 531 'ctypedef %s %s "%s"' % (dtype.resolve(), dtype_name, 532 dtype.empty_declaration_code())) 533 534 if buffer_type.dtype.is_int: 535 if str(dtype) not in seen_int_dtypes: 536 seen_int_dtypes.add(str(dtype)) 537 pyx_code.context.update(dtype_name=dtype_name, 538 dtype_type=self._dtype_type(dtype)) 539 pyx_code.local_variable_declarations.put_chunk( 540 u""" 541 cdef bint {{dtype_name}}_is_signed 542 {{dtype_name}}_is_signed = not (<{{dtype_type}}> -1 > 0) 543 """) 544 545 def _split_fused_types(self, arg): 546 """ 547 Specialize fused types and split into normal types and buffer types. 548 """ 549 specialized_types = PyrexTypes.get_specialized_types(arg.type) 550 551 # Prefer long over int, etc by sorting (see type classes in PyrexTypes.py) 552 specialized_types.sort() 553 554 seen_py_type_names = set() 555 normal_types, buffer_types, pythran_types = [], [], [] 556 has_object_fallback = False 557 for specialized_type in specialized_types: 558 py_type_name = specialized_type.py_type_name() 559 if py_type_name: 560 if py_type_name in seen_py_type_names: 561 continue 562 seen_py_type_names.add(py_type_name) 563 if py_type_name == 'object': 564 has_object_fallback = True 565 else: 566 normal_types.append(specialized_type) 567 elif specialized_type.is_pythran_expr: 568 pythran_types.append(specialized_type) 569 elif specialized_type.is_buffer or specialized_type.is_memoryviewslice: 570 buffer_types.append(specialized_type) 571 572 return normal_types, buffer_types, pythran_types, has_object_fallback 573 574 def _unpack_argument(self, pyx_code): 575 pyx_code.put_chunk( 576 u""" 577 # PROCESSING ARGUMENT {{arg_tuple_idx}} 578 if {{arg_tuple_idx}} < len(<tuple>args): 579 arg = (<tuple>args)[{{arg_tuple_idx}}] 580 elif kwargs is not None and '{{arg.name}}' in <dict>kwargs: 581 arg = (<dict>kwargs)['{{arg.name}}'] 582 else: 583 {{if arg.default}} 584 arg = (<tuple>defaults)[{{default_idx}}] 585 {{else}} 586 {{if arg_tuple_idx < min_positional_args}} 587 raise TypeError("Expected at least %d argument%s, got %d" % ( 588 {{min_positional_args}}, {{'"s"' if min_positional_args != 1 else '""'}}, len(<tuple>args))) 589 {{else}} 590 raise TypeError("Missing keyword-only argument: '%s'" % "{{arg.default}}") 591 {{endif}} 592 {{endif}} 593 """) 594 595 def _fused_signature_index(self, pyx_code): 596 """ 597 Generate Cython code for constructing a persistent nested dictionary index of 598 fused type specialization signatures. 599 """ 600 pyx_code.put_chunk( 601 u""" 602 if not _fused_sigindex: 603 for sig in <dict>signatures: 604 sigindex_node = _fused_sigindex 605 *sig_series, last_type = sig.strip('()').split('|') 606 for sig_type in sig_series: 607 if sig_type not in sigindex_node: 608 sigindex_node[sig_type] = sigindex_node = {} 609 else: 610 sigindex_node = sigindex_node[sig_type] 611 sigindex_node[last_type] = sig 612 """ 613 ) 614 615 def make_fused_cpdef(self, orig_py_func, env, is_def): 616 """ 617 This creates the function that is indexable from Python and does 618 runtime dispatch based on the argument types. The function gets the 619 arg tuple and kwargs dict (or None) and the defaults tuple 620 as arguments from the Binding Fused Function's tp_call. 621 """ 622 from . import TreeFragment, Code, UtilityCode 623 624 fused_types = self._get_fused_base_types([ 625 arg.type for arg in self.node.args if arg.type.is_fused]) 626 627 context = { 628 'memviewslice_cname': MemoryView.memviewslice_cname, 629 'func_args': self.node.args, 630 'n_fused': len(fused_types), 631 'min_positional_args': 632 self.node.num_required_args - self.node.num_required_kw_args 633 if is_def else 634 sum(1 for arg in self.node.args if arg.default is None), 635 'name': orig_py_func.entry.name, 636 } 637 638 pyx_code = Code.PyxCodeWriter(context=context) 639 decl_code = Code.PyxCodeWriter(context=context) 640 decl_code.put_chunk( 641 u""" 642 cdef extern from *: 643 void __pyx_PyErr_Clear "PyErr_Clear" () 644 type __Pyx_ImportNumPyArrayTypeIfAvailable() 645 int __Pyx_Is_Little_Endian() 646 """) 647 decl_code.indent() 648 649 pyx_code.put_chunk( 650 u""" 651 def __pyx_fused_cpdef(signatures, args, kwargs, defaults, _fused_sigindex={}): 652 # FIXME: use a typed signature - currently fails badly because 653 # default arguments inherit the types we specify here! 654 655 cdef list search_list 656 657 cdef dict sn, sigindex_node 658 659 dest_sig = [None] * {{n_fused}} 660 661 if kwargs is not None and not kwargs: 662 kwargs = None 663 664 cdef Py_ssize_t i 665 666 # instance check body 667 """) 668 669 pyx_code.indent() # indent following code to function body 670 pyx_code.named_insertion_point("imports") 671 pyx_code.named_insertion_point("func_defs") 672 pyx_code.named_insertion_point("local_variable_declarations") 673 674 fused_index = 0 675 default_idx = 0 676 all_buffer_types = OrderedSet() 677 seen_fused_types = set() 678 for i, arg in enumerate(self.node.args): 679 if arg.type.is_fused: 680 arg_fused_types = arg.type.get_fused_types() 681 if len(arg_fused_types) > 1: 682 raise NotImplementedError("Determination of more than one fused base " 683 "type per argument is not implemented.") 684 fused_type = arg_fused_types[0] 685 686 if arg.type.is_fused and fused_type not in seen_fused_types: 687 seen_fused_types.add(fused_type) 688 689 context.update( 690 arg_tuple_idx=i, 691 arg=arg, 692 dest_sig_idx=fused_index, 693 default_idx=default_idx, 694 ) 695 696 normal_types, buffer_types, pythran_types, has_object_fallback = self._split_fused_types(arg) 697 self._unpack_argument(pyx_code) 698 699 # 'unrolled' loop, first match breaks out of it 700 if pyx_code.indenter("while 1:"): 701 if normal_types: 702 self._fused_instance_checks(normal_types, pyx_code, env) 703 if buffer_types or pythran_types: 704 env.use_utility_code(Code.UtilityCode.load_cached("IsLittleEndian", "ModuleSetupCode.c")) 705 self._buffer_checks(buffer_types, pythran_types, pyx_code, decl_code, env) 706 if has_object_fallback: 707 pyx_code.context.update(specialized_type_name='object') 708 pyx_code.putln(self.match) 709 else: 710 pyx_code.putln(self.no_match) 711 pyx_code.putln("break") 712 pyx_code.dedent() 713 714 fused_index += 1 715 all_buffer_types.update(buffer_types) 716 all_buffer_types.update(ty.org_buffer for ty in pythran_types) 717 718 if arg.default: 719 default_idx += 1 720 721 if all_buffer_types: 722 self._buffer_declarations(pyx_code, decl_code, all_buffer_types, pythran_types) 723 env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c")) 724 env.use_utility_code(Code.UtilityCode.load_cached("ImportNumPyArray", "ImportExport.c")) 725 726 self._fused_signature_index(pyx_code) 727 728 pyx_code.put_chunk( 729 u""" 730 sigindex_matches = [] 731 sigindex_candidates = [_fused_sigindex] 732 733 for dst_type in dest_sig: 734 found_matches = [] 735 found_candidates = [] 736 # Make two seperate lists: One for signature sub-trees 737 # with at least one definite match, and another for 738 # signature sub-trees with only ambiguous matches 739 # (where `dest_sig[i] is None`). 740 if dst_type is None: 741 for sn in sigindex_matches: 742 found_matches.extend(sn.values()) 743 for sn in sigindex_candidates: 744 found_candidates.extend(sn.values()) 745 else: 746 for search_list in (sigindex_matches, sigindex_candidates): 747 for sn in search_list: 748 if dst_type in sn: 749 found_matches.append(sn[dst_type]) 750 sigindex_matches = found_matches 751 sigindex_candidates = found_candidates 752 if not (found_matches or found_candidates): 753 break 754 755 candidates = sigindex_matches 756 757 if not candidates: 758 raise TypeError("No matching signature found") 759 elif len(candidates) > 1: 760 raise TypeError("Function call with ambiguous argument types") 761 else: 762 return (<dict>signatures)[candidates[0]] 763 """) 764 765 fragment_code = pyx_code.getvalue() 766 # print decl_code.getvalue() 767 # print fragment_code 768 from .Optimize import ConstantFolding 769 fragment = TreeFragment.TreeFragment( 770 fragment_code, level='module', pipeline=[ConstantFolding()]) 771 ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root) 772 UtilityCode.declare_declarations_in_scope( 773 decl_code.getvalue(), env.global_scope()) 774 ast.scope = env 775 # FIXME: for static methods of cdef classes, we build the wrong signature here: first arg becomes 'self' 776 ast.analyse_declarations(env) 777 py_func = ast.stats[-1] # the DefNode 778 self.fragment_scope = ast.scope 779 780 if isinstance(self.node, DefNode): 781 py_func.specialized_cpdefs = self.nodes[:] 782 else: 783 py_func.specialized_cpdefs = [n.py_func for n in self.nodes] 784 785 return py_func 786 787 def update_fused_defnode_entry(self, env): 788 copy_attributes = ( 789 'name', 'pos', 'cname', 'func_cname', 'pyfunc_cname', 790 'pymethdef_cname', 'doc', 'doc_cname', 'is_member', 791 'scope' 792 ) 793 794 entry = self.py_func.entry 795 796 for attr in copy_attributes: 797 setattr(entry, attr, 798 getattr(self.orig_py_func.entry, attr)) 799 800 self.py_func.name = self.orig_py_func.name 801 self.py_func.doc = self.orig_py_func.doc 802 803 env.entries.pop('__pyx_fused_cpdef', None) 804 if isinstance(self.node, DefNode): 805 env.entries[entry.name] = entry 806 else: 807 env.entries[entry.name].as_variable = entry 808 809 env.pyfunc_entries.append(entry) 810 811 self.py_func.entry.fused_cfunction = self 812 for node in self.nodes: 813 if isinstance(self.node, DefNode): 814 node.fused_py_func = self.py_func 815 else: 816 node.py_func.fused_py_func = self.py_func 817 node.entry.as_variable = entry 818 819 self.synthesize_defnodes() 820 self.stats.append(self.__signatures__) 821 822 def analyse_expressions(self, env): 823 """ 824 Analyse the expressions. Take care to only evaluate default arguments 825 once and clone the result for all specializations 826 """ 827 for fused_compound_type in self.fused_compound_types: 828 for fused_type in fused_compound_type.get_fused_types(): 829 for specialization_type in fused_type.types: 830 if specialization_type.is_complex: 831 specialization_type.create_declaration_utility_code(env) 832 833 if self.py_func: 834 self.__signatures__ = self.__signatures__.analyse_expressions(env) 835 self.py_func = self.py_func.analyse_expressions(env) 836 self.resulting_fused_function = self.resulting_fused_function.analyse_expressions(env) 837 self.fused_func_assignment = self.fused_func_assignment.analyse_expressions(env) 838 839 self.defaults = defaults = [] 840 841 for arg in self.node.args: 842 if arg.default: 843 arg.default = arg.default.analyse_expressions(env) 844 # coerce the argument to temp since CloneNode really requires a temp 845 defaults.append(ProxyNode(arg.default.coerce_to_temp(env))) 846 else: 847 defaults.append(None) 848 849 for i, stat in enumerate(self.stats): 850 stat = self.stats[i] = stat.analyse_expressions(env) 851 if isinstance(stat, FuncDefNode) and stat is not self.py_func: 852 # the dispatcher specifically doesn't want its defaults overriding 853 for arg, default in zip(stat.args, defaults): 854 if default is not None: 855 arg.default = CloneNode(default).analyse_expressions(env).coerce_to(arg.type, env) 856 857 if self.py_func: 858 args = [CloneNode(default) for default in defaults if default] 859 self.defaults_tuple = TupleNode(self.pos, args=args) 860 self.defaults_tuple = self.defaults_tuple.analyse_types(env, skip_children=True).coerce_to_pyobject(env) 861 self.defaults_tuple = ProxyNode(self.defaults_tuple) 862 self.code_object = ProxyNode(self.specialized_pycfuncs[0].code_object) 863 864 fused_func = self.resulting_fused_function.arg 865 fused_func.defaults_tuple = CloneNode(self.defaults_tuple) 866 fused_func.code_object = CloneNode(self.code_object) 867 868 for i, pycfunc in enumerate(self.specialized_pycfuncs): 869 pycfunc.code_object = CloneNode(self.code_object) 870 pycfunc = self.specialized_pycfuncs[i] = pycfunc.analyse_types(env) 871 pycfunc.defaults_tuple = CloneNode(self.defaults_tuple) 872 return self 873 874 def synthesize_defnodes(self): 875 """ 876 Create the __signatures__ dict of PyCFunctionNode specializations. 877 """ 878 if isinstance(self.nodes[0], CFuncDefNode): 879 nodes = [node.py_func for node in self.nodes] 880 else: 881 nodes = self.nodes 882 883 # For the moment, fused functions do not support METH_FASTCALL 884 for node in nodes: 885 node.entry.signature.use_fastcall = False 886 887 signatures = [StringEncoding.EncodedString(node.specialized_signature_string) 888 for node in nodes] 889 keys = [ExprNodes.StringNode(node.pos, value=sig) 890 for node, sig in zip(nodes, signatures)] 891 values = [ExprNodes.PyCFunctionNode.from_defnode(node, binding=True) 892 for node in nodes] 893 894 self.__signatures__ = ExprNodes.DictNode.from_pairs(self.pos, zip(keys, values)) 895 896 self.specialized_pycfuncs = values 897 for pycfuncnode in values: 898 pycfuncnode.is_specialization = True 899 900 def generate_function_definitions(self, env, code): 901 if self.py_func: 902 self.py_func.pymethdef_required = True 903 self.fused_func_assignment.generate_function_definitions(env, code) 904 905 for stat in self.stats: 906 if isinstance(stat, FuncDefNode) and stat.entry.used: 907 code.mark_pos(stat.pos) 908 stat.generate_function_definitions(env, code) 909 910 def generate_execution_code(self, code): 911 # Note: all def function specialization are wrapped in PyCFunction 912 # nodes in the self.__signatures__ dictnode. 913 for default in self.defaults: 914 if default is not None: 915 default.generate_evaluation_code(code) 916 917 if self.py_func: 918 self.defaults_tuple.generate_evaluation_code(code) 919 self.code_object.generate_evaluation_code(code) 920 921 for stat in self.stats: 922 code.mark_pos(stat.pos) 923 if isinstance(stat, ExprNodes.ExprNode): 924 stat.generate_evaluation_code(code) 925 else: 926 stat.generate_execution_code(code) 927 928 if self.__signatures__: 929 self.resulting_fused_function.generate_evaluation_code(code) 930 931 code.putln( 932 "((__pyx_FusedFunctionObject *) %s)->__signatures__ = %s;" % 933 (self.resulting_fused_function.result(), 934 self.__signatures__.result())) 935 self.__signatures__.generate_giveref(code) 936 self.__signatures__.generate_post_assignment_code(code) 937 self.__signatures__.free_temps(code) 938 939 self.fused_func_assignment.generate_execution_code(code) 940 941 # Dispose of results 942 self.resulting_fused_function.generate_disposal_code(code) 943 self.resulting_fused_function.free_temps(code) 944 self.defaults_tuple.generate_disposal_code(code) 945 self.defaults_tuple.free_temps(code) 946 self.code_object.generate_disposal_code(code) 947 self.code_object.free_temps(code) 948 949 for default in self.defaults: 950 if default is not None: 951 default.generate_disposal_code(code) 952 default.free_temps(code) 953 954 def annotate(self, code): 955 for stat in self.stats: 956 stat.annotate(code) 957