1"""
2python _generate_pyx.py
3
4Generate Ufunc definition source files for scipy.special. Produces
5files '_ufuncs.c' and '_ufuncs_cxx.c' by first producing Cython.
6
7This will generate both calls to PyUFunc_FromFuncAndData and the
8required ufunc inner loops.
9
10The functions signatures are contained in 'functions.json', the syntax
11for a function signature is
12
13    <function>:       <name> ':' <input> '*' <output>
14                        '->' <retval> '*' <ignored_retval>
15    <input>:          <typecode>*
16    <output>:         <typecode>*
17    <retval>:         <typecode>?
18    <ignored_retval>: <typecode>?
19    <headers>:        <header_name> [',' <header_name>]*
20
21The input parameter types are denoted by single character type
22codes, according to
23
24   'f': 'float'
25   'd': 'double'
26   'g': 'long double'
27   'F': 'float complex'
28   'D': 'double complex'
29   'G': 'long double complex'
30   'i': 'int'
31   'l': 'long'
32   'v': 'void'
33
34If multiple kernel functions are given for a single ufunc, the one
35which is used is determined by the standard ufunc mechanism. Kernel
36functions that are listed first are also matched first against the
37ufunc input types, so functions listed earlier take precedence.
38
39In addition, versions with casted variables, such as d->f,D->F and
40i->d are automatically generated.
41
42There should be either a single header that contains all of the kernel
43functions listed, or there should be one header for each kernel
44function. Cython pxd files are allowed in addition to .h files.
45
46Cython functions may use fused types, but the names in the list
47should be the specialized ones, such as 'somefunc[float]'.
48
49Function coming from C++ should have ``++`` appended to the name of
50the header.
51
52Floating-point exceptions inside these Ufuncs are converted to
53special function errors --- which are separately controlled by the
54user, and off by default, as they are usually not especially useful
55for the user.
56
57
58The C++ module
59--------------
60In addition to ``_ufuncs`` module, a second module ``_ufuncs_cxx`` is
61generated. This module only exports function pointers that are to be
62used when constructing some of the ufuncs in ``_ufuncs``. The function
63pointers are exported via Cython's standard mechanism.
64
65This mainly avoids build issues --- Python distutils has no way to
66figure out what to do if you want to link both C++ and Fortran code in
67the same shared library.
68
69"""
70
71# -----------------------------------------------------------------------------
72# Extra code
73# -----------------------------------------------------------------------------
74
75UFUNCS_EXTRA_CODE_COMMON = """\
76# This file is automatically generated by _generate_pyx.py.
77# Do not edit manually!
78
79include "_ufuncs_extra_code_common.pxi"
80"""
81
82UFUNCS_EXTRA_CODE = """\
83include "_ufuncs_extra_code.pxi"
84"""
85
86UFUNCS_EXTRA_CODE_BOTTOM = """\
87#
88# Aliases
89#
90jn = jv
91"""
92
93CYTHON_SPECIAL_PXD = """\
94# This file is automatically generated by _generate_pyx.py.
95# Do not edit manually!
96
97ctypedef fused number_t:
98    double complex
99    double
100
101cpdef number_t spherical_jn(long n, number_t z, bint derivative=*) nogil
102cpdef number_t spherical_yn(long n, number_t z, bint derivative=*) nogil
103cpdef number_t spherical_in(long n, number_t z, bint derivative=*) nogil
104cpdef number_t spherical_kn(long n, number_t z, bint derivative=*) nogil
105"""
106
107CYTHON_SPECIAL_PYX = """\
108# This file is automatically generated by _generate_pyx.py.
109# Do not edit manually!
110\"\"\"
111.. highlight:: cython
112
113Cython API for special functions
114================================
115
116Scalar, typed versions of many of the functions in ``scipy.special``
117can be accessed directly from Cython; the complete list is given
118below. Functions are overloaded using Cython fused types so their
119names match their Python counterpart. The module follows the following
120conventions:
121
122- If a function's Python counterpart returns multiple values, then the
123  function returns its outputs via pointers in the final arguments.
124- If a function's Python counterpart returns a single value, then the
125  function's output is returned directly.
126
127The module is usable from Cython via::
128
129    cimport scipy.special.cython_special
130
131Error handling
132--------------
133
134Functions can indicate an error by returning ``nan``; however they
135cannot emit warnings like their counterparts in ``scipy.special``.
136
137Available functions
138-------------------
139
140FUNCLIST
141
142Custom functions
143----------------
144
145Some functions in ``scipy.special`` which are not ufuncs have custom
146Cython wrappers.
147
148Spherical Bessel functions
149~~~~~~~~~~~~~~~~~~~~~~~~~~
150
151The optional ``derivative`` boolean argument is replaced with an
152optional Cython ``bint``, leading to the following signatures.
153
154- :py:func:`~scipy.special.spherical_jn`::
155
156        double complex spherical_jn(long, double complex)
157        double complex spherical_jn(long, double complex, bint)
158        double spherical_jn(long, double)
159        double spherical_jn(long, double, bint)
160
161- :py:func:`~scipy.special.spherical_yn`::
162
163        double complex spherical_yn(long, double complex)
164        double complex spherical_yn(long, double complex, bint)
165        double spherical_yn(long, double)
166        double spherical_yn(long, double, bint)
167
168- :py:func:`~scipy.special.spherical_in`::
169
170        double complex spherical_in(long, double complex)
171        double complex spherical_in(long, double complex, bint)
172        double spherical_in(long, double)
173        double spherical_in(long, double, bint)
174
175- :py:func:`~scipy.special.spherical_kn`::
176
177        double complex spherical_kn(long, double complex)
178        double complex spherical_kn(long, double complex, bint)
179        double spherical_kn(long, double)
180        double spherical_kn(long, double, bint)
181
182\"\"\"
183
184include "_cython_special.pxi"
185include "_cython_special_custom.pxi"
186"""
187
188STUBS = """\
189# This file is automatically generated by _generate_pyx.py.
190# Do not edit manually!
191
192from typing import Any, Dict
193
194import numpy as np
195
196__all__ = [
197    'geterr',
198    'seterr',
199    'errstate',
200    {ALL}
201]
202
203def geterr() -> Dict[str, str]: ...
204def seterr(**kwargs: str) -> Dict[str, str]: ...
205
206class errstate:
207    def __init__(self, **kargs: str) -> None: ...
208    def __enter__(self) -> None: ...
209    def __exit__(
210        self,
211        exc_type: Any,  # Unused
212        exc_value: Any,  # Unused
213        traceback: Any,  # Unused
214    ) -> None: ...
215
216{STUBS}
217
218"""
219
220
221# -----------------------------------------------------------------------------
222# Code generation
223# -----------------------------------------------------------------------------
224
225import itertools
226import json
227import os
228import optparse
229import re
230import textwrap
231from typing import List
232
233import numpy
234
235
236BASE_DIR = os.path.abspath(os.path.dirname(__file__))
237
238add_newdocs = __import__('add_newdocs')
239
240CY_TYPES = {
241    'f': 'float',
242    'd': 'double',
243    'g': 'long double',
244    'F': 'float complex',
245    'D': 'double complex',
246    'G': 'long double complex',
247    'i': 'int',
248    'l': 'long',
249    'v': 'void',
250}
251
252C_TYPES = {
253    'f': 'npy_float',
254    'd': 'npy_double',
255    'g': 'npy_longdouble',
256    'F': 'npy_cfloat',
257    'D': 'npy_cdouble',
258    'G': 'npy_clongdouble',
259    'i': 'npy_int',
260    'l': 'npy_long',
261    'v': 'void',
262}
263
264TYPE_NAMES = {
265    'f': 'NPY_FLOAT',
266    'd': 'NPY_DOUBLE',
267    'g': 'NPY_LONGDOUBLE',
268    'F': 'NPY_CFLOAT',
269    'D': 'NPY_CDOUBLE',
270    'G': 'NPY_CLONGDOUBLE',
271    'i': 'NPY_INT',
272    'l': 'NPY_LONG',
273}
274
275CYTHON_SPECIAL_BENCHFUNCS = {
276    'airy': ['d*dddd', 'D*DDDD'],
277    'beta': ['dd'],
278    'erf': ['d', 'D'],
279    'exprel': ['d'],
280    'gamma': ['d', 'D'],
281    'jv': ['dd', 'dD'],
282    'loggamma': ['D'],
283    'logit': ['d'],
284    'psi': ['d', 'D'],
285}
286
287
288def underscore(arg):
289    return arg.replace(" ", "_")
290
291
292def cast_order(c):
293    return ['ilfdgFDG'.index(x) for x in c]
294
295
296# These downcasts will cause the function to return NaNs, unless the
297# values happen to coincide exactly.
298DANGEROUS_DOWNCAST = set([
299    ('F', 'i'), ('F', 'l'), ('F', 'f'), ('F', 'd'), ('F', 'g'),
300    ('D', 'i'), ('D', 'l'), ('D', 'f'), ('D', 'd'), ('D', 'g'),
301    ('G', 'i'), ('G', 'l'), ('G', 'f'), ('G', 'd'), ('G', 'g'),
302    ('f', 'i'), ('f', 'l'),
303    ('d', 'i'), ('d', 'l'),
304    ('g', 'i'), ('g', 'l'),
305    ('l', 'i'),
306])
307
308NAN_VALUE = {
309    'f': 'NPY_NAN',
310    'd': 'NPY_NAN',
311    'g': 'NPY_NAN',
312    'F': 'NPY_NAN',
313    'D': 'NPY_NAN',
314    'G': 'NPY_NAN',
315    'i': '0xbad0bad0',
316    'l': '0xbad0bad0',
317}
318
319
320def generate_loop(func_inputs, func_outputs, func_retval,
321                  ufunc_inputs, ufunc_outputs):
322    """
323    Generate a UFunc loop function that calls a function given as its
324    data parameter with the specified input and output arguments and
325    return value.
326
327    This function can be passed to PyUFunc_FromFuncAndData.
328
329    Parameters
330    ----------
331    func_inputs, func_outputs, func_retval : str
332        Signature of the function to call, given as type codes of the
333        input, output and return value arguments. These 1-character
334        codes are given according to the CY_TYPES and TYPE_NAMES
335        lists above.
336
337        The corresponding C function signature to be called is:
338
339            retval func(intype1 iv1, intype2 iv2, ..., outtype1 *ov1, ...);
340
341        If len(ufunc_outputs) == len(func_outputs)+1, the return value
342        is treated as the first output argument. Otherwise, the return
343        value is ignored.
344
345    ufunc_inputs, ufunc_outputs : str
346        Ufunc input and output signature.
347
348        This does not have to exactly match the function signature,
349        as long as the type casts work out on the C level.
350
351    Returns
352    -------
353    loop_name
354        Name of the generated loop function.
355    loop_body
356        Generated C code for the loop.
357
358    """
359    if len(func_inputs) != len(ufunc_inputs):
360        raise ValueError("Function and ufunc have different number of inputs")
361
362    if len(func_outputs) != len(ufunc_outputs) and not (
363            func_retval != "v" and len(func_outputs)+1 == len(ufunc_outputs)):
364        raise ValueError("Function retval and ufunc outputs don't match")
365
366    name = "loop_%s_%s_%s_As_%s_%s" % (
367        func_retval, func_inputs, func_outputs, ufunc_inputs, ufunc_outputs
368        )
369    body = "cdef void %s(char **args, np.npy_intp *dims, np.npy_intp *steps, void *data) nogil:\n" % name
370    body += "    cdef np.npy_intp i, n = dims[0]\n"
371    body += "    cdef void *func = (<void**>data)[0]\n"
372    body += "    cdef char *func_name = <char*>(<void**>data)[1]\n"
373
374    for j in range(len(ufunc_inputs)):
375        body += "    cdef char *ip%d = args[%d]\n" % (j, j)
376    for j in range(len(ufunc_outputs)):
377        body += "    cdef char *op%d = args[%d]\n" % (j, j + len(ufunc_inputs))
378
379    ftypes = []
380    fvars = []
381    outtypecodes = []
382    for j in range(len(func_inputs)):
383        ftypes.append(CY_TYPES[func_inputs[j]])
384        fvars.append("<%s>(<%s*>ip%d)[0]" % (
385            CY_TYPES[func_inputs[j]],
386            CY_TYPES[ufunc_inputs[j]], j))
387
388    if len(func_outputs)+1 == len(ufunc_outputs):
389        func_joff = 1
390        outtypecodes.append(func_retval)
391        body += "    cdef %s ov0\n" % (CY_TYPES[func_retval],)
392    else:
393        func_joff = 0
394
395    for j, outtype in enumerate(func_outputs):
396        body += "    cdef %s ov%d\n" % (CY_TYPES[outtype], j+func_joff)
397        ftypes.append("%s *" % CY_TYPES[outtype])
398        fvars.append("&ov%d" % (j+func_joff))
399        outtypecodes.append(outtype)
400
401    body += "    for i in range(n):\n"
402    if len(func_outputs)+1 == len(ufunc_outputs):
403        rv = "ov0 = "
404    else:
405        rv = ""
406
407    funcall = "        %s(<%s(*)(%s) nogil>func)(%s)\n" % (
408        rv, CY_TYPES[func_retval], ", ".join(ftypes), ", ".join(fvars))
409
410    # Cast-check inputs and call function
411    input_checks = []
412    for j in range(len(func_inputs)):
413        if (ufunc_inputs[j], func_inputs[j]) in DANGEROUS_DOWNCAST:
414            chk = "<%s>(<%s*>ip%d)[0] == (<%s*>ip%d)[0]" % (
415                CY_TYPES[func_inputs[j]], CY_TYPES[ufunc_inputs[j]], j,
416                CY_TYPES[ufunc_inputs[j]], j)
417            input_checks.append(chk)
418
419    if input_checks:
420        body += "        if %s:\n" % (" and ".join(input_checks))
421        body += "    " + funcall
422        body += "        else:\n"
423        body += "            sf_error.error(func_name, sf_error.DOMAIN, \"invalid input argument\")\n"
424        for j, outtype in enumerate(outtypecodes):
425            body += "            ov%d = <%s>%s\n" % (
426                j, CY_TYPES[outtype], NAN_VALUE[outtype])
427    else:
428        body += funcall
429
430    # Assign and cast-check output values
431    for j, (outtype, fouttype) in enumerate(zip(ufunc_outputs, outtypecodes)):
432        if (fouttype, outtype) in DANGEROUS_DOWNCAST:
433            body += "        if ov%d == <%s>ov%d:\n" % (j, CY_TYPES[outtype], j)
434            body += "            (<%s *>op%d)[0] = <%s>ov%d\n" % (
435                CY_TYPES[outtype], j, CY_TYPES[outtype], j)
436            body += "        else:\n"
437            body += "            sf_error.error(func_name, sf_error.DOMAIN, \"invalid output\")\n"
438            body += "            (<%s *>op%d)[0] = <%s>%s\n" % (
439                CY_TYPES[outtype], j, CY_TYPES[outtype], NAN_VALUE[outtype])
440        else:
441            body += "        (<%s *>op%d)[0] = <%s>ov%d\n" % (
442                CY_TYPES[outtype], j, CY_TYPES[outtype], j)
443    for j in range(len(ufunc_inputs)):
444        body += "        ip%d += steps[%d]\n" % (j, j)
445    for j in range(len(ufunc_outputs)):
446        body += "        op%d += steps[%d]\n" % (j, j + len(ufunc_inputs))
447
448    body += "    sf_error.check_fpe(func_name)\n"
449
450    return name, body
451
452
453def generate_fused_type(codes):
454    """
455    Generate name of and cython code for a fused type.
456
457    Parameters
458    ----------
459    typecodes : str
460        Valid inputs to CY_TYPES (i.e. f, d, g, ...).
461
462    """
463    cytypes = [CY_TYPES[x] for x in codes]
464    name = codes + "_number_t"
465    declaration = ["ctypedef fused " + name + ":"]
466    for cytype in cytypes:
467        declaration.append("    " + cytype)
468    declaration = "\n".join(declaration)
469    return name, declaration
470
471
472def generate_bench(name, codes):
473    tab = " "*4
474    top, middle, end = [], [], []
475
476    tmp = codes.split("*")
477    if len(tmp) > 1:
478        incodes = tmp[0]
479        outcodes = tmp[1]
480    else:
481        incodes = tmp[0]
482        outcodes = ""
483
484    inargs, inargs_and_types = [], []
485    for n, code in enumerate(incodes):
486        arg = "x{}".format(n)
487        inargs.append(arg)
488        inargs_and_types.append("{} {}".format(CY_TYPES[code], arg))
489    line = "def {{}}(int N, {}):".format(", ".join(inargs_and_types))
490    top.append(line)
491    top.append(tab + "cdef int n")
492
493    outargs = []
494    for n, code in enumerate(outcodes):
495        arg = "y{}".format(n)
496        outargs.append("&{}".format(arg))
497        line = "cdef {} {}".format(CY_TYPES[code], arg)
498        middle.append(tab + line)
499
500    end.append(tab + "for n in range(N):")
501    end.append(2*tab + "{}({})")
502    pyfunc = "_bench_{}_{}_{}".format(name, incodes, "py")
503    cyfunc = "_bench_{}_{}_{}".format(name, incodes, "cy")
504    pytemplate = "\n".join(top + end)
505    cytemplate = "\n".join(top + middle + end)
506    pybench = pytemplate.format(pyfunc, "_ufuncs." + name, ", ".join(inargs))
507    cybench = cytemplate.format(cyfunc, name, ", ".join(inargs + outargs))
508    return pybench, cybench
509
510
511def generate_doc(name, specs):
512    tab = " "*4
513    doc = ["- :py:func:`~scipy.special.{}`::\n".format(name)]
514    for spec in specs:
515        incodes, outcodes = spec.split("->")
516        incodes = incodes.split("*")
517        intypes = [CY_TYPES[x] for x in incodes[0]]
518        if len(incodes) > 1:
519            types = [f"{CY_TYPES[x]} *" for x in incodes[1]]
520            intypes.extend(types)
521        outtype = CY_TYPES[outcodes]
522        line = "{} {}({})".format(outtype, name, ", ".join(intypes))
523        doc.append(2*tab + line)
524    doc[-1] = "{}\n".format(doc[-1])
525    doc = "\n".join(doc)
526    return doc
527
528
529def npy_cdouble_from_double_complex(var):
530    """Cast a Cython double complex to a NumPy cdouble."""
531    res = "_complexstuff.npy_cdouble_from_double_complex({})".format(var)
532    return res
533
534
535def double_complex_from_npy_cdouble(var):
536    """Cast a NumPy cdouble to a Cython double complex."""
537    res = "_complexstuff.double_complex_from_npy_cdouble({})".format(var)
538    return res
539
540
541def iter_variants(inputs, outputs):
542    """
543    Generate variants of UFunc signatures, by changing variable types,
544    within the limitation that the corresponding C types casts still
545    work out.
546
547    This does not generate all possibilities, just the ones required
548    for the ufunc to work properly with the most common data types.
549
550    Parameters
551    ----------
552    inputs, outputs : str
553        UFunc input and output signature strings
554
555    Yields
556    ------
557    new_input, new_output : str
558        Modified input and output strings.
559        Also the original input/output pair is yielded.
560
561    """
562    maps = [
563        # always use long instead of int (more common type on 64-bit)
564        ('i', 'l'),
565    ]
566
567    # float32-preserving signatures
568    if not ('i' in inputs or 'l' in inputs):
569        # Don't add float32 versions of ufuncs with integer arguments, as this
570        # can lead to incorrect dtype selection if the integer arguments are
571        # arrays, but float arguments are scalars.
572        # For instance sph_harm(0,[0],0,0).dtype == complex64
573        # This may be a NumPy bug, but we need to work around it.
574        # cf. gh-4895, https://github.com/numpy/numpy/issues/5895
575        maps = maps + [(a + 'dD', b + 'fF') for a, b in maps]
576
577    # do the replacements
578    for src, dst in maps:
579        new_inputs = inputs
580        new_outputs = outputs
581        for a, b in zip(src, dst):
582            new_inputs = new_inputs.replace(a, b)
583            new_outputs = new_outputs.replace(a, b)
584        yield new_inputs, new_outputs
585
586
587class Func:
588    """
589    Base class for Ufunc and FusedFunc.
590
591    """
592    def __init__(self, name, signatures):
593        self.name = name
594        self.signatures = []
595        self.function_name_overrides = {}
596
597        for header in signatures.keys():
598            for name, sig in signatures[header].items():
599                inarg, outarg, ret = self._parse_signature(sig)
600                self.signatures.append((name, inarg, outarg, ret, header))
601
602    def _parse_signature(self, sig):
603        m = re.match(r"\s*([fdgFDGil]*)\s*\*\s*([fdgFDGil]*)\s*->\s*([*fdgFDGil]*)\s*$", sig)
604        if m:
605            inarg, outarg, ret = [x.strip() for x in m.groups()]
606            if ret.count('*') > 1:
607                raise ValueError("{}: Invalid signature: {}".format(self.name, sig))
608            return inarg, outarg, ret
609        m = re.match(r"\s*([fdgFDGil]*)\s*->\s*([fdgFDGil]?)\s*$", sig)
610        if m:
611            inarg, ret = [x.strip() for x in m.groups()]
612            return inarg, "", ret
613        raise ValueError("{}: Invalid signature: {}".format(self.name, sig))
614
615    def get_prototypes(self, nptypes_for_h=False):
616        prototypes = []
617        for func_name, inarg, outarg, ret, header in self.signatures:
618            ret = ret.replace('*', '')
619            c_args = ([C_TYPES[x] for x in inarg]
620                      + [C_TYPES[x] + ' *' for x in outarg])
621            cy_args = ([CY_TYPES[x] for x in inarg]
622                       + [CY_TYPES[x] + ' *' for x in outarg])
623            c_proto = "%s (*)(%s)" % (C_TYPES[ret], ", ".join(c_args))
624            if header.endswith("h") and nptypes_for_h:
625                cy_proto = c_proto + "nogil"
626            else:
627                cy_proto = "%s (*)(%s) nogil" % (CY_TYPES[ret], ", ".join(cy_args))
628            prototypes.append((func_name, c_proto, cy_proto, header))
629        return prototypes
630
631    def cython_func_name(self, c_name, specialized=False, prefix="_func_",
632                         override=True):
633        # act on function name overrides
634        if override and c_name in self.function_name_overrides:
635            c_name = self.function_name_overrides[c_name]
636            prefix = ""
637
638        # support fused types
639        m = re.match(r'^(.*?)(\[.*\])$', c_name)
640        if m:
641            c_base_name, fused_part = m.groups()
642        else:
643            c_base_name, fused_part = c_name, ""
644        if specialized:
645            return "%s%s%s" % (prefix, c_base_name, fused_part.replace(' ', '_'))
646        else:
647            return "%s%s" % (prefix, c_base_name,)
648
649
650class Ufunc(Func):
651    """
652    Ufunc signature, restricted format suitable for special functions.
653
654    Parameters
655    ----------
656    name
657        Name of the ufunc to create
658    signature
659        String of form 'func: fff*ff->f, func2: ddd->*i' describing
660        the C-level functions and types of their input arguments
661        and return values.
662
663        The syntax is 'function_name: inputparams*outputparams->output_retval*ignored_retval'
664
665    Attributes
666    ----------
667    name : str
668        Python name for the Ufunc
669    signatures : list of (func_name, inarg_spec, outarg_spec, ret_spec, header_name)
670        List of parsed signatures
671    doc : str
672        Docstring, obtained from add_newdocs
673    function_name_overrides : dict of str->str
674        Overrides for the function names in signatures
675
676    """
677    def __init__(self, name, signatures):
678        super().__init__(name, signatures)
679        self.doc = add_newdocs.get(name)
680        if self.doc is None:
681            raise ValueError("No docstring for ufunc %r" % name)
682        self.doc = textwrap.dedent(self.doc).strip()
683
684    def _get_signatures_and_loops(self, all_loops):
685        inarg_num = None
686        outarg_num = None
687
688        seen = set()
689        variants = []
690
691        def add_variant(func_name, inarg, outarg, ret, inp, outp):
692            if inp in seen:
693                return
694            seen.add(inp)
695
696            sig = (func_name, inp, outp)
697            if "v" in outp:
698                raise ValueError("%s: void signature %r" % (self.name, sig))
699            if len(inp) != inarg_num or len(outp) != outarg_num:
700                raise ValueError("%s: signature %r does not have %d/%d input/output args" % (
701                    self.name, sig,
702                    inarg_num, outarg_num))
703
704            loop_name, loop = generate_loop(inarg, outarg, ret, inp, outp)
705            all_loops[loop_name] = loop
706            variants.append((func_name, loop_name, inp, outp))
707
708        # First add base variants
709        for func_name, inarg, outarg, ret, header in self.signatures:
710            outp = re.sub(r'\*.*', '', ret) + outarg
711            ret = ret.replace('*', '')
712            if inarg_num is None:
713                inarg_num = len(inarg)
714                outarg_num = len(outp)
715
716            inp, outp = list(iter_variants(inarg, outp))[0]
717            add_variant(func_name, inarg, outarg, ret, inp, outp)
718
719        # Then the supplementary ones
720        for func_name, inarg, outarg, ret, header in self.signatures:
721            outp = re.sub(r'\*.*', '', ret) + outarg
722            ret = ret.replace('*', '')
723            for inp, outp in iter_variants(inarg, outp):
724                add_variant(func_name, inarg, outarg, ret, inp, outp)
725
726        # Then sort variants to input argument cast order
727        # -- the sort is stable, so functions earlier in the signature list
728        #    are still preferred
729        variants.sort(key=lambda v: cast_order(v[2]))
730
731        return variants, inarg_num, outarg_num
732
733    def generate(self, all_loops):
734        toplevel = ""
735
736        variants, inarg_num, outarg_num = self._get_signatures_and_loops(
737                all_loops)
738
739        loops = []
740        funcs = []
741        types = []
742
743        for func_name, loop_name, inputs, outputs in variants:
744            for x in inputs:
745                types.append(TYPE_NAMES[x])
746            for x in outputs:
747                types.append(TYPE_NAMES[x])
748            loops.append(loop_name)
749            funcs.append(func_name)
750
751        toplevel += "cdef np.PyUFuncGenericFunction ufunc_%s_loops[%d]\n" % (self.name, len(loops))
752        toplevel += "cdef void *ufunc_%s_ptr[%d]\n" % (self.name, 2*len(funcs))
753        toplevel += "cdef void *ufunc_%s_data[%d]\n" % (self.name, len(funcs))
754        toplevel += "cdef char ufunc_%s_types[%d]\n" % (self.name, len(types))
755        toplevel += 'cdef char *ufunc_%s_doc = (\n    "%s")\n' % (
756            self.name,
757            self.doc.replace("\\", "\\\\").replace('"', '\\"').replace('\n', '\\n\"\n    "')
758            )
759
760        for j, function in enumerate(loops):
761            toplevel += "ufunc_%s_loops[%d] = <np.PyUFuncGenericFunction>%s\n" % (self.name, j, function)
762        for j, type in enumerate(types):
763            toplevel += "ufunc_%s_types[%d] = <char>%s\n" % (self.name, j, type)
764        for j, func in enumerate(funcs):
765            toplevel += "ufunc_%s_ptr[2*%d] = <void*>%s\n" % (self.name, j,
766                                                              self.cython_func_name(func, specialized=True))
767            toplevel += "ufunc_%s_ptr[2*%d+1] = <void*>(<char*>\"%s\")\n" % (self.name, j,
768                                                                             self.name)
769        for j, func in enumerate(funcs):
770            toplevel += "ufunc_%s_data[%d] = &ufunc_%s_ptr[2*%d]\n" % (
771                self.name, j, self.name, j)
772
773        toplevel += ('@ = np.PyUFunc_FromFuncAndData(ufunc_@_loops, '
774                     'ufunc_@_data, ufunc_@_types, %d, %d, %d, 0, '
775                     '"@", ufunc_@_doc, 0)\n' % (len(types)/(inarg_num+outarg_num),
776                                                 inarg_num, outarg_num)
777                     ).replace('@', self.name)
778
779        return toplevel
780
781
782class FusedFunc(Func):
783    """
784    Generate code for a fused-type special function that can be
785    cimported in Cython.
786
787    """
788    def __init__(self, name, signatures):
789        super().__init__(name, signatures)
790        self.doc = "See the documentation for scipy.special." + self.name
791        # "codes" are the keys for CY_TYPES
792        self.incodes, self.outcodes = self._get_codes()
793        self.fused_types = set()
794        self.intypes, infused_types = self._get_types(self.incodes)
795        self.fused_types.update(infused_types)
796        self.outtypes, outfused_types = self._get_types(self.outcodes)
797        self.fused_types.update(outfused_types)
798        self.invars, self.outvars = self._get_vars()
799
800    def _get_codes(self):
801        inarg_num, outarg_num = None, None
802        all_inp, all_outp = [], []
803        for _, inarg, outarg, ret, _ in self.signatures:
804            outp = re.sub(r'\*.*', '', ret) + outarg
805            if inarg_num is None:
806                inarg_num = len(inarg)
807                outarg_num = len(outp)
808            inp, outp = list(iter_variants(inarg, outp))[0]
809            all_inp.append(inp)
810            all_outp.append(outp)
811
812        incodes = []
813        for n in range(inarg_num):
814            codes = unique([x[n] for x in all_inp])
815            codes.sort()
816            incodes.append(''.join(codes))
817        outcodes = []
818        for n in range(outarg_num):
819            codes = unique([x[n] for x in all_outp])
820            codes.sort()
821            outcodes.append(''.join(codes))
822
823        return tuple(incodes), tuple(outcodes)
824
825    def _get_types(self, codes):
826        all_types = []
827        fused_types = set()
828        for code in codes:
829            if len(code) == 1:
830                # It's not a fused type
831                all_types.append((CY_TYPES[code], code))
832            else:
833                # It's a fused type
834                fused_type, dec = generate_fused_type(code)
835                fused_types.add(dec)
836                all_types.append((fused_type, code))
837        return all_types, fused_types
838
839    def _get_vars(self):
840        invars = ["x{}".format(n) for n in range(len(self.intypes))]
841        outvars = ["y{}".format(n) for n in range(len(self.outtypes))]
842        return invars, outvars
843
844    def _get_conditional(self, types, codes, adverb):
845        """Generate an if/elif/else clause that selects a specialization of
846        fused types.
847
848        """
849        clauses = []
850        seen = set()
851        for (typ, typcode), code in zip(types, codes):
852            if len(typcode) == 1:
853                continue
854            if typ not in seen:
855                clauses.append(f"{typ} is {underscore(CY_TYPES[code])}")
856                seen.add(typ)
857        if clauses and adverb != "else":
858            line = "{} {}:".format(adverb, " and ".join(clauses))
859        elif clauses and adverb == "else":
860            line = "else:"
861        else:
862            line = None
863        return line
864
865    def _get_incallvars(self, intypes, c):
866        """Generate pure input variables to a specialization,
867        i.e., variables that aren't used to return a value.
868
869        """
870        incallvars = []
871        for n, intype in enumerate(intypes):
872            var = self.invars[n]
873            if c and intype == "double complex":
874                var = npy_cdouble_from_double_complex(var)
875            incallvars.append(var)
876        return incallvars
877
878    def _get_outcallvars(self, outtypes, c):
879        """Generate output variables to a specialization,
880        i.e., pointers that are used to return values.
881
882        """
883        outcallvars, tmpvars, casts = [], [], []
884        # If there are more out variables than out types, we want the
885        # tail of the out variables
886        start = len(self.outvars) - len(outtypes)
887        outvars = self.outvars[start:]
888        for n, (var, outtype) in enumerate(zip(outvars, outtypes)):
889            if c and outtype == "double complex":
890                tmp = "tmp{}".format(n)
891                tmpvars.append(tmp)
892                outcallvars.append("&{}".format(tmp))
893                tmpcast = double_complex_from_npy_cdouble(tmp)
894                casts.append("{}[0] = {}".format(var, tmpcast))
895            else:
896                outcallvars.append("{}".format(var))
897        return outcallvars, tmpvars, casts
898
899    def _get_nan_decs(self):
900        """Set all variables to nan for specializations of fused types for
901        which don't have signatures.
902
903        """
904        # Set non fused-type variables to nan
905        tab = " "*4
906        fused_types, lines = [], [tab + "else:"]
907        seen = set()
908        for outvar, outtype, code in zip(self.outvars, self.outtypes,
909                                         self.outcodes):
910            if len(code) == 1:
911                line = "{}[0] = {}".format(outvar, NAN_VALUE[code])
912                lines.append(2*tab + line)
913            else:
914                fused_type = outtype
915                name, _ = fused_type
916                if name not in seen:
917                    fused_types.append(fused_type)
918                    seen.add(name)
919        if not fused_types:
920            return lines
921
922        # Set fused-type variables to nan
923        all_codes = tuple([codes for _unused, codes in fused_types])
924
925        codelens = [len(x) for x in all_codes]
926        last = numpy.prod(codelens) - 1
927        for m, codes in enumerate(itertools.product(*all_codes)):
928            fused_codes, decs = [], []
929            for n, fused_type in enumerate(fused_types):
930                code = codes[n]
931                fused_codes.append(underscore(CY_TYPES[code]))
932                for nn, outvar in enumerate(self.outvars):
933                    if self.outtypes[nn] == fused_type:
934                        line = "{}[0] = {}".format(outvar, NAN_VALUE[code])
935                        decs.append(line)
936            if m == 0:
937                adverb = "if"
938            elif m == last:
939                adverb = "else"
940            else:
941                adverb = "elif"
942            cond = self._get_conditional(fused_types, codes, adverb)
943            lines.append(2*tab + cond)
944            lines.extend([3*tab + x for x in decs])
945        return lines
946
947    def _get_tmp_decs(self, all_tmpvars):
948        """Generate the declarations of any necessary temporary
949        variables.
950
951        """
952        tab = " "*4
953        tmpvars = list(all_tmpvars)
954        tmpvars.sort()
955        tmpdecs = [tab + "cdef npy_cdouble {}".format(tmpvar)
956                   for tmpvar in tmpvars]
957        return tmpdecs
958
959    def _get_python_wrap(self):
960        """Generate a Python wrapper for functions which pass their
961        arguments as pointers.
962
963        """
964        tab = " "*4
965        body, callvars = [], []
966        for (intype, _), invar in zip(self.intypes, self.invars):
967            callvars.append("{} {}".format(intype, invar))
968        line = "def _{}_pywrap({}):".format(self.name, ", ".join(callvars))
969        body.append(line)
970        for (outtype, _), outvar in zip(self.outtypes, self.outvars):
971            line = "cdef {} {}".format(outtype, outvar)
972            body.append(tab + line)
973        addr_outvars = [f"&{x}" for x in self.outvars]
974        line = "{}({}, {})".format(self.name, ", ".join(self.invars),
975                                   ", ".join(addr_outvars))
976        body.append(tab + line)
977        line = "return {}".format(", ".join(self.outvars))
978        body.append(tab + line)
979        body = "\n".join(body)
980        return body
981
982    def _get_common(self, signum, sig):
983        """Generate code common to all the _generate_* methods."""
984        tab = " "*4
985        func_name, incodes, outcodes, retcode, header = sig
986        # Convert ints to longs; cf. iter_variants()
987        incodes = incodes.replace('i', 'l')
988        outcodes = outcodes.replace('i', 'l')
989        retcode = retcode.replace('i', 'l')
990
991        if header.endswith("h"):
992            c = True
993        else:
994            c = False
995        if header.endswith("++"):
996            cpp = True
997        else:
998            cpp = False
999
1000        intypes = [CY_TYPES[x] for x in incodes]
1001        outtypes = [CY_TYPES[x] for x in outcodes]
1002        retcode = re.sub(r'\*.*', '', retcode)
1003        if not retcode:
1004            retcode = 'v'
1005        rettype = CY_TYPES[retcode]
1006
1007        if cpp:
1008            # Functions from _ufuncs_cxx are exported as a void*
1009            # pointers; cast them to the correct types
1010            func_name = "scipy.special._ufuncs_cxx._export_{}".format(func_name)
1011            func_name = "(<{}(*)({}) nogil>{})"\
1012                    .format(rettype, ", ".join(intypes + outtypes), func_name)
1013        else:
1014            func_name = self.cython_func_name(func_name, specialized=True)
1015
1016        if signum == 0:
1017            adverb = "if"
1018        else:
1019            adverb = "elif"
1020        cond = self._get_conditional(self.intypes, incodes, adverb)
1021        if cond:
1022            lines = [tab + cond]
1023            sp = 2*tab
1024        else:
1025            lines = []
1026            sp = tab
1027
1028        return func_name, incodes, outcodes, retcode, \
1029            intypes, outtypes, rettype, c, lines, sp
1030
1031    def _generate_from_return_and_no_outargs(self):
1032        tab = " "*4
1033        specs, body = [], []
1034        for signum, sig in enumerate(self.signatures):
1035            func_name, incodes, outcodes, retcode, intypes, outtypes, \
1036                rettype, c, lines, sp = self._get_common(signum, sig)
1037            body.extend(lines)
1038
1039            # Generate the call to the specialized function
1040            callvars = self._get_incallvars(intypes, c)
1041            call = "{}({})".format(func_name, ", ".join(callvars))
1042            if c and rettype == "double complex":
1043                call = double_complex_from_npy_cdouble(call)
1044            line = sp + "return {}".format(call)
1045            body.append(line)
1046            sig = "{}->{}".format(incodes, retcode)
1047            specs.append(sig)
1048
1049        if len(specs) > 1:
1050            # Return nan for signatures without a specialization
1051            body.append(tab + "else:")
1052            outtype, outcodes = self.outtypes[0]
1053            last = len(outcodes) - 1
1054            if len(outcodes) == 1:
1055                line = "return {}".format(NAN_VALUE[outcodes])
1056                body.append(2*tab + line)
1057            else:
1058                for n, code in enumerate(outcodes):
1059                    if n == 0:
1060                        adverb = "if"
1061                    elif n == last:
1062                        adverb = "else"
1063                    else:
1064                        adverb = "elif"
1065                    cond = self._get_conditional(self.outtypes, code, adverb)
1066                    body.append(2*tab + cond)
1067                    line = "return {}".format(NAN_VALUE[code])
1068                    body.append(3*tab + line)
1069
1070        # Generate the head of the function
1071        callvars, head = [], []
1072        for n, (intype, _) in enumerate(self.intypes):
1073            callvars.append("{} {}".format(intype, self.invars[n]))
1074        (outtype, _) = self.outtypes[0]
1075        dec = "cpdef {} {}({}) nogil".format(outtype, self.name, ", ".join(callvars))
1076        head.append(dec + ":")
1077        head.append(tab + '"""{}"""'.format(self.doc))
1078
1079        src = "\n".join(head + body)
1080        return dec, src, specs
1081
1082    def _generate_from_outargs_and_no_return(self):
1083        tab = " "*4
1084        all_tmpvars = set()
1085        specs, body = [], []
1086        for signum, sig in enumerate(self.signatures):
1087            func_name, incodes, outcodes, retcode, intypes, outtypes, \
1088                rettype, c, lines, sp = self._get_common(signum, sig)
1089            body.extend(lines)
1090
1091            # Generate the call to the specialized function
1092            callvars = self._get_incallvars(intypes, c)
1093            outcallvars, tmpvars, casts = self._get_outcallvars(outtypes, c)
1094            callvars.extend(outcallvars)
1095            all_tmpvars.update(tmpvars)
1096
1097            call = "{}({})".format(func_name, ", ".join(callvars))
1098            body.append(sp + call)
1099            body.extend([sp + x for x in casts])
1100            if len(outcodes) == 1:
1101                sig = "{}->{}".format(incodes, outcodes)
1102                specs.append(sig)
1103            else:
1104                sig = "{}*{}->v".format(incodes, outcodes)
1105                specs.append(sig)
1106
1107        if len(specs) > 1:
1108            lines = self._get_nan_decs()
1109            body.extend(lines)
1110
1111        if len(self.outvars) == 1:
1112            line = "return {}[0]".format(self.outvars[0])
1113            body.append(tab + line)
1114
1115        # Generate the head of the function
1116        callvars, head = [], []
1117        for invar, (intype, _) in zip(self.invars, self.intypes):
1118            callvars.append("{} {}".format(intype, invar))
1119        if len(self.outvars) > 1:
1120            for outvar, (outtype, _) in zip(self.outvars, self.outtypes):
1121                callvars.append("{} *{}".format(outtype, outvar))
1122        if len(self.outvars) == 1:
1123            outtype, _ = self.outtypes[0]
1124            dec = "cpdef {} {}({}) nogil".format(outtype, self.name, ", ".join(callvars))
1125        else:
1126            dec = "cdef void {}({}) nogil".format(self.name, ", ".join(callvars))
1127        head.append(dec + ":")
1128        head.append(tab + '"""{}"""'.format(self.doc))
1129        if len(self.outvars) == 1:
1130            outvar = self.outvars[0]
1131            outtype, _ = self.outtypes[0]
1132            line = "cdef {} {}".format(outtype, outvar)
1133            head.append(tab + line)
1134        head.extend(self._get_tmp_decs(all_tmpvars))
1135
1136        src = "\n".join(head + body)
1137        return dec, src, specs
1138
1139    def _generate_from_outargs_and_return(self):
1140        tab = " "*4
1141        all_tmpvars = set()
1142        specs, body = [], []
1143        for signum, sig in enumerate(self.signatures):
1144            func_name, incodes, outcodes, retcode, intypes, outtypes, \
1145                rettype, c, lines, sp = self._get_common(signum, sig)
1146            body.extend(lines)
1147
1148            # Generate the call to the specialized function
1149            callvars = self._get_incallvars(intypes, c)
1150            outcallvars, tmpvars, casts = self._get_outcallvars(outtypes, c)
1151            callvars.extend(outcallvars)
1152            all_tmpvars.update(tmpvars)
1153            call = "{}({})".format(func_name, ", ".join(callvars))
1154            if c and rettype == "double complex":
1155                call = double_complex_from_npy_cdouble(call)
1156            call = "{}[0] = {}".format(self.outvars[0], call)
1157            body.append(sp + call)
1158            body.extend([sp + x for x in casts])
1159            sig = "{}*{}->v".format(incodes, outcodes + retcode)
1160            specs.append(sig)
1161
1162        if len(specs) > 1:
1163            lines = self._get_nan_decs()
1164            body.extend(lines)
1165
1166        # Generate the head of the function
1167        callvars, head = [], []
1168        for invar, (intype, _) in zip(self.invars, self.intypes):
1169            callvars.append("{} {}".format(intype, invar))
1170        for outvar, (outtype, _) in zip(self.outvars, self.outtypes):
1171            callvars.append("{} *{}".format(outtype, outvar))
1172        dec = "cdef void {}({}) nogil".format(self.name, ", ".join(callvars))
1173        head.append(dec + ":")
1174        head.append(tab + '"""{}"""'.format(self.doc))
1175        head.extend(self._get_tmp_decs(all_tmpvars))
1176
1177        src = "\n".join(head + body)
1178        return dec, src, specs
1179
1180    def generate(self):
1181        _, _, outcodes, retcode, _ = self.signatures[0]
1182        retcode = re.sub(r'\*.*', '', retcode)
1183        if not retcode:
1184            retcode = 'v'
1185
1186        if len(outcodes) == 0 and retcode != 'v':
1187            dec, src, specs = self._generate_from_return_and_no_outargs()
1188        elif len(outcodes) > 0 and retcode == 'v':
1189            dec, src, specs = self._generate_from_outargs_and_no_return()
1190        elif len(outcodes) > 0 and retcode != 'v':
1191            dec, src, specs = self._generate_from_outargs_and_return()
1192        else:
1193            raise ValueError("Invalid signature")
1194
1195        if len(self.outvars) > 1:
1196            wrap = self._get_python_wrap()
1197        else:
1198            wrap = None
1199
1200        return dec, src, specs, self.fused_types, wrap
1201
1202
1203def get_declaration(ufunc, c_name, c_proto, cy_proto, header,
1204                    proto_h_filename):
1205    """
1206    Construct a Cython declaration of a function coming either from a
1207    pxd or a header file. Do sufficient tricks to enable compile-time
1208    type checking against the signature expected by the ufunc.
1209
1210    """
1211    defs = []
1212    defs_h = []
1213
1214    var_name = c_name.replace('[', '_').replace(']', '_').replace(' ', '_')
1215
1216    if header.endswith('.pxd'):
1217        defs.append("from .%s cimport %s as %s" % (
1218            header[:-4], ufunc.cython_func_name(c_name, prefix=""),
1219            ufunc.cython_func_name(c_name)))
1220
1221        # check function signature at compile time
1222        proto_name = '_proto_%s_t' % var_name
1223        defs.append("ctypedef %s" % (cy_proto.replace('(*)', proto_name)))
1224        defs.append("cdef %s *%s_var = &%s" % (
1225            proto_name, proto_name, ufunc.cython_func_name(c_name, specialized=True)))
1226    else:
1227        # redeclare the function, so that the assumed
1228        # signature is checked at compile time
1229        new_name = "%s \"%s\"" % (ufunc.cython_func_name(c_name), c_name)
1230        defs.append("cdef extern from \"%s\":" % proto_h_filename)
1231        defs.append("    cdef %s" % (cy_proto.replace('(*)', new_name)))
1232        defs_h.append("#include \"%s\"" % header)
1233        defs_h.append("%s;" % (c_proto.replace('(*)', c_name)))
1234
1235    return defs, defs_h, var_name
1236
1237
1238def generate_ufuncs(fn_prefix, cxx_fn_prefix, ufuncs):
1239    filename = fn_prefix + ".pyx"
1240    proto_h_filename = fn_prefix + '_defs.h'
1241
1242    cxx_proto_h_filename = cxx_fn_prefix + '_defs.h'
1243    cxx_pyx_filename = cxx_fn_prefix + ".pyx"
1244    cxx_pxd_filename = cxx_fn_prefix + ".pxd"
1245
1246    toplevel = ""
1247
1248    # for _ufuncs*
1249    defs = []
1250    defs_h = []
1251    all_loops = {}
1252
1253    # for _ufuncs_cxx*
1254    cxx_defs = []
1255    cxx_pxd_defs = [
1256        "from . cimport sf_error",
1257        "cdef void _set_action(sf_error.sf_error_t, sf_error.sf_action_t) nogil"
1258    ]
1259    cxx_defs_h = []
1260
1261    ufuncs.sort(key=lambda u: u.name)
1262
1263    for ufunc in ufuncs:
1264        # generate function declaration and type checking snippets
1265        cfuncs = ufunc.get_prototypes()
1266        for c_name, c_proto, cy_proto, header in cfuncs:
1267            if header.endswith('++'):
1268                header = header[:-2]
1269
1270                # for the CXX module
1271                item_defs, item_defs_h, var_name = get_declaration(ufunc, c_name, c_proto, cy_proto,
1272                                                                   header, cxx_proto_h_filename)
1273                cxx_defs.extend(item_defs)
1274                cxx_defs_h.extend(item_defs_h)
1275
1276                cxx_defs.append("cdef void *_export_%s = <void*>%s" % (
1277                    var_name, ufunc.cython_func_name(c_name, specialized=True, override=False)))
1278                cxx_pxd_defs.append("cdef void *_export_%s" % (var_name,))
1279
1280                # let cython grab the function pointer from the c++ shared library
1281                ufunc.function_name_overrides[c_name] = "scipy.special._ufuncs_cxx._export_" + var_name
1282            else:
1283                # usual case
1284                item_defs, item_defs_h, _ = get_declaration(ufunc, c_name, c_proto, cy_proto, header,
1285                                                            proto_h_filename)
1286                defs.extend(item_defs)
1287                defs_h.extend(item_defs_h)
1288
1289        # ufunc creation code snippet
1290        t = ufunc.generate(all_loops)
1291        toplevel += t + "\n"
1292
1293    # Produce output
1294    toplevel = "\n".join(sorted(all_loops.values()) + defs + [toplevel])
1295    # Generate an `__all__` for the module
1296    all_ufuncs = (
1297        [
1298            "'{}'".format(ufunc.name)
1299            for ufunc in ufuncs if not ufunc.name.startswith('_')
1300        ]
1301        + ["'geterr'", "'seterr'", "'errstate'", "'jn'"]
1302    )
1303    module_all = '__all__ = [{}]'.format(', '.join(all_ufuncs))
1304
1305    with open(filename, 'w') as f:
1306        f.write(UFUNCS_EXTRA_CODE_COMMON)
1307        f.write(UFUNCS_EXTRA_CODE)
1308        f.write(module_all)
1309        f.write("\n")
1310        f.write(toplevel)
1311        f.write(UFUNCS_EXTRA_CODE_BOTTOM)
1312
1313    defs_h = unique(defs_h)
1314    with open(proto_h_filename, 'w') as f:
1315        f.write("#ifndef UFUNCS_PROTO_H\n#define UFUNCS_PROTO_H 1\n")
1316        f.write("\n".join(defs_h))
1317        f.write("\n#endif\n")
1318
1319    cxx_defs_h = unique(cxx_defs_h)
1320    with open(cxx_proto_h_filename, 'w') as f:
1321        f.write("#ifndef UFUNCS_PROTO_H\n#define UFUNCS_PROTO_H 1\n")
1322        f.write("\n".join(cxx_defs_h))
1323        f.write("\n#endif\n")
1324
1325    with open(cxx_pyx_filename, 'w') as f:
1326        f.write(UFUNCS_EXTRA_CODE_COMMON)
1327        f.write("\n")
1328        f.write("\n".join(cxx_defs))
1329        f.write("\n# distutils: language = c++\n")
1330
1331    with open(cxx_pxd_filename, 'w') as f:
1332        f.write("\n".join(cxx_pxd_defs))
1333
1334
1335def generate_fused_funcs(modname, ufunc_fn_prefix, fused_funcs):
1336    pxdfile = modname + ".pxd"
1337    pyxfile = modname + ".pyx"
1338    proto_h_filename = ufunc_fn_prefix + '_defs.h'
1339
1340    sources = []
1341    declarations = []
1342    # Code for benchmarks
1343    bench_aux = []
1344    fused_types = set()
1345    # Parameters for the tests
1346    doc = []
1347    defs = []
1348
1349    for func in fused_funcs:
1350        if func.name.startswith("_"):
1351            # Don't try to deal with functions that have extra layers
1352            # of wrappers.
1353            continue
1354
1355        # Get the function declaration for the .pxd and the source
1356        # code for the .pyx
1357        dec, src, specs, func_fused_types, wrap = func.generate()
1358        declarations.append(dec)
1359        sources.append(src)
1360        if wrap:
1361            sources.append(wrap)
1362        fused_types.update(func_fused_types)
1363
1364        # Declare the specializations
1365        cfuncs = func.get_prototypes(nptypes_for_h=True)
1366        for c_name, c_proto, cy_proto, header in cfuncs:
1367            if header.endswith('++'):
1368                # We grab the c++ functions from the c++ module
1369                continue
1370            item_defs, _, _ = get_declaration(func, c_name, c_proto,
1371                                              cy_proto, header,
1372                                              proto_h_filename)
1373            defs.extend(item_defs)
1374
1375        # Add a line to the documentation
1376        doc.append(generate_doc(func.name, specs))
1377
1378        # Generate code for benchmarks
1379        if func.name in CYTHON_SPECIAL_BENCHFUNCS:
1380            for codes in CYTHON_SPECIAL_BENCHFUNCS[func.name]:
1381                pybench, cybench = generate_bench(func.name, codes)
1382                bench_aux.extend([pybench, cybench])
1383
1384    fused_types = list(fused_types)
1385    fused_types.sort()
1386
1387    with open(pxdfile, 'w') as f:
1388        f.write(CYTHON_SPECIAL_PXD)
1389        f.write("\n")
1390        f.write("\n\n".join(fused_types))
1391        f.write("\n\n")
1392        f.write("\n".join(declarations))
1393    with open(pyxfile, 'w') as f:
1394        header = CYTHON_SPECIAL_PYX
1395        header = header.replace("FUNCLIST", "\n".join(doc))
1396        f.write(header)
1397        f.write("\n")
1398        f.write("\n".join(defs))
1399        f.write("\n\n")
1400        f.write("\n\n".join(sources))
1401        f.write("\n\n")
1402        f.write("\n\n".join(bench_aux))
1403
1404
1405def generate_ufuncs_type_stubs(module_name: str, ufuncs: List[Ufunc]):
1406    stubs, module_all = [], []
1407    for ufunc in ufuncs:
1408        stubs.append(f'{ufunc.name}: np.ufunc')
1409        if not ufunc.name.startswith('_'):
1410            module_all.append(f"'{ufunc.name}'")
1411    # jn is an alias for jv.
1412    module_all.append("'jn'")
1413    stubs.append('jn: np.ufunc')
1414    module_all.sort()
1415    stubs.sort()
1416
1417    contents = STUBS.format(
1418        ALL=',\n    '.join(module_all),
1419        STUBS='\n'.join(stubs),
1420    )
1421
1422    stubs_file = f'{module_name}.pyi'
1423    with open(stubs_file, 'w') as f:
1424        f.write(contents)
1425
1426
1427def unique(lst):
1428    """
1429    Return a list without repeated entries (first occurrence is kept),
1430    preserving order.
1431    """
1432    seen = set()
1433    new_lst = []
1434    for item in lst:
1435        if item in seen:
1436            continue
1437        seen.add(item)
1438        new_lst.append(item)
1439    return new_lst
1440
1441
1442def all_newer(src_files, dst_files):
1443    from distutils.dep_util import newer
1444    return all(os.path.exists(dst) and newer(dst, src)
1445               for dst in dst_files for src in src_files)
1446
1447
1448def main():
1449    p = optparse.OptionParser(usage=(__doc__ or '').strip())
1450    options, args = p.parse_args()
1451    if len(args) != 0:
1452        p.error('invalid number of arguments')
1453
1454    pwd = os.path.dirname(__file__)
1455    src_files = (os.path.abspath(__file__),
1456                 os.path.abspath(os.path.join(pwd, 'functions.json')),
1457                 os.path.abspath(os.path.join(pwd, 'add_newdocs.py')))
1458    dst_files = ('_ufuncs.pyx',
1459                 '_ufuncs_defs.h',
1460                 '_ufuncs_cxx.pyx',
1461                 '_ufuncs_cxx.pxd',
1462                 '_ufuncs_cxx_defs.h',
1463                 '_ufuncs.pyi',
1464                 'cython_special.pyx',
1465                 'cython_special.pxd')
1466
1467    os.chdir(BASE_DIR)
1468
1469    if all_newer(src_files, dst_files):
1470        print("scipy/special/_generate_pyx.py: all files up-to-date")
1471        return
1472
1473    ufuncs, fused_funcs = [], []
1474    with open('functions.json') as data:
1475        functions = json.load(data)
1476    for f, sig in functions.items():
1477        ufuncs.append(Ufunc(f, sig))
1478        fused_funcs.append(FusedFunc(f, sig))
1479    generate_ufuncs("_ufuncs", "_ufuncs_cxx", ufuncs)
1480    generate_ufuncs_type_stubs("_ufuncs", ufuncs)
1481    generate_fused_funcs("cython_special", "_ufuncs", fused_funcs)
1482
1483
1484if __name__ == "__main__":
1485    main()
1486