1from __future__ import absolute_import
2
3import copy
4
5from . import (ExprNodes, PyrexTypes, MemoryView,
6               ParseTreeTransforms, StringEncoding, Errors)
7from .ExprNodes import CloneNode, ProxyNode, TupleNode
8from .Nodes import FuncDefNode, CFuncDefNode, StatListNode, DefNode
9from ..Utils import OrderedSet
10
11
12class FusedCFuncDefNode(StatListNode):
13    """
14    This node replaces a function with fused arguments. It deep-copies the
15    function for every permutation of fused types, and allocates a new local
16    scope for it. It keeps track of the original function in self.node, and
17    the entry of the original function in the symbol table is given the
18    'fused_cfunction' attribute which points back to us.
19    Then when a function lookup occurs (to e.g. call it), the call can be
20    dispatched to the right function.
21
22    node    FuncDefNode    the original function
23    nodes   [FuncDefNode]  list of copies of node with different specific types
24    py_func DefNode        the fused python function subscriptable from
25                           Python space
26    __signatures__         A DictNode mapping signature specialization strings
27                           to PyCFunction nodes
28    resulting_fused_function  PyCFunction for the fused DefNode that delegates
29                              to specializations
30    fused_func_assignment   Assignment of the fused function to the function name
31    defaults_tuple          TupleNode of defaults (letting PyCFunctionNode build
32                            defaults would result in many different tuples)
33    specialized_pycfuncs    List of synthesized pycfunction nodes for the
34                            specializations
35    code_object             CodeObjectNode shared by all specializations and the
36                            fused function
37
38    fused_compound_types    All fused (compound) types (e.g. floating[:])
39    """
40
41    __signatures__ = None
42    resulting_fused_function = None
43    fused_func_assignment = None
44    defaults_tuple = None
45    decorators = None
46
47    child_attrs = StatListNode.child_attrs + [
48        '__signatures__', 'resulting_fused_function', 'fused_func_assignment']
49
50    def __init__(self, node, env):
51        super(FusedCFuncDefNode, self).__init__(node.pos)
52
53        self.nodes = []
54        self.node = node
55
56        is_def = isinstance(self.node, DefNode)
57        if is_def:
58            # self.node.decorators = []
59            self.copy_def(env)
60        else:
61            self.copy_cdef(env)
62
63        # Perform some sanity checks. If anything fails, it's a bug
64        for n in self.nodes:
65            assert not n.entry.type.is_fused
66            assert not n.local_scope.return_type.is_fused
67            if node.return_type.is_fused:
68                assert not n.return_type.is_fused
69
70            if not is_def and n.cfunc_declarator.optional_arg_count:
71                assert n.type.op_arg_struct
72
73        node.entry.fused_cfunction = self
74        # Copy the nodes as AnalyseDeclarationsTransform will prepend
75        # self.py_func to self.stats, as we only want specialized
76        # CFuncDefNodes in self.nodes
77        self.stats = self.nodes[:]
78
79    def copy_def(self, env):
80        """
81        Create a copy of the original def or lambda function for specialized
82        versions.
83        """
84        fused_compound_types = PyrexTypes.unique(
85            [arg.type for arg in self.node.args if arg.type.is_fused])
86        fused_types = self._get_fused_base_types(fused_compound_types)
87        permutations = PyrexTypes.get_all_specialized_permutations(fused_types)
88
89        self.fused_compound_types = fused_compound_types
90
91        if self.node.entry in env.pyfunc_entries:
92            env.pyfunc_entries.remove(self.node.entry)
93
94        for cname, fused_to_specific in permutations:
95            copied_node = copy.deepcopy(self.node)
96            # keep signature object identity for special casing in DefNode.analyse_declarations()
97            copied_node.entry.signature = self.node.entry.signature
98
99            self._specialize_function_args(copied_node.args, fused_to_specific)
100            copied_node.return_type = self.node.return_type.specialize(
101                                                    fused_to_specific)
102
103            copied_node.analyse_declarations(env)
104            # copied_node.is_staticmethod = self.node.is_staticmethod
105            # copied_node.is_classmethod = self.node.is_classmethod
106            self.create_new_local_scope(copied_node, env, fused_to_specific)
107            self.specialize_copied_def(copied_node, cname, self.node.entry,
108                                       fused_to_specific, fused_compound_types)
109
110            PyrexTypes.specialize_entry(copied_node.entry, cname)
111            copied_node.entry.used = True
112            env.entries[copied_node.entry.name] = copied_node.entry
113
114            if not self.replace_fused_typechecks(copied_node):
115                break
116
117        self.orig_py_func = self.node
118        self.py_func = self.make_fused_cpdef(self.node, env, is_def=True)
119
120    def copy_cdef(self, env):
121        """
122        Create a copy of the original c(p)def function for all specialized
123        versions.
124        """
125        permutations = self.node.type.get_all_specialized_permutations()
126        # print 'Node %s has %d specializations:' % (self.node.entry.name,
127        #                                            len(permutations))
128        # import pprint; pprint.pprint([d for cname, d in permutations])
129
130        # Prevent copying of the python function
131        self.orig_py_func = orig_py_func = self.node.py_func
132        self.node.py_func = None
133        if orig_py_func:
134            env.pyfunc_entries.remove(orig_py_func.entry)
135
136        fused_types = self.node.type.get_fused_types()
137        self.fused_compound_types = fused_types
138
139        new_cfunc_entries = []
140        for cname, fused_to_specific in permutations:
141            copied_node = copy.deepcopy(self.node)
142
143            # Make the types in our CFuncType specific.
144            type = copied_node.type.specialize(fused_to_specific)
145            entry = copied_node.entry
146            type.specialize_entry(entry, cname)
147
148            # Reuse existing Entries (e.g. from .pxd files).
149            for i, orig_entry in enumerate(env.cfunc_entries):
150                if entry.cname == orig_entry.cname and type.same_as_resolved_type(orig_entry.type):
151                    copied_node.entry = env.cfunc_entries[i]
152                    if not copied_node.entry.func_cname:
153                        copied_node.entry.func_cname = entry.func_cname
154                    entry = copied_node.entry
155                    type = entry.type
156                    break
157            else:
158                new_cfunc_entries.append(entry)
159
160            copied_node.type = type
161            entry.type, type.entry = type, entry
162
163            entry.used = (entry.used or
164                          self.node.entry.defined_in_pxd or
165                          env.is_c_class_scope or
166                          entry.is_cmethod)
167
168            if self.node.cfunc_declarator.optional_arg_count:
169                self.node.cfunc_declarator.declare_optional_arg_struct(
170                                           type, env, fused_cname=cname)
171
172            copied_node.return_type = type.return_type
173            self.create_new_local_scope(copied_node, env, fused_to_specific)
174
175            # Make the argument types in the CFuncDeclarator specific
176            self._specialize_function_args(copied_node.cfunc_declarator.args,
177                                           fused_to_specific)
178
179            # If a cpdef, declare all specialized cpdefs (this
180            # also calls analyse_declarations)
181            copied_node.declare_cpdef_wrapper(env)
182            if copied_node.py_func:
183                env.pyfunc_entries.remove(copied_node.py_func.entry)
184
185                self.specialize_copied_def(
186                        copied_node.py_func, cname, self.node.entry.as_variable,
187                        fused_to_specific, fused_types)
188
189            if not self.replace_fused_typechecks(copied_node):
190                break
191
192        # replace old entry with new entries
193        try:
194            cindex = env.cfunc_entries.index(self.node.entry)
195        except ValueError:
196            env.cfunc_entries.extend(new_cfunc_entries)
197        else:
198            env.cfunc_entries[cindex:cindex+1] = new_cfunc_entries
199
200        if orig_py_func:
201            self.py_func = self.make_fused_cpdef(orig_py_func, env,
202                                                 is_def=False)
203        else:
204            self.py_func = orig_py_func
205
206    def _get_fused_base_types(self, fused_compound_types):
207        """
208        Get a list of unique basic fused types, from a list of
209        (possibly) compound fused types.
210        """
211        base_types = []
212        seen = set()
213        for fused_type in fused_compound_types:
214            fused_type.get_fused_types(result=base_types, seen=seen)
215        return base_types
216
217    def _specialize_function_args(self, args, fused_to_specific):
218        for arg in args:
219            if arg.type.is_fused:
220                arg.type = arg.type.specialize(fused_to_specific)
221                if arg.type.is_memoryviewslice:
222                    arg.type.validate_memslice_dtype(arg.pos)
223
224    def create_new_local_scope(self, node, env, f2s):
225        """
226        Create a new local scope for the copied node and append it to
227        self.nodes. A new local scope is needed because the arguments with the
228        fused types are already in the local scope, and we need the specialized
229        entries created after analyse_declarations on each specialized version
230        of the (CFunc)DefNode.
231        f2s is a dict mapping each fused type to its specialized version
232        """
233        node.create_local_scope(env)
234        node.local_scope.fused_to_specific = f2s
235
236        # This is copied from the original function, set it to false to
237        # stop recursion
238        node.has_fused_arguments = False
239        self.nodes.append(node)
240
241    def specialize_copied_def(self, node, cname, py_entry, f2s, fused_compound_types):
242        """Specialize the copy of a DefNode given the copied node,
243        the specialization cname and the original DefNode entry"""
244        fused_types = self._get_fused_base_types(fused_compound_types)
245        type_strings = [
246            PyrexTypes.specialization_signature_string(fused_type, f2s)
247                for fused_type in fused_types
248        ]
249
250        node.specialized_signature_string = '|'.join(type_strings)
251
252        node.entry.pymethdef_cname = PyrexTypes.get_fused_cname(
253                                        cname, node.entry.pymethdef_cname)
254        node.entry.doc = py_entry.doc
255        node.entry.doc_cname = py_entry.doc_cname
256
257    def replace_fused_typechecks(self, copied_node):
258        """
259        Branch-prune fused type checks like
260
261            if fused_t is int:
262                ...
263
264        Returns whether an error was issued and whether we should stop in
265        in order to prevent a flood of errors.
266        """
267        num_errors = Errors.num_errors
268        transform = ParseTreeTransforms.ReplaceFusedTypeChecks(
269                                       copied_node.local_scope)
270        transform(copied_node)
271
272        if Errors.num_errors > num_errors:
273            return False
274
275        return True
276
277    def _fused_instance_checks(self, normal_types, pyx_code, env):
278        """
279        Generate Cython code for instance checks, matching an object to
280        specialized types.
281        """
282        for specialized_type in normal_types:
283            # all_numeric = all_numeric and specialized_type.is_numeric
284            pyx_code.context.update(
285                py_type_name=specialized_type.py_type_name(),
286                specialized_type_name=specialized_type.specialization_string,
287            )
288            pyx_code.put_chunk(
289                u"""
290                    if isinstance(arg, {{py_type_name}}):
291                        dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'; break
292                """)
293
294    def _dtype_name(self, dtype):
295        if dtype.is_typedef:
296            return '___pyx_%s' % dtype
297        return str(dtype).replace(' ', '_')
298
299    def _dtype_type(self, dtype):
300        if dtype.is_typedef:
301            return self._dtype_name(dtype)
302        return str(dtype)
303
304    def _sizeof_dtype(self, dtype):
305        if dtype.is_pyobject:
306            return 'sizeof(void *)'
307        else:
308            return "sizeof(%s)" % self._dtype_type(dtype)
309
310    def _buffer_check_numpy_dtype_setup_cases(self, pyx_code):
311        "Setup some common cases to match dtypes against specializations"
312        if pyx_code.indenter("if kind in b'iu':"):
313            pyx_code.putln("pass")
314            pyx_code.named_insertion_point("dtype_int")
315            pyx_code.dedent()
316
317        if pyx_code.indenter("elif kind == b'f':"):
318            pyx_code.putln("pass")
319            pyx_code.named_insertion_point("dtype_float")
320            pyx_code.dedent()
321
322        if pyx_code.indenter("elif kind == b'c':"):
323            pyx_code.putln("pass")
324            pyx_code.named_insertion_point("dtype_complex")
325            pyx_code.dedent()
326
327        if pyx_code.indenter("elif kind == b'O':"):
328            pyx_code.putln("pass")
329            pyx_code.named_insertion_point("dtype_object")
330            pyx_code.dedent()
331
332    match = "dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'"
333    no_match = "dest_sig[{{dest_sig_idx}}] = None"
334    def _buffer_check_numpy_dtype(self, pyx_code, specialized_buffer_types, pythran_types):
335        """
336        Match a numpy dtype object to the individual specializations.
337        """
338        self._buffer_check_numpy_dtype_setup_cases(pyx_code)
339
340        for specialized_type in pythran_types+specialized_buffer_types:
341            final_type = specialized_type
342            if specialized_type.is_pythran_expr:
343                specialized_type = specialized_type.org_buffer
344            dtype = specialized_type.dtype
345            pyx_code.context.update(
346                itemsize_match=self._sizeof_dtype(dtype) + " == itemsize",
347                signed_match="not (%s_is_signed ^ dtype_signed)" % self._dtype_name(dtype),
348                dtype=dtype,
349                specialized_type_name=final_type.specialization_string)
350
351            dtypes = [
352                (dtype.is_int, pyx_code.dtype_int),
353                (dtype.is_float, pyx_code.dtype_float),
354                (dtype.is_complex, pyx_code.dtype_complex)
355            ]
356
357            for dtype_category, codewriter in dtypes:
358                if dtype_category:
359                    cond = '{{itemsize_match}} and (<Py_ssize_t>arg.ndim) == %d' % (
360                                                    specialized_type.ndim,)
361                    if dtype.is_int:
362                        cond += ' and {{signed_match}}'
363
364                    if final_type.is_pythran_expr:
365                        cond += ' and arg_is_pythran_compatible'
366
367                    if codewriter.indenter("if %s:" % cond):
368                        #codewriter.putln("print 'buffer match found based on numpy dtype'")
369                        codewriter.putln(self.match)
370                        codewriter.putln("break")
371                        codewriter.dedent()
372
373    def _buffer_parse_format_string_check(self, pyx_code, decl_code,
374                                          specialized_type, env):
375        """
376        For each specialized type, try to coerce the object to a memoryview
377        slice of that type. This means obtaining a buffer and parsing the
378        format string.
379        TODO: separate buffer acquisition from format parsing
380        """
381        dtype = specialized_type.dtype
382        if specialized_type.is_buffer:
383            axes = [('direct', 'strided')] * specialized_type.ndim
384        else:
385            axes = specialized_type.axes
386
387        memslice_type = PyrexTypes.MemoryViewSliceType(dtype, axes)
388        memslice_type.create_from_py_utility_code(env)
389        pyx_code.context.update(
390            coerce_from_py_func=memslice_type.from_py_function,
391            dtype=dtype)
392        decl_code.putln(
393            "{{memviewslice_cname}} {{coerce_from_py_func}}(object, int)")
394
395        pyx_code.context.update(
396            specialized_type_name=specialized_type.specialization_string,
397            sizeof_dtype=self._sizeof_dtype(dtype))
398
399        pyx_code.put_chunk(
400            u"""
401                # try {{dtype}}
402                if itemsize == -1 or itemsize == {{sizeof_dtype}}:
403                    memslice = {{coerce_from_py_func}}(arg, 0)
404                    if memslice.memview:
405                        __PYX_XDEC_MEMVIEW(&memslice, 1)
406                        # print 'found a match for the buffer through format parsing'
407                        %s
408                        break
409                    else:
410                        __pyx_PyErr_Clear()
411            """ % self.match)
412
413    def _buffer_checks(self, buffer_types, pythran_types, pyx_code, decl_code, env):
414        """
415        Generate Cython code to match objects to buffer specializations.
416        First try to get a numpy dtype object and match it against the individual
417        specializations. If that fails, try naively to coerce the object
418        to each specialization, which obtains the buffer each time and tries
419        to match the format string.
420        """
421        # The first thing to find a match in this loop breaks out of the loop
422        pyx_code.put_chunk(
423            u"""
424                """ + (u"arg_is_pythran_compatible = False" if pythran_types else u"") + u"""
425                if ndarray is not None:
426                    if isinstance(arg, ndarray):
427                        dtype = arg.dtype
428                        """ + (u"arg_is_pythran_compatible = True" if pythran_types else u"") + u"""
429                    elif __pyx_memoryview_check(arg):
430                        arg_base = arg.base
431                        if isinstance(arg_base, ndarray):
432                            dtype = arg_base.dtype
433                        else:
434                            dtype = None
435                    else:
436                        dtype = None
437
438                    itemsize = -1
439                    if dtype is not None:
440                        itemsize = dtype.itemsize
441                        kind = ord(dtype.kind)
442                        dtype_signed = kind == 'i'
443            """)
444        pyx_code.indent(2)
445        if pythran_types:
446            pyx_code.put_chunk(
447                u"""
448                        # Pythran only supports the endianness of the current compiler
449                        byteorder = dtype.byteorder
450                        if byteorder == "<" and not __Pyx_Is_Little_Endian():
451                            arg_is_pythran_compatible = False
452                        elif byteorder == ">" and __Pyx_Is_Little_Endian():
453                            arg_is_pythran_compatible = False
454                        if arg_is_pythran_compatible:
455                            cur_stride = itemsize
456                            shape = arg.shape
457                            strides = arg.strides
458                            for i in range(arg.ndim-1, -1, -1):
459                                if (<Py_ssize_t>strides[i]) != cur_stride:
460                                    arg_is_pythran_compatible = False
461                                    break
462                                cur_stride *= <Py_ssize_t> shape[i]
463                            else:
464                                arg_is_pythran_compatible = not (arg.flags.f_contiguous and (<Py_ssize_t>arg.ndim) > 1)
465                """)
466        pyx_code.named_insertion_point("numpy_dtype_checks")
467        self._buffer_check_numpy_dtype(pyx_code, buffer_types, pythran_types)
468        pyx_code.dedent(2)
469
470        for specialized_type in buffer_types:
471            self._buffer_parse_format_string_check(
472                    pyx_code, decl_code, specialized_type, env)
473
474    def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types, pythran_types):
475        """
476        If we have any buffer specializations, write out some variable
477        declarations and imports.
478        """
479        decl_code.put_chunk(
480            u"""
481                ctypedef struct {{memviewslice_cname}}:
482                    void *memview
483
484                void __PYX_XDEC_MEMVIEW({{memviewslice_cname}} *, int have_gil)
485                bint __pyx_memoryview_check(object)
486            """)
487
488        pyx_code.local_variable_declarations.put_chunk(
489            u"""
490                cdef {{memviewslice_cname}} memslice
491                cdef Py_ssize_t itemsize
492                cdef bint dtype_signed
493                cdef char kind
494
495                itemsize = -1
496            """)
497
498        if pythran_types:
499            pyx_code.local_variable_declarations.put_chunk(u"""
500                cdef bint arg_is_pythran_compatible
501                cdef Py_ssize_t cur_stride
502            """)
503
504        pyx_code.imports.put_chunk(
505            u"""
506                cdef type ndarray
507                ndarray = __Pyx_ImportNumPyArrayTypeIfAvailable()
508            """)
509
510        seen_typedefs = set()
511        seen_int_dtypes = set()
512        for buffer_type in all_buffer_types:
513            dtype = buffer_type.dtype
514            dtype_name = self._dtype_name(dtype)
515            if dtype.is_typedef:
516                if dtype_name not in seen_typedefs:
517                    seen_typedefs.add(dtype_name)
518                    decl_code.putln(
519                        'ctypedef %s %s "%s"' % (dtype.resolve(), dtype_name,
520                                                 dtype.empty_declaration_code()))
521
522            if buffer_type.dtype.is_int:
523                if str(dtype) not in seen_int_dtypes:
524                    seen_int_dtypes.add(str(dtype))
525                    pyx_code.context.update(dtype_name=dtype_name,
526                                            dtype_type=self._dtype_type(dtype))
527                    pyx_code.local_variable_declarations.put_chunk(
528                        u"""
529                            cdef bint {{dtype_name}}_is_signed
530                            {{dtype_name}}_is_signed = not (<{{dtype_type}}> -1 > 0)
531                        """)
532
533    def _split_fused_types(self, arg):
534        """
535        Specialize fused types and split into normal types and buffer types.
536        """
537        specialized_types = PyrexTypes.get_specialized_types(arg.type)
538
539        # Prefer long over int, etc by sorting (see type classes in PyrexTypes.py)
540        specialized_types.sort()
541
542        seen_py_type_names = set()
543        normal_types, buffer_types, pythran_types = [], [], []
544        has_object_fallback = False
545        for specialized_type in specialized_types:
546            py_type_name = specialized_type.py_type_name()
547            if py_type_name:
548                if py_type_name in seen_py_type_names:
549                    continue
550                seen_py_type_names.add(py_type_name)
551                if py_type_name == 'object':
552                    has_object_fallback = True
553                else:
554                    normal_types.append(specialized_type)
555            elif specialized_type.is_pythran_expr:
556                pythran_types.append(specialized_type)
557            elif specialized_type.is_buffer or specialized_type.is_memoryviewslice:
558                buffer_types.append(specialized_type)
559
560        return normal_types, buffer_types, pythran_types, has_object_fallback
561
562    def _unpack_argument(self, pyx_code):
563        pyx_code.put_chunk(
564            u"""
565                # PROCESSING ARGUMENT {{arg_tuple_idx}}
566                if {{arg_tuple_idx}} < len(<tuple>args):
567                    arg = (<tuple>args)[{{arg_tuple_idx}}]
568                elif kwargs is not None and '{{arg.name}}' in <dict>kwargs:
569                    arg = (<dict>kwargs)['{{arg.name}}']
570                else:
571                {{if arg.default}}
572                    arg = (<tuple>defaults)[{{default_idx}}]
573                {{else}}
574                    {{if arg_tuple_idx < min_positional_args}}
575                        raise TypeError("Expected at least %d argument%s, got %d" % (
576                            {{min_positional_args}}, {{'"s"' if min_positional_args != 1 else '""'}}, len(<tuple>args)))
577                    {{else}}
578                        raise TypeError("Missing keyword-only argument: '%s'" % "{{arg.default}}")
579                    {{endif}}
580                {{endif}}
581            """)
582
583    def make_fused_cpdef(self, orig_py_func, env, is_def):
584        """
585        This creates the function that is indexable from Python and does
586        runtime dispatch based on the argument types. The function gets the
587        arg tuple and kwargs dict (or None) and the defaults tuple
588        as arguments from the Binding Fused Function's tp_call.
589        """
590        from . import TreeFragment, Code, UtilityCode
591
592        fused_types = self._get_fused_base_types([
593            arg.type for arg in self.node.args if arg.type.is_fused])
594
595        context = {
596            'memviewslice_cname': MemoryView.memviewslice_cname,
597            'func_args': self.node.args,
598            'n_fused': len(fused_types),
599            'min_positional_args':
600                self.node.num_required_args - self.node.num_required_kw_args
601                if is_def else
602                sum(1 for arg in self.node.args if arg.default is None),
603            'name': orig_py_func.entry.name,
604        }
605
606        pyx_code = Code.PyxCodeWriter(context=context)
607        decl_code = Code.PyxCodeWriter(context=context)
608        decl_code.put_chunk(
609            u"""
610                cdef extern from *:
611                    void __pyx_PyErr_Clear "PyErr_Clear" ()
612                    type __Pyx_ImportNumPyArrayTypeIfAvailable()
613                    int __Pyx_Is_Little_Endian()
614            """)
615        decl_code.indent()
616
617        pyx_code.put_chunk(
618            u"""
619                def __pyx_fused_cpdef(signatures, args, kwargs, defaults):
620                    # FIXME: use a typed signature - currently fails badly because
621                    #        default arguments inherit the types we specify here!
622
623                    dest_sig = [None] * {{n_fused}}
624
625                    if kwargs is not None and not kwargs:
626                        kwargs = None
627
628                    cdef Py_ssize_t i
629
630                    # instance check body
631            """)
632
633        pyx_code.indent() # indent following code to function body
634        pyx_code.named_insertion_point("imports")
635        pyx_code.named_insertion_point("func_defs")
636        pyx_code.named_insertion_point("local_variable_declarations")
637
638        fused_index = 0
639        default_idx = 0
640        all_buffer_types = OrderedSet()
641        seen_fused_types = set()
642        for i, arg in enumerate(self.node.args):
643            if arg.type.is_fused:
644                arg_fused_types = arg.type.get_fused_types()
645                if len(arg_fused_types) > 1:
646                    raise NotImplementedError("Determination of more than one fused base "
647                                              "type per argument is not implemented.")
648                fused_type = arg_fused_types[0]
649
650            if arg.type.is_fused and fused_type not in seen_fused_types:
651                seen_fused_types.add(fused_type)
652
653                context.update(
654                    arg_tuple_idx=i,
655                    arg=arg,
656                    dest_sig_idx=fused_index,
657                    default_idx=default_idx,
658                )
659
660                normal_types, buffer_types, pythran_types, has_object_fallback = self._split_fused_types(arg)
661                self._unpack_argument(pyx_code)
662
663                # 'unrolled' loop, first match breaks out of it
664                if pyx_code.indenter("while 1:"):
665                    if normal_types:
666                        self._fused_instance_checks(normal_types, pyx_code, env)
667                    if buffer_types or pythran_types:
668                        env.use_utility_code(Code.UtilityCode.load_cached("IsLittleEndian", "ModuleSetupCode.c"))
669                        self._buffer_checks(buffer_types, pythran_types, pyx_code, decl_code, env)
670                    if has_object_fallback:
671                        pyx_code.context.update(specialized_type_name='object')
672                        pyx_code.putln(self.match)
673                    else:
674                        pyx_code.putln(self.no_match)
675                    pyx_code.putln("break")
676                    pyx_code.dedent()
677
678                fused_index += 1
679                all_buffer_types.update(buffer_types)
680                all_buffer_types.update(ty.org_buffer for ty in pythran_types)
681
682            if arg.default:
683                default_idx += 1
684
685        if all_buffer_types:
686            self._buffer_declarations(pyx_code, decl_code, all_buffer_types, pythran_types)
687            env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c"))
688            env.use_utility_code(Code.UtilityCode.load_cached("ImportNumPyArray", "ImportExport.c"))
689
690        pyx_code.put_chunk(
691            u"""
692                candidates = []
693                for sig in <dict>signatures:
694                    match_found = False
695                    src_sig = sig.strip('()').split('|')
696                    for i in range(len(dest_sig)):
697                        dst_type = dest_sig[i]
698                        if dst_type is not None:
699                            if src_sig[i] == dst_type:
700                                match_found = True
701                            else:
702                                match_found = False
703                                break
704
705                    if match_found:
706                        candidates.append(sig)
707
708                if not candidates:
709                    raise TypeError("No matching signature found")
710                elif len(candidates) > 1:
711                    raise TypeError("Function call with ambiguous argument types")
712                else:
713                    return (<dict>signatures)[candidates[0]]
714            """)
715
716        fragment_code = pyx_code.getvalue()
717        # print decl_code.getvalue()
718        # print fragment_code
719        from .Optimize import ConstantFolding
720        fragment = TreeFragment.TreeFragment(
721            fragment_code, level='module', pipeline=[ConstantFolding()])
722        ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root)
723        UtilityCode.declare_declarations_in_scope(
724            decl_code.getvalue(), env.global_scope())
725        ast.scope = env
726        # FIXME: for static methods of cdef classes, we build the wrong signature here: first arg becomes 'self'
727        ast.analyse_declarations(env)
728        py_func = ast.stats[-1]  # the DefNode
729        self.fragment_scope = ast.scope
730
731        if isinstance(self.node, DefNode):
732            py_func.specialized_cpdefs = self.nodes[:]
733        else:
734            py_func.specialized_cpdefs = [n.py_func for n in self.nodes]
735
736        return py_func
737
738    def update_fused_defnode_entry(self, env):
739        copy_attributes = (
740            'name', 'pos', 'cname', 'func_cname', 'pyfunc_cname',
741            'pymethdef_cname', 'doc', 'doc_cname', 'is_member',
742            'scope'
743        )
744
745        entry = self.py_func.entry
746
747        for attr in copy_attributes:
748            setattr(entry, attr,
749                    getattr(self.orig_py_func.entry, attr))
750
751        self.py_func.name = self.orig_py_func.name
752        self.py_func.doc = self.orig_py_func.doc
753
754        env.entries.pop('__pyx_fused_cpdef', None)
755        if isinstance(self.node, DefNode):
756            env.entries[entry.name] = entry
757        else:
758            env.entries[entry.name].as_variable = entry
759
760        env.pyfunc_entries.append(entry)
761
762        self.py_func.entry.fused_cfunction = self
763        for node in self.nodes:
764            if isinstance(self.node, DefNode):
765                node.fused_py_func = self.py_func
766            else:
767                node.py_func.fused_py_func = self.py_func
768                node.entry.as_variable = entry
769
770        self.synthesize_defnodes()
771        self.stats.append(self.__signatures__)
772
773    def analyse_expressions(self, env):
774        """
775        Analyse the expressions. Take care to only evaluate default arguments
776        once and clone the result for all specializations
777        """
778        for fused_compound_type in self.fused_compound_types:
779            for fused_type in fused_compound_type.get_fused_types():
780                for specialization_type in fused_type.types:
781                    if specialization_type.is_complex:
782                        specialization_type.create_declaration_utility_code(env)
783
784        if self.py_func:
785            self.__signatures__ = self.__signatures__.analyse_expressions(env)
786            self.py_func = self.py_func.analyse_expressions(env)
787            self.resulting_fused_function = self.resulting_fused_function.analyse_expressions(env)
788            self.fused_func_assignment = self.fused_func_assignment.analyse_expressions(env)
789
790        self.defaults = defaults = []
791
792        for arg in self.node.args:
793            if arg.default:
794                arg.default = arg.default.analyse_expressions(env)
795                defaults.append(ProxyNode(arg.default))
796            else:
797                defaults.append(None)
798
799        for i, stat in enumerate(self.stats):
800            stat = self.stats[i] = stat.analyse_expressions(env)
801            if isinstance(stat, FuncDefNode):
802                for arg, default in zip(stat.args, defaults):
803                    if default is not None:
804                        arg.default = CloneNode(default).coerce_to(arg.type, env)
805
806        if self.py_func:
807            args = [CloneNode(default) for default in defaults if default]
808            self.defaults_tuple = TupleNode(self.pos, args=args)
809            self.defaults_tuple = self.defaults_tuple.analyse_types(env, skip_children=True).coerce_to_pyobject(env)
810            self.defaults_tuple = ProxyNode(self.defaults_tuple)
811            self.code_object = ProxyNode(self.specialized_pycfuncs[0].code_object)
812
813            fused_func = self.resulting_fused_function.arg
814            fused_func.defaults_tuple = CloneNode(self.defaults_tuple)
815            fused_func.code_object = CloneNode(self.code_object)
816
817            for i, pycfunc in enumerate(self.specialized_pycfuncs):
818                pycfunc.code_object = CloneNode(self.code_object)
819                pycfunc = self.specialized_pycfuncs[i] = pycfunc.analyse_types(env)
820                pycfunc.defaults_tuple = CloneNode(self.defaults_tuple)
821        return self
822
823    def synthesize_defnodes(self):
824        """
825        Create the __signatures__ dict of PyCFunctionNode specializations.
826        """
827        if isinstance(self.nodes[0], CFuncDefNode):
828            nodes = [node.py_func for node in self.nodes]
829        else:
830            nodes = self.nodes
831
832        signatures = [StringEncoding.EncodedString(node.specialized_signature_string)
833                      for node in nodes]
834        keys = [ExprNodes.StringNode(node.pos, value=sig)
835                for node, sig in zip(nodes, signatures)]
836        values = [ExprNodes.PyCFunctionNode.from_defnode(node, binding=True)
837                  for node in nodes]
838
839        self.__signatures__ = ExprNodes.DictNode.from_pairs(self.pos, zip(keys, values))
840
841        self.specialized_pycfuncs = values
842        for pycfuncnode in values:
843            pycfuncnode.is_specialization = True
844
845    def generate_function_definitions(self, env, code):
846        if self.py_func:
847            self.py_func.pymethdef_required = True
848            self.fused_func_assignment.generate_function_definitions(env, code)
849
850        for stat in self.stats:
851            if isinstance(stat, FuncDefNode) and stat.entry.used:
852                code.mark_pos(stat.pos)
853                stat.generate_function_definitions(env, code)
854
855    def generate_execution_code(self, code):
856        # Note: all def function specialization are wrapped in PyCFunction
857        # nodes in the self.__signatures__ dictnode.
858        for default in self.defaults:
859            if default is not None:
860                default.generate_evaluation_code(code)
861
862        if self.py_func:
863            self.defaults_tuple.generate_evaluation_code(code)
864            self.code_object.generate_evaluation_code(code)
865
866        for stat in self.stats:
867            code.mark_pos(stat.pos)
868            if isinstance(stat, ExprNodes.ExprNode):
869                stat.generate_evaluation_code(code)
870            else:
871                stat.generate_execution_code(code)
872
873        if self.__signatures__:
874            self.resulting_fused_function.generate_evaluation_code(code)
875
876            code.putln(
877                "((__pyx_FusedFunctionObject *) %s)->__signatures__ = %s;" %
878                                    (self.resulting_fused_function.result(),
879                                     self.__signatures__.result()))
880            code.put_giveref(self.__signatures__.result())
881            self.__signatures__.generate_post_assignment_code(code)
882            self.__signatures__.free_temps(code)
883
884            self.fused_func_assignment.generate_execution_code(code)
885
886            # Dispose of results
887            self.resulting_fused_function.generate_disposal_code(code)
888            self.resulting_fused_function.free_temps(code)
889            self.defaults_tuple.generate_disposal_code(code)
890            self.defaults_tuple.free_temps(code)
891            self.code_object.generate_disposal_code(code)
892            self.code_object.free_temps(code)
893
894        for default in self.defaults:
895            if default is not None:
896                default.generate_disposal_code(code)
897                default.free_temps(code)
898
899    def annotate(self, code):
900        for stat in self.stats:
901            stat.annotate(code)
902