1""" 2python _generate_pyx.py 3 4Generate Ufunc definition source files for scipy.special. Produces 5files '_ufuncs.c' and '_ufuncs_cxx.c' by first producing Cython. 6 7This will generate both calls to PyUFunc_FromFuncAndData and the 8required ufunc inner loops. 9 10The functions signatures are contained in 'functions.json', the syntax 11for a function signature is 12 13 <function>: <name> ':' <input> '*' <output> 14 '->' <retval> '*' <ignored_retval> 15 <input>: <typecode>* 16 <output>: <typecode>* 17 <retval>: <typecode>? 18 <ignored_retval>: <typecode>? 19 <headers>: <header_name> [',' <header_name>]* 20 21The input parameter types are denoted by single character type 22codes, according to 23 24 'f': 'float' 25 'd': 'double' 26 'g': 'long double' 27 'F': 'float complex' 28 'D': 'double complex' 29 'G': 'long double complex' 30 'i': 'int' 31 'l': 'long' 32 'v': 'void' 33 34If multiple kernel functions are given for a single ufunc, the one 35which is used is determined by the standard ufunc mechanism. Kernel 36functions that are listed first are also matched first against the 37ufunc input types, so functions listed earlier take precedence. 38 39In addition, versions with casted variables, such as d->f,D->F and 40i->d are automatically generated. 41 42There should be either a single header that contains all of the kernel 43functions listed, or there should be one header for each kernel 44function. Cython pxd files are allowed in addition to .h files. 45 46Cython functions may use fused types, but the names in the list 47should be the specialized ones, such as 'somefunc[float]'. 48 49Function coming from C++ should have ``++`` appended to the name of 50the header. 51 52Floating-point exceptions inside these Ufuncs are converted to 53special function errors --- which are separately controlled by the 54user, and off by default, as they are usually not especially useful 55for the user. 56 57 58The C++ module 59-------------- 60In addition to ``_ufuncs`` module, a second module ``_ufuncs_cxx`` is 61generated. This module only exports function pointers that are to be 62used when constructing some of the ufuncs in ``_ufuncs``. The function 63pointers are exported via Cython's standard mechanism. 64 65This mainly avoids build issues --- Python distutils has no way to 66figure out what to do if you want to link both C++ and Fortran code in 67the same shared library. 68 69""" 70 71# ----------------------------------------------------------------------------- 72# Extra code 73# ----------------------------------------------------------------------------- 74 75UFUNCS_EXTRA_CODE_COMMON = """\ 76# This file is automatically generated by _generate_pyx.py. 77# Do not edit manually! 78 79include "_ufuncs_extra_code_common.pxi" 80""" 81 82UFUNCS_EXTRA_CODE = """\ 83include "_ufuncs_extra_code.pxi" 84""" 85 86UFUNCS_EXTRA_CODE_BOTTOM = """\ 87# 88# Aliases 89# 90jn = jv 91""" 92 93CYTHON_SPECIAL_PXD = """\ 94# This file is automatically generated by _generate_pyx.py. 95# Do not edit manually! 96 97ctypedef fused number_t: 98 double complex 99 double 100 101cpdef number_t spherical_jn(long n, number_t z, bint derivative=*) nogil 102cpdef number_t spherical_yn(long n, number_t z, bint derivative=*) nogil 103cpdef number_t spherical_in(long n, number_t z, bint derivative=*) nogil 104cpdef number_t spherical_kn(long n, number_t z, bint derivative=*) nogil 105""" 106 107CYTHON_SPECIAL_PYX = """\ 108# This file is automatically generated by _generate_pyx.py. 109# Do not edit manually! 110\"\"\" 111.. highlight:: cython 112 113Cython API for special functions 114================================ 115 116Scalar, typed versions of many of the functions in ``scipy.special`` 117can be accessed directly from Cython; the complete list is given 118below. Functions are overloaded using Cython fused types so their 119names match their Python counterpart. The module follows the following 120conventions: 121 122- If a function's Python counterpart returns multiple values, then the 123 function returns its outputs via pointers in the final arguments. 124- If a function's Python counterpart returns a single value, then the 125 function's output is returned directly. 126 127The module is usable from Cython via:: 128 129 cimport scipy.special.cython_special 130 131Error handling 132-------------- 133 134Functions can indicate an error by returning ``nan``; however they 135cannot emit warnings like their counterparts in ``scipy.special``. 136 137Available functions 138------------------- 139 140FUNCLIST 141 142Custom functions 143---------------- 144 145Some functions in ``scipy.special`` which are not ufuncs have custom 146Cython wrappers. 147 148Spherical Bessel functions 149~~~~~~~~~~~~~~~~~~~~~~~~~~ 150 151The optional ``derivative`` boolean argument is replaced with an 152optional Cython ``bint``, leading to the following signatures. 153 154- :py:func:`~scipy.special.spherical_jn`:: 155 156 double complex spherical_jn(long, double complex) 157 double complex spherical_jn(long, double complex, bint) 158 double spherical_jn(long, double) 159 double spherical_jn(long, double, bint) 160 161- :py:func:`~scipy.special.spherical_yn`:: 162 163 double complex spherical_yn(long, double complex) 164 double complex spherical_yn(long, double complex, bint) 165 double spherical_yn(long, double) 166 double spherical_yn(long, double, bint) 167 168- :py:func:`~scipy.special.spherical_in`:: 169 170 double complex spherical_in(long, double complex) 171 double complex spherical_in(long, double complex, bint) 172 double spherical_in(long, double) 173 double spherical_in(long, double, bint) 174 175- :py:func:`~scipy.special.spherical_kn`:: 176 177 double complex spherical_kn(long, double complex) 178 double complex spherical_kn(long, double complex, bint) 179 double spherical_kn(long, double) 180 double spherical_kn(long, double, bint) 181 182\"\"\" 183 184include "_cython_special.pxi" 185include "_cython_special_custom.pxi" 186""" 187 188STUBS = """\ 189# This file is automatically generated by _generate_pyx.py. 190# Do not edit manually! 191 192from typing import Any, Dict 193 194import numpy as np 195 196__all__ = [ 197 'geterr', 198 'seterr', 199 'errstate', 200 {ALL} 201] 202 203def geterr() -> Dict[str, str]: ... 204def seterr(**kwargs: str) -> Dict[str, str]: ... 205 206class errstate: 207 def __init__(self, **kargs: str) -> None: ... 208 def __enter__(self) -> None: ... 209 def __exit__( 210 self, 211 exc_type: Any, # Unused 212 exc_value: Any, # Unused 213 traceback: Any, # Unused 214 ) -> None: ... 215 216{STUBS} 217 218""" 219 220 221# ----------------------------------------------------------------------------- 222# Code generation 223# ----------------------------------------------------------------------------- 224 225import itertools 226import json 227import os 228import optparse 229import re 230import textwrap 231from typing import List 232 233import numpy 234 235 236BASE_DIR = os.path.abspath(os.path.dirname(__file__)) 237 238add_newdocs = __import__('add_newdocs') 239 240CY_TYPES = { 241 'f': 'float', 242 'd': 'double', 243 'g': 'long double', 244 'F': 'float complex', 245 'D': 'double complex', 246 'G': 'long double complex', 247 'i': 'int', 248 'l': 'long', 249 'v': 'void', 250} 251 252C_TYPES = { 253 'f': 'npy_float', 254 'd': 'npy_double', 255 'g': 'npy_longdouble', 256 'F': 'npy_cfloat', 257 'D': 'npy_cdouble', 258 'G': 'npy_clongdouble', 259 'i': 'npy_int', 260 'l': 'npy_long', 261 'v': 'void', 262} 263 264TYPE_NAMES = { 265 'f': 'NPY_FLOAT', 266 'd': 'NPY_DOUBLE', 267 'g': 'NPY_LONGDOUBLE', 268 'F': 'NPY_CFLOAT', 269 'D': 'NPY_CDOUBLE', 270 'G': 'NPY_CLONGDOUBLE', 271 'i': 'NPY_INT', 272 'l': 'NPY_LONG', 273} 274 275CYTHON_SPECIAL_BENCHFUNCS = { 276 'airy': ['d*dddd', 'D*DDDD'], 277 'beta': ['dd'], 278 'erf': ['d', 'D'], 279 'exprel': ['d'], 280 'gamma': ['d', 'D'], 281 'jv': ['dd', 'dD'], 282 'loggamma': ['D'], 283 'logit': ['d'], 284 'psi': ['d', 'D'], 285} 286 287 288def underscore(arg): 289 return arg.replace(" ", "_") 290 291 292def cast_order(c): 293 return ['ilfdgFDG'.index(x) for x in c] 294 295 296# These downcasts will cause the function to return NaNs, unless the 297# values happen to coincide exactly. 298DANGEROUS_DOWNCAST = set([ 299 ('F', 'i'), ('F', 'l'), ('F', 'f'), ('F', 'd'), ('F', 'g'), 300 ('D', 'i'), ('D', 'l'), ('D', 'f'), ('D', 'd'), ('D', 'g'), 301 ('G', 'i'), ('G', 'l'), ('G', 'f'), ('G', 'd'), ('G', 'g'), 302 ('f', 'i'), ('f', 'l'), 303 ('d', 'i'), ('d', 'l'), 304 ('g', 'i'), ('g', 'l'), 305 ('l', 'i'), 306]) 307 308NAN_VALUE = { 309 'f': 'NPY_NAN', 310 'd': 'NPY_NAN', 311 'g': 'NPY_NAN', 312 'F': 'NPY_NAN', 313 'D': 'NPY_NAN', 314 'G': 'NPY_NAN', 315 'i': '0xbad0bad0', 316 'l': '0xbad0bad0', 317} 318 319 320def generate_loop(func_inputs, func_outputs, func_retval, 321 ufunc_inputs, ufunc_outputs): 322 """ 323 Generate a UFunc loop function that calls a function given as its 324 data parameter with the specified input and output arguments and 325 return value. 326 327 This function can be passed to PyUFunc_FromFuncAndData. 328 329 Parameters 330 ---------- 331 func_inputs, func_outputs, func_retval : str 332 Signature of the function to call, given as type codes of the 333 input, output and return value arguments. These 1-character 334 codes are given according to the CY_TYPES and TYPE_NAMES 335 lists above. 336 337 The corresponding C function signature to be called is: 338 339 retval func(intype1 iv1, intype2 iv2, ..., outtype1 *ov1, ...); 340 341 If len(ufunc_outputs) == len(func_outputs)+1, the return value 342 is treated as the first output argument. Otherwise, the return 343 value is ignored. 344 345 ufunc_inputs, ufunc_outputs : str 346 Ufunc input and output signature. 347 348 This does not have to exactly match the function signature, 349 as long as the type casts work out on the C level. 350 351 Returns 352 ------- 353 loop_name 354 Name of the generated loop function. 355 loop_body 356 Generated C code for the loop. 357 358 """ 359 if len(func_inputs) != len(ufunc_inputs): 360 raise ValueError("Function and ufunc have different number of inputs") 361 362 if len(func_outputs) != len(ufunc_outputs) and not ( 363 func_retval != "v" and len(func_outputs)+1 == len(ufunc_outputs)): 364 raise ValueError("Function retval and ufunc outputs don't match") 365 366 name = "loop_%s_%s_%s_As_%s_%s" % ( 367 func_retval, func_inputs, func_outputs, ufunc_inputs, ufunc_outputs 368 ) 369 body = "cdef void %s(char **args, np.npy_intp *dims, np.npy_intp *steps, void *data) nogil:\n" % name 370 body += " cdef np.npy_intp i, n = dims[0]\n" 371 body += " cdef void *func = (<void**>data)[0]\n" 372 body += " cdef char *func_name = <char*>(<void**>data)[1]\n" 373 374 for j in range(len(ufunc_inputs)): 375 body += " cdef char *ip%d = args[%d]\n" % (j, j) 376 for j in range(len(ufunc_outputs)): 377 body += " cdef char *op%d = args[%d]\n" % (j, j + len(ufunc_inputs)) 378 379 ftypes = [] 380 fvars = [] 381 outtypecodes = [] 382 for j in range(len(func_inputs)): 383 ftypes.append(CY_TYPES[func_inputs[j]]) 384 fvars.append("<%s>(<%s*>ip%d)[0]" % ( 385 CY_TYPES[func_inputs[j]], 386 CY_TYPES[ufunc_inputs[j]], j)) 387 388 if len(func_outputs)+1 == len(ufunc_outputs): 389 func_joff = 1 390 outtypecodes.append(func_retval) 391 body += " cdef %s ov0\n" % (CY_TYPES[func_retval],) 392 else: 393 func_joff = 0 394 395 for j, outtype in enumerate(func_outputs): 396 body += " cdef %s ov%d\n" % (CY_TYPES[outtype], j+func_joff) 397 ftypes.append("%s *" % CY_TYPES[outtype]) 398 fvars.append("&ov%d" % (j+func_joff)) 399 outtypecodes.append(outtype) 400 401 body += " for i in range(n):\n" 402 if len(func_outputs)+1 == len(ufunc_outputs): 403 rv = "ov0 = " 404 else: 405 rv = "" 406 407 funcall = " %s(<%s(*)(%s) nogil>func)(%s)\n" % ( 408 rv, CY_TYPES[func_retval], ", ".join(ftypes), ", ".join(fvars)) 409 410 # Cast-check inputs and call function 411 input_checks = [] 412 for j in range(len(func_inputs)): 413 if (ufunc_inputs[j], func_inputs[j]) in DANGEROUS_DOWNCAST: 414 chk = "<%s>(<%s*>ip%d)[0] == (<%s*>ip%d)[0]" % ( 415 CY_TYPES[func_inputs[j]], CY_TYPES[ufunc_inputs[j]], j, 416 CY_TYPES[ufunc_inputs[j]], j) 417 input_checks.append(chk) 418 419 if input_checks: 420 body += " if %s:\n" % (" and ".join(input_checks)) 421 body += " " + funcall 422 body += " else:\n" 423 body += " sf_error.error(func_name, sf_error.DOMAIN, \"invalid input argument\")\n" 424 for j, outtype in enumerate(outtypecodes): 425 body += " ov%d = <%s>%s\n" % ( 426 j, CY_TYPES[outtype], NAN_VALUE[outtype]) 427 else: 428 body += funcall 429 430 # Assign and cast-check output values 431 for j, (outtype, fouttype) in enumerate(zip(ufunc_outputs, outtypecodes)): 432 if (fouttype, outtype) in DANGEROUS_DOWNCAST: 433 body += " if ov%d == <%s>ov%d:\n" % (j, CY_TYPES[outtype], j) 434 body += " (<%s *>op%d)[0] = <%s>ov%d\n" % ( 435 CY_TYPES[outtype], j, CY_TYPES[outtype], j) 436 body += " else:\n" 437 body += " sf_error.error(func_name, sf_error.DOMAIN, \"invalid output\")\n" 438 body += " (<%s *>op%d)[0] = <%s>%s\n" % ( 439 CY_TYPES[outtype], j, CY_TYPES[outtype], NAN_VALUE[outtype]) 440 else: 441 body += " (<%s *>op%d)[0] = <%s>ov%d\n" % ( 442 CY_TYPES[outtype], j, CY_TYPES[outtype], j) 443 for j in range(len(ufunc_inputs)): 444 body += " ip%d += steps[%d]\n" % (j, j) 445 for j in range(len(ufunc_outputs)): 446 body += " op%d += steps[%d]\n" % (j, j + len(ufunc_inputs)) 447 448 body += " sf_error.check_fpe(func_name)\n" 449 450 return name, body 451 452 453def generate_fused_type(codes): 454 """ 455 Generate name of and cython code for a fused type. 456 457 Parameters 458 ---------- 459 typecodes : str 460 Valid inputs to CY_TYPES (i.e. f, d, g, ...). 461 462 """ 463 cytypes = [CY_TYPES[x] for x in codes] 464 name = codes + "_number_t" 465 declaration = ["ctypedef fused " + name + ":"] 466 for cytype in cytypes: 467 declaration.append(" " + cytype) 468 declaration = "\n".join(declaration) 469 return name, declaration 470 471 472def generate_bench(name, codes): 473 tab = " "*4 474 top, middle, end = [], [], [] 475 476 tmp = codes.split("*") 477 if len(tmp) > 1: 478 incodes = tmp[0] 479 outcodes = tmp[1] 480 else: 481 incodes = tmp[0] 482 outcodes = "" 483 484 inargs, inargs_and_types = [], [] 485 for n, code in enumerate(incodes): 486 arg = "x{}".format(n) 487 inargs.append(arg) 488 inargs_and_types.append("{} {}".format(CY_TYPES[code], arg)) 489 line = "def {{}}(int N, {}):".format(", ".join(inargs_and_types)) 490 top.append(line) 491 top.append(tab + "cdef int n") 492 493 outargs = [] 494 for n, code in enumerate(outcodes): 495 arg = "y{}".format(n) 496 outargs.append("&{}".format(arg)) 497 line = "cdef {} {}".format(CY_TYPES[code], arg) 498 middle.append(tab + line) 499 500 end.append(tab + "for n in range(N):") 501 end.append(2*tab + "{}({})") 502 pyfunc = "_bench_{}_{}_{}".format(name, incodes, "py") 503 cyfunc = "_bench_{}_{}_{}".format(name, incodes, "cy") 504 pytemplate = "\n".join(top + end) 505 cytemplate = "\n".join(top + middle + end) 506 pybench = pytemplate.format(pyfunc, "_ufuncs." + name, ", ".join(inargs)) 507 cybench = cytemplate.format(cyfunc, name, ", ".join(inargs + outargs)) 508 return pybench, cybench 509 510 511def generate_doc(name, specs): 512 tab = " "*4 513 doc = ["- :py:func:`~scipy.special.{}`::\n".format(name)] 514 for spec in specs: 515 incodes, outcodes = spec.split("->") 516 incodes = incodes.split("*") 517 intypes = [CY_TYPES[x] for x in incodes[0]] 518 if len(incodes) > 1: 519 types = [f"{CY_TYPES[x]} *" for x in incodes[1]] 520 intypes.extend(types) 521 outtype = CY_TYPES[outcodes] 522 line = "{} {}({})".format(outtype, name, ", ".join(intypes)) 523 doc.append(2*tab + line) 524 doc[-1] = "{}\n".format(doc[-1]) 525 doc = "\n".join(doc) 526 return doc 527 528 529def npy_cdouble_from_double_complex(var): 530 """Cast a Cython double complex to a NumPy cdouble.""" 531 res = "_complexstuff.npy_cdouble_from_double_complex({})".format(var) 532 return res 533 534 535def double_complex_from_npy_cdouble(var): 536 """Cast a NumPy cdouble to a Cython double complex.""" 537 res = "_complexstuff.double_complex_from_npy_cdouble({})".format(var) 538 return res 539 540 541def iter_variants(inputs, outputs): 542 """ 543 Generate variants of UFunc signatures, by changing variable types, 544 within the limitation that the corresponding C types casts still 545 work out. 546 547 This does not generate all possibilities, just the ones required 548 for the ufunc to work properly with the most common data types. 549 550 Parameters 551 ---------- 552 inputs, outputs : str 553 UFunc input and output signature strings 554 555 Yields 556 ------ 557 new_input, new_output : str 558 Modified input and output strings. 559 Also the original input/output pair is yielded. 560 561 """ 562 maps = [ 563 # always use long instead of int (more common type on 64-bit) 564 ('i', 'l'), 565 ] 566 567 # float32-preserving signatures 568 if not ('i' in inputs or 'l' in inputs): 569 # Don't add float32 versions of ufuncs with integer arguments, as this 570 # can lead to incorrect dtype selection if the integer arguments are 571 # arrays, but float arguments are scalars. 572 # For instance sph_harm(0,[0],0,0).dtype == complex64 573 # This may be a NumPy bug, but we need to work around it. 574 # cf. gh-4895, https://github.com/numpy/numpy/issues/5895 575 maps = maps + [(a + 'dD', b + 'fF') for a, b in maps] 576 577 # do the replacements 578 for src, dst in maps: 579 new_inputs = inputs 580 new_outputs = outputs 581 for a, b in zip(src, dst): 582 new_inputs = new_inputs.replace(a, b) 583 new_outputs = new_outputs.replace(a, b) 584 yield new_inputs, new_outputs 585 586 587class Func: 588 """ 589 Base class for Ufunc and FusedFunc. 590 591 """ 592 def __init__(self, name, signatures): 593 self.name = name 594 self.signatures = [] 595 self.function_name_overrides = {} 596 597 for header in signatures.keys(): 598 for name, sig in signatures[header].items(): 599 inarg, outarg, ret = self._parse_signature(sig) 600 self.signatures.append((name, inarg, outarg, ret, header)) 601 602 def _parse_signature(self, sig): 603 m = re.match(r"\s*([fdgFDGil]*)\s*\*\s*([fdgFDGil]*)\s*->\s*([*fdgFDGil]*)\s*$", sig) 604 if m: 605 inarg, outarg, ret = [x.strip() for x in m.groups()] 606 if ret.count('*') > 1: 607 raise ValueError("{}: Invalid signature: {}".format(self.name, sig)) 608 return inarg, outarg, ret 609 m = re.match(r"\s*([fdgFDGil]*)\s*->\s*([fdgFDGil]?)\s*$", sig) 610 if m: 611 inarg, ret = [x.strip() for x in m.groups()] 612 return inarg, "", ret 613 raise ValueError("{}: Invalid signature: {}".format(self.name, sig)) 614 615 def get_prototypes(self, nptypes_for_h=False): 616 prototypes = [] 617 for func_name, inarg, outarg, ret, header in self.signatures: 618 ret = ret.replace('*', '') 619 c_args = ([C_TYPES[x] for x in inarg] 620 + [C_TYPES[x] + ' *' for x in outarg]) 621 cy_args = ([CY_TYPES[x] for x in inarg] 622 + [CY_TYPES[x] + ' *' for x in outarg]) 623 c_proto = "%s (*)(%s)" % (C_TYPES[ret], ", ".join(c_args)) 624 if header.endswith("h") and nptypes_for_h: 625 cy_proto = c_proto + "nogil" 626 else: 627 cy_proto = "%s (*)(%s) nogil" % (CY_TYPES[ret], ", ".join(cy_args)) 628 prototypes.append((func_name, c_proto, cy_proto, header)) 629 return prototypes 630 631 def cython_func_name(self, c_name, specialized=False, prefix="_func_", 632 override=True): 633 # act on function name overrides 634 if override and c_name in self.function_name_overrides: 635 c_name = self.function_name_overrides[c_name] 636 prefix = "" 637 638 # support fused types 639 m = re.match(r'^(.*?)(\[.*\])$', c_name) 640 if m: 641 c_base_name, fused_part = m.groups() 642 else: 643 c_base_name, fused_part = c_name, "" 644 if specialized: 645 return "%s%s%s" % (prefix, c_base_name, fused_part.replace(' ', '_')) 646 else: 647 return "%s%s" % (prefix, c_base_name,) 648 649 650class Ufunc(Func): 651 """ 652 Ufunc signature, restricted format suitable for special functions. 653 654 Parameters 655 ---------- 656 name 657 Name of the ufunc to create 658 signature 659 String of form 'func: fff*ff->f, func2: ddd->*i' describing 660 the C-level functions and types of their input arguments 661 and return values. 662 663 The syntax is 'function_name: inputparams*outputparams->output_retval*ignored_retval' 664 665 Attributes 666 ---------- 667 name : str 668 Python name for the Ufunc 669 signatures : list of (func_name, inarg_spec, outarg_spec, ret_spec, header_name) 670 List of parsed signatures 671 doc : str 672 Docstring, obtained from add_newdocs 673 function_name_overrides : dict of str->str 674 Overrides for the function names in signatures 675 676 """ 677 def __init__(self, name, signatures): 678 super().__init__(name, signatures) 679 self.doc = add_newdocs.get(name) 680 if self.doc is None: 681 raise ValueError("No docstring for ufunc %r" % name) 682 self.doc = textwrap.dedent(self.doc).strip() 683 684 def _get_signatures_and_loops(self, all_loops): 685 inarg_num = None 686 outarg_num = None 687 688 seen = set() 689 variants = [] 690 691 def add_variant(func_name, inarg, outarg, ret, inp, outp): 692 if inp in seen: 693 return 694 seen.add(inp) 695 696 sig = (func_name, inp, outp) 697 if "v" in outp: 698 raise ValueError("%s: void signature %r" % (self.name, sig)) 699 if len(inp) != inarg_num or len(outp) != outarg_num: 700 raise ValueError("%s: signature %r does not have %d/%d input/output args" % ( 701 self.name, sig, 702 inarg_num, outarg_num)) 703 704 loop_name, loop = generate_loop(inarg, outarg, ret, inp, outp) 705 all_loops[loop_name] = loop 706 variants.append((func_name, loop_name, inp, outp)) 707 708 # First add base variants 709 for func_name, inarg, outarg, ret, header in self.signatures: 710 outp = re.sub(r'\*.*', '', ret) + outarg 711 ret = ret.replace('*', '') 712 if inarg_num is None: 713 inarg_num = len(inarg) 714 outarg_num = len(outp) 715 716 inp, outp = list(iter_variants(inarg, outp))[0] 717 add_variant(func_name, inarg, outarg, ret, inp, outp) 718 719 # Then the supplementary ones 720 for func_name, inarg, outarg, ret, header in self.signatures: 721 outp = re.sub(r'\*.*', '', ret) + outarg 722 ret = ret.replace('*', '') 723 for inp, outp in iter_variants(inarg, outp): 724 add_variant(func_name, inarg, outarg, ret, inp, outp) 725 726 # Then sort variants to input argument cast order 727 # -- the sort is stable, so functions earlier in the signature list 728 # are still preferred 729 variants.sort(key=lambda v: cast_order(v[2])) 730 731 return variants, inarg_num, outarg_num 732 733 def generate(self, all_loops): 734 toplevel = "" 735 736 variants, inarg_num, outarg_num = self._get_signatures_and_loops( 737 all_loops) 738 739 loops = [] 740 funcs = [] 741 types = [] 742 743 for func_name, loop_name, inputs, outputs in variants: 744 for x in inputs: 745 types.append(TYPE_NAMES[x]) 746 for x in outputs: 747 types.append(TYPE_NAMES[x]) 748 loops.append(loop_name) 749 funcs.append(func_name) 750 751 toplevel += "cdef np.PyUFuncGenericFunction ufunc_%s_loops[%d]\n" % (self.name, len(loops)) 752 toplevel += "cdef void *ufunc_%s_ptr[%d]\n" % (self.name, 2*len(funcs)) 753 toplevel += "cdef void *ufunc_%s_data[%d]\n" % (self.name, len(funcs)) 754 toplevel += "cdef char ufunc_%s_types[%d]\n" % (self.name, len(types)) 755 toplevel += 'cdef char *ufunc_%s_doc = (\n "%s")\n' % ( 756 self.name, 757 self.doc.replace("\\", "\\\\").replace('"', '\\"').replace('\n', '\\n\"\n "') 758 ) 759 760 for j, function in enumerate(loops): 761 toplevel += "ufunc_%s_loops[%d] = <np.PyUFuncGenericFunction>%s\n" % (self.name, j, function) 762 for j, type in enumerate(types): 763 toplevel += "ufunc_%s_types[%d] = <char>%s\n" % (self.name, j, type) 764 for j, func in enumerate(funcs): 765 toplevel += "ufunc_%s_ptr[2*%d] = <void*>%s\n" % (self.name, j, 766 self.cython_func_name(func, specialized=True)) 767 toplevel += "ufunc_%s_ptr[2*%d+1] = <void*>(<char*>\"%s\")\n" % (self.name, j, 768 self.name) 769 for j, func in enumerate(funcs): 770 toplevel += "ufunc_%s_data[%d] = &ufunc_%s_ptr[2*%d]\n" % ( 771 self.name, j, self.name, j) 772 773 toplevel += ('@ = np.PyUFunc_FromFuncAndData(ufunc_@_loops, ' 774 'ufunc_@_data, ufunc_@_types, %d, %d, %d, 0, ' 775 '"@", ufunc_@_doc, 0)\n' % (len(types)/(inarg_num+outarg_num), 776 inarg_num, outarg_num) 777 ).replace('@', self.name) 778 779 return toplevel 780 781 782class FusedFunc(Func): 783 """ 784 Generate code for a fused-type special function that can be 785 cimported in Cython. 786 787 """ 788 def __init__(self, name, signatures): 789 super().__init__(name, signatures) 790 self.doc = "See the documentation for scipy.special." + self.name 791 # "codes" are the keys for CY_TYPES 792 self.incodes, self.outcodes = self._get_codes() 793 self.fused_types = set() 794 self.intypes, infused_types = self._get_types(self.incodes) 795 self.fused_types.update(infused_types) 796 self.outtypes, outfused_types = self._get_types(self.outcodes) 797 self.fused_types.update(outfused_types) 798 self.invars, self.outvars = self._get_vars() 799 800 def _get_codes(self): 801 inarg_num, outarg_num = None, None 802 all_inp, all_outp = [], [] 803 for _, inarg, outarg, ret, _ in self.signatures: 804 outp = re.sub(r'\*.*', '', ret) + outarg 805 if inarg_num is None: 806 inarg_num = len(inarg) 807 outarg_num = len(outp) 808 inp, outp = list(iter_variants(inarg, outp))[0] 809 all_inp.append(inp) 810 all_outp.append(outp) 811 812 incodes = [] 813 for n in range(inarg_num): 814 codes = unique([x[n] for x in all_inp]) 815 codes.sort() 816 incodes.append(''.join(codes)) 817 outcodes = [] 818 for n in range(outarg_num): 819 codes = unique([x[n] for x in all_outp]) 820 codes.sort() 821 outcodes.append(''.join(codes)) 822 823 return tuple(incodes), tuple(outcodes) 824 825 def _get_types(self, codes): 826 all_types = [] 827 fused_types = set() 828 for code in codes: 829 if len(code) == 1: 830 # It's not a fused type 831 all_types.append((CY_TYPES[code], code)) 832 else: 833 # It's a fused type 834 fused_type, dec = generate_fused_type(code) 835 fused_types.add(dec) 836 all_types.append((fused_type, code)) 837 return all_types, fused_types 838 839 def _get_vars(self): 840 invars = ["x{}".format(n) for n in range(len(self.intypes))] 841 outvars = ["y{}".format(n) for n in range(len(self.outtypes))] 842 return invars, outvars 843 844 def _get_conditional(self, types, codes, adverb): 845 """Generate an if/elif/else clause that selects a specialization of 846 fused types. 847 848 """ 849 clauses = [] 850 seen = set() 851 for (typ, typcode), code in zip(types, codes): 852 if len(typcode) == 1: 853 continue 854 if typ not in seen: 855 clauses.append(f"{typ} is {underscore(CY_TYPES[code])}") 856 seen.add(typ) 857 if clauses and adverb != "else": 858 line = "{} {}:".format(adverb, " and ".join(clauses)) 859 elif clauses and adverb == "else": 860 line = "else:" 861 else: 862 line = None 863 return line 864 865 def _get_incallvars(self, intypes, c): 866 """Generate pure input variables to a specialization, 867 i.e., variables that aren't used to return a value. 868 869 """ 870 incallvars = [] 871 for n, intype in enumerate(intypes): 872 var = self.invars[n] 873 if c and intype == "double complex": 874 var = npy_cdouble_from_double_complex(var) 875 incallvars.append(var) 876 return incallvars 877 878 def _get_outcallvars(self, outtypes, c): 879 """Generate output variables to a specialization, 880 i.e., pointers that are used to return values. 881 882 """ 883 outcallvars, tmpvars, casts = [], [], [] 884 # If there are more out variables than out types, we want the 885 # tail of the out variables 886 start = len(self.outvars) - len(outtypes) 887 outvars = self.outvars[start:] 888 for n, (var, outtype) in enumerate(zip(outvars, outtypes)): 889 if c and outtype == "double complex": 890 tmp = "tmp{}".format(n) 891 tmpvars.append(tmp) 892 outcallvars.append("&{}".format(tmp)) 893 tmpcast = double_complex_from_npy_cdouble(tmp) 894 casts.append("{}[0] = {}".format(var, tmpcast)) 895 else: 896 outcallvars.append("{}".format(var)) 897 return outcallvars, tmpvars, casts 898 899 def _get_nan_decs(self): 900 """Set all variables to nan for specializations of fused types for 901 which don't have signatures. 902 903 """ 904 # Set non fused-type variables to nan 905 tab = " "*4 906 fused_types, lines = [], [tab + "else:"] 907 seen = set() 908 for outvar, outtype, code in zip(self.outvars, self.outtypes, 909 self.outcodes): 910 if len(code) == 1: 911 line = "{}[0] = {}".format(outvar, NAN_VALUE[code]) 912 lines.append(2*tab + line) 913 else: 914 fused_type = outtype 915 name, _ = fused_type 916 if name not in seen: 917 fused_types.append(fused_type) 918 seen.add(name) 919 if not fused_types: 920 return lines 921 922 # Set fused-type variables to nan 923 all_codes = tuple([codes for _unused, codes in fused_types]) 924 925 codelens = [len(x) for x in all_codes] 926 last = numpy.prod(codelens) - 1 927 for m, codes in enumerate(itertools.product(*all_codes)): 928 fused_codes, decs = [], [] 929 for n, fused_type in enumerate(fused_types): 930 code = codes[n] 931 fused_codes.append(underscore(CY_TYPES[code])) 932 for nn, outvar in enumerate(self.outvars): 933 if self.outtypes[nn] == fused_type: 934 line = "{}[0] = {}".format(outvar, NAN_VALUE[code]) 935 decs.append(line) 936 if m == 0: 937 adverb = "if" 938 elif m == last: 939 adverb = "else" 940 else: 941 adverb = "elif" 942 cond = self._get_conditional(fused_types, codes, adverb) 943 lines.append(2*tab + cond) 944 lines.extend([3*tab + x for x in decs]) 945 return lines 946 947 def _get_tmp_decs(self, all_tmpvars): 948 """Generate the declarations of any necessary temporary 949 variables. 950 951 """ 952 tab = " "*4 953 tmpvars = list(all_tmpvars) 954 tmpvars.sort() 955 tmpdecs = [tab + "cdef npy_cdouble {}".format(tmpvar) 956 for tmpvar in tmpvars] 957 return tmpdecs 958 959 def _get_python_wrap(self): 960 """Generate a Python wrapper for functions which pass their 961 arguments as pointers. 962 963 """ 964 tab = " "*4 965 body, callvars = [], [] 966 for (intype, _), invar in zip(self.intypes, self.invars): 967 callvars.append("{} {}".format(intype, invar)) 968 line = "def _{}_pywrap({}):".format(self.name, ", ".join(callvars)) 969 body.append(line) 970 for (outtype, _), outvar in zip(self.outtypes, self.outvars): 971 line = "cdef {} {}".format(outtype, outvar) 972 body.append(tab + line) 973 addr_outvars = [f"&{x}" for x in self.outvars] 974 line = "{}({}, {})".format(self.name, ", ".join(self.invars), 975 ", ".join(addr_outvars)) 976 body.append(tab + line) 977 line = "return {}".format(", ".join(self.outvars)) 978 body.append(tab + line) 979 body = "\n".join(body) 980 return body 981 982 def _get_common(self, signum, sig): 983 """Generate code common to all the _generate_* methods.""" 984 tab = " "*4 985 func_name, incodes, outcodes, retcode, header = sig 986 # Convert ints to longs; cf. iter_variants() 987 incodes = incodes.replace('i', 'l') 988 outcodes = outcodes.replace('i', 'l') 989 retcode = retcode.replace('i', 'l') 990 991 if header.endswith("h"): 992 c = True 993 else: 994 c = False 995 if header.endswith("++"): 996 cpp = True 997 else: 998 cpp = False 999 1000 intypes = [CY_TYPES[x] for x in incodes] 1001 outtypes = [CY_TYPES[x] for x in outcodes] 1002 retcode = re.sub(r'\*.*', '', retcode) 1003 if not retcode: 1004 retcode = 'v' 1005 rettype = CY_TYPES[retcode] 1006 1007 if cpp: 1008 # Functions from _ufuncs_cxx are exported as a void* 1009 # pointers; cast them to the correct types 1010 func_name = "scipy.special._ufuncs_cxx._export_{}".format(func_name) 1011 func_name = "(<{}(*)({}) nogil>{})"\ 1012 .format(rettype, ", ".join(intypes + outtypes), func_name) 1013 else: 1014 func_name = self.cython_func_name(func_name, specialized=True) 1015 1016 if signum == 0: 1017 adverb = "if" 1018 else: 1019 adverb = "elif" 1020 cond = self._get_conditional(self.intypes, incodes, adverb) 1021 if cond: 1022 lines = [tab + cond] 1023 sp = 2*tab 1024 else: 1025 lines = [] 1026 sp = tab 1027 1028 return func_name, incodes, outcodes, retcode, \ 1029 intypes, outtypes, rettype, c, lines, sp 1030 1031 def _generate_from_return_and_no_outargs(self): 1032 tab = " "*4 1033 specs, body = [], [] 1034 for signum, sig in enumerate(self.signatures): 1035 func_name, incodes, outcodes, retcode, intypes, outtypes, \ 1036 rettype, c, lines, sp = self._get_common(signum, sig) 1037 body.extend(lines) 1038 1039 # Generate the call to the specialized function 1040 callvars = self._get_incallvars(intypes, c) 1041 call = "{}({})".format(func_name, ", ".join(callvars)) 1042 if c and rettype == "double complex": 1043 call = double_complex_from_npy_cdouble(call) 1044 line = sp + "return {}".format(call) 1045 body.append(line) 1046 sig = "{}->{}".format(incodes, retcode) 1047 specs.append(sig) 1048 1049 if len(specs) > 1: 1050 # Return nan for signatures without a specialization 1051 body.append(tab + "else:") 1052 outtype, outcodes = self.outtypes[0] 1053 last = len(outcodes) - 1 1054 if len(outcodes) == 1: 1055 line = "return {}".format(NAN_VALUE[outcodes]) 1056 body.append(2*tab + line) 1057 else: 1058 for n, code in enumerate(outcodes): 1059 if n == 0: 1060 adverb = "if" 1061 elif n == last: 1062 adverb = "else" 1063 else: 1064 adverb = "elif" 1065 cond = self._get_conditional(self.outtypes, code, adverb) 1066 body.append(2*tab + cond) 1067 line = "return {}".format(NAN_VALUE[code]) 1068 body.append(3*tab + line) 1069 1070 # Generate the head of the function 1071 callvars, head = [], [] 1072 for n, (intype, _) in enumerate(self.intypes): 1073 callvars.append("{} {}".format(intype, self.invars[n])) 1074 (outtype, _) = self.outtypes[0] 1075 dec = "cpdef {} {}({}) nogil".format(outtype, self.name, ", ".join(callvars)) 1076 head.append(dec + ":") 1077 head.append(tab + '"""{}"""'.format(self.doc)) 1078 1079 src = "\n".join(head + body) 1080 return dec, src, specs 1081 1082 def _generate_from_outargs_and_no_return(self): 1083 tab = " "*4 1084 all_tmpvars = set() 1085 specs, body = [], [] 1086 for signum, sig in enumerate(self.signatures): 1087 func_name, incodes, outcodes, retcode, intypes, outtypes, \ 1088 rettype, c, lines, sp = self._get_common(signum, sig) 1089 body.extend(lines) 1090 1091 # Generate the call to the specialized function 1092 callvars = self._get_incallvars(intypes, c) 1093 outcallvars, tmpvars, casts = self._get_outcallvars(outtypes, c) 1094 callvars.extend(outcallvars) 1095 all_tmpvars.update(tmpvars) 1096 1097 call = "{}({})".format(func_name, ", ".join(callvars)) 1098 body.append(sp + call) 1099 body.extend([sp + x for x in casts]) 1100 if len(outcodes) == 1: 1101 sig = "{}->{}".format(incodes, outcodes) 1102 specs.append(sig) 1103 else: 1104 sig = "{}*{}->v".format(incodes, outcodes) 1105 specs.append(sig) 1106 1107 if len(specs) > 1: 1108 lines = self._get_nan_decs() 1109 body.extend(lines) 1110 1111 if len(self.outvars) == 1: 1112 line = "return {}[0]".format(self.outvars[0]) 1113 body.append(tab + line) 1114 1115 # Generate the head of the function 1116 callvars, head = [], [] 1117 for invar, (intype, _) in zip(self.invars, self.intypes): 1118 callvars.append("{} {}".format(intype, invar)) 1119 if len(self.outvars) > 1: 1120 for outvar, (outtype, _) in zip(self.outvars, self.outtypes): 1121 callvars.append("{} *{}".format(outtype, outvar)) 1122 if len(self.outvars) == 1: 1123 outtype, _ = self.outtypes[0] 1124 dec = "cpdef {} {}({}) nogil".format(outtype, self.name, ", ".join(callvars)) 1125 else: 1126 dec = "cdef void {}({}) nogil".format(self.name, ", ".join(callvars)) 1127 head.append(dec + ":") 1128 head.append(tab + '"""{}"""'.format(self.doc)) 1129 if len(self.outvars) == 1: 1130 outvar = self.outvars[0] 1131 outtype, _ = self.outtypes[0] 1132 line = "cdef {} {}".format(outtype, outvar) 1133 head.append(tab + line) 1134 head.extend(self._get_tmp_decs(all_tmpvars)) 1135 1136 src = "\n".join(head + body) 1137 return dec, src, specs 1138 1139 def _generate_from_outargs_and_return(self): 1140 tab = " "*4 1141 all_tmpvars = set() 1142 specs, body = [], [] 1143 for signum, sig in enumerate(self.signatures): 1144 func_name, incodes, outcodes, retcode, intypes, outtypes, \ 1145 rettype, c, lines, sp = self._get_common(signum, sig) 1146 body.extend(lines) 1147 1148 # Generate the call to the specialized function 1149 callvars = self._get_incallvars(intypes, c) 1150 outcallvars, tmpvars, casts = self._get_outcallvars(outtypes, c) 1151 callvars.extend(outcallvars) 1152 all_tmpvars.update(tmpvars) 1153 call = "{}({})".format(func_name, ", ".join(callvars)) 1154 if c and rettype == "double complex": 1155 call = double_complex_from_npy_cdouble(call) 1156 call = "{}[0] = {}".format(self.outvars[0], call) 1157 body.append(sp + call) 1158 body.extend([sp + x for x in casts]) 1159 sig = "{}*{}->v".format(incodes, outcodes + retcode) 1160 specs.append(sig) 1161 1162 if len(specs) > 1: 1163 lines = self._get_nan_decs() 1164 body.extend(lines) 1165 1166 # Generate the head of the function 1167 callvars, head = [], [] 1168 for invar, (intype, _) in zip(self.invars, self.intypes): 1169 callvars.append("{} {}".format(intype, invar)) 1170 for outvar, (outtype, _) in zip(self.outvars, self.outtypes): 1171 callvars.append("{} *{}".format(outtype, outvar)) 1172 dec = "cdef void {}({}) nogil".format(self.name, ", ".join(callvars)) 1173 head.append(dec + ":") 1174 head.append(tab + '"""{}"""'.format(self.doc)) 1175 head.extend(self._get_tmp_decs(all_tmpvars)) 1176 1177 src = "\n".join(head + body) 1178 return dec, src, specs 1179 1180 def generate(self): 1181 _, _, outcodes, retcode, _ = self.signatures[0] 1182 retcode = re.sub(r'\*.*', '', retcode) 1183 if not retcode: 1184 retcode = 'v' 1185 1186 if len(outcodes) == 0 and retcode != 'v': 1187 dec, src, specs = self._generate_from_return_and_no_outargs() 1188 elif len(outcodes) > 0 and retcode == 'v': 1189 dec, src, specs = self._generate_from_outargs_and_no_return() 1190 elif len(outcodes) > 0 and retcode != 'v': 1191 dec, src, specs = self._generate_from_outargs_and_return() 1192 else: 1193 raise ValueError("Invalid signature") 1194 1195 if len(self.outvars) > 1: 1196 wrap = self._get_python_wrap() 1197 else: 1198 wrap = None 1199 1200 return dec, src, specs, self.fused_types, wrap 1201 1202 1203def get_declaration(ufunc, c_name, c_proto, cy_proto, header, 1204 proto_h_filename): 1205 """ 1206 Construct a Cython declaration of a function coming either from a 1207 pxd or a header file. Do sufficient tricks to enable compile-time 1208 type checking against the signature expected by the ufunc. 1209 1210 """ 1211 defs = [] 1212 defs_h = [] 1213 1214 var_name = c_name.replace('[', '_').replace(']', '_').replace(' ', '_') 1215 1216 if header.endswith('.pxd'): 1217 defs.append("from .%s cimport %s as %s" % ( 1218 header[:-4], ufunc.cython_func_name(c_name, prefix=""), 1219 ufunc.cython_func_name(c_name))) 1220 1221 # check function signature at compile time 1222 proto_name = '_proto_%s_t' % var_name 1223 defs.append("ctypedef %s" % (cy_proto.replace('(*)', proto_name))) 1224 defs.append("cdef %s *%s_var = &%s" % ( 1225 proto_name, proto_name, ufunc.cython_func_name(c_name, specialized=True))) 1226 else: 1227 # redeclare the function, so that the assumed 1228 # signature is checked at compile time 1229 new_name = "%s \"%s\"" % (ufunc.cython_func_name(c_name), c_name) 1230 defs.append("cdef extern from \"%s\":" % proto_h_filename) 1231 defs.append(" cdef %s" % (cy_proto.replace('(*)', new_name))) 1232 defs_h.append("#include \"%s\"" % header) 1233 defs_h.append("%s;" % (c_proto.replace('(*)', c_name))) 1234 1235 return defs, defs_h, var_name 1236 1237 1238def generate_ufuncs(fn_prefix, cxx_fn_prefix, ufuncs): 1239 filename = fn_prefix + ".pyx" 1240 proto_h_filename = fn_prefix + '_defs.h' 1241 1242 cxx_proto_h_filename = cxx_fn_prefix + '_defs.h' 1243 cxx_pyx_filename = cxx_fn_prefix + ".pyx" 1244 cxx_pxd_filename = cxx_fn_prefix + ".pxd" 1245 1246 toplevel = "" 1247 1248 # for _ufuncs* 1249 defs = [] 1250 defs_h = [] 1251 all_loops = {} 1252 1253 # for _ufuncs_cxx* 1254 cxx_defs = [] 1255 cxx_pxd_defs = [ 1256 "from . cimport sf_error", 1257 "cdef void _set_action(sf_error.sf_error_t, sf_error.sf_action_t) nogil" 1258 ] 1259 cxx_defs_h = [] 1260 1261 ufuncs.sort(key=lambda u: u.name) 1262 1263 for ufunc in ufuncs: 1264 # generate function declaration and type checking snippets 1265 cfuncs = ufunc.get_prototypes() 1266 for c_name, c_proto, cy_proto, header in cfuncs: 1267 if header.endswith('++'): 1268 header = header[:-2] 1269 1270 # for the CXX module 1271 item_defs, item_defs_h, var_name = get_declaration(ufunc, c_name, c_proto, cy_proto, 1272 header, cxx_proto_h_filename) 1273 cxx_defs.extend(item_defs) 1274 cxx_defs_h.extend(item_defs_h) 1275 1276 cxx_defs.append("cdef void *_export_%s = <void*>%s" % ( 1277 var_name, ufunc.cython_func_name(c_name, specialized=True, override=False))) 1278 cxx_pxd_defs.append("cdef void *_export_%s" % (var_name,)) 1279 1280 # let cython grab the function pointer from the c++ shared library 1281 ufunc.function_name_overrides[c_name] = "scipy.special._ufuncs_cxx._export_" + var_name 1282 else: 1283 # usual case 1284 item_defs, item_defs_h, _ = get_declaration(ufunc, c_name, c_proto, cy_proto, header, 1285 proto_h_filename) 1286 defs.extend(item_defs) 1287 defs_h.extend(item_defs_h) 1288 1289 # ufunc creation code snippet 1290 t = ufunc.generate(all_loops) 1291 toplevel += t + "\n" 1292 1293 # Produce output 1294 toplevel = "\n".join(sorted(all_loops.values()) + defs + [toplevel]) 1295 # Generate an `__all__` for the module 1296 all_ufuncs = ( 1297 [ 1298 "'{}'".format(ufunc.name) 1299 for ufunc in ufuncs if not ufunc.name.startswith('_') 1300 ] 1301 + ["'geterr'", "'seterr'", "'errstate'", "'jn'"] 1302 ) 1303 module_all = '__all__ = [{}]'.format(', '.join(all_ufuncs)) 1304 1305 with open(filename, 'w') as f: 1306 f.write(UFUNCS_EXTRA_CODE_COMMON) 1307 f.write(UFUNCS_EXTRA_CODE) 1308 f.write(module_all) 1309 f.write("\n") 1310 f.write(toplevel) 1311 f.write(UFUNCS_EXTRA_CODE_BOTTOM) 1312 1313 defs_h = unique(defs_h) 1314 with open(proto_h_filename, 'w') as f: 1315 f.write("#ifndef UFUNCS_PROTO_H\n#define UFUNCS_PROTO_H 1\n") 1316 f.write("\n".join(defs_h)) 1317 f.write("\n#endif\n") 1318 1319 cxx_defs_h = unique(cxx_defs_h) 1320 with open(cxx_proto_h_filename, 'w') as f: 1321 f.write("#ifndef UFUNCS_PROTO_H\n#define UFUNCS_PROTO_H 1\n") 1322 f.write("\n".join(cxx_defs_h)) 1323 f.write("\n#endif\n") 1324 1325 with open(cxx_pyx_filename, 'w') as f: 1326 f.write(UFUNCS_EXTRA_CODE_COMMON) 1327 f.write("\n") 1328 f.write("\n".join(cxx_defs)) 1329 f.write("\n# distutils: language = c++\n") 1330 1331 with open(cxx_pxd_filename, 'w') as f: 1332 f.write("\n".join(cxx_pxd_defs)) 1333 1334 1335def generate_fused_funcs(modname, ufunc_fn_prefix, fused_funcs): 1336 pxdfile = modname + ".pxd" 1337 pyxfile = modname + ".pyx" 1338 proto_h_filename = ufunc_fn_prefix + '_defs.h' 1339 1340 sources = [] 1341 declarations = [] 1342 # Code for benchmarks 1343 bench_aux = [] 1344 fused_types = set() 1345 # Parameters for the tests 1346 doc = [] 1347 defs = [] 1348 1349 for func in fused_funcs: 1350 if func.name.startswith("_"): 1351 # Don't try to deal with functions that have extra layers 1352 # of wrappers. 1353 continue 1354 1355 # Get the function declaration for the .pxd and the source 1356 # code for the .pyx 1357 dec, src, specs, func_fused_types, wrap = func.generate() 1358 declarations.append(dec) 1359 sources.append(src) 1360 if wrap: 1361 sources.append(wrap) 1362 fused_types.update(func_fused_types) 1363 1364 # Declare the specializations 1365 cfuncs = func.get_prototypes(nptypes_for_h=True) 1366 for c_name, c_proto, cy_proto, header in cfuncs: 1367 if header.endswith('++'): 1368 # We grab the c++ functions from the c++ module 1369 continue 1370 item_defs, _, _ = get_declaration(func, c_name, c_proto, 1371 cy_proto, header, 1372 proto_h_filename) 1373 defs.extend(item_defs) 1374 1375 # Add a line to the documentation 1376 doc.append(generate_doc(func.name, specs)) 1377 1378 # Generate code for benchmarks 1379 if func.name in CYTHON_SPECIAL_BENCHFUNCS: 1380 for codes in CYTHON_SPECIAL_BENCHFUNCS[func.name]: 1381 pybench, cybench = generate_bench(func.name, codes) 1382 bench_aux.extend([pybench, cybench]) 1383 1384 fused_types = list(fused_types) 1385 fused_types.sort() 1386 1387 with open(pxdfile, 'w') as f: 1388 f.write(CYTHON_SPECIAL_PXD) 1389 f.write("\n") 1390 f.write("\n\n".join(fused_types)) 1391 f.write("\n\n") 1392 f.write("\n".join(declarations)) 1393 with open(pyxfile, 'w') as f: 1394 header = CYTHON_SPECIAL_PYX 1395 header = header.replace("FUNCLIST", "\n".join(doc)) 1396 f.write(header) 1397 f.write("\n") 1398 f.write("\n".join(defs)) 1399 f.write("\n\n") 1400 f.write("\n\n".join(sources)) 1401 f.write("\n\n") 1402 f.write("\n\n".join(bench_aux)) 1403 1404 1405def generate_ufuncs_type_stubs(module_name: str, ufuncs: List[Ufunc]): 1406 stubs, module_all = [], [] 1407 for ufunc in ufuncs: 1408 stubs.append(f'{ufunc.name}: np.ufunc') 1409 if not ufunc.name.startswith('_'): 1410 module_all.append(f"'{ufunc.name}'") 1411 # jn is an alias for jv. 1412 module_all.append("'jn'") 1413 stubs.append('jn: np.ufunc') 1414 module_all.sort() 1415 stubs.sort() 1416 1417 contents = STUBS.format( 1418 ALL=',\n '.join(module_all), 1419 STUBS='\n'.join(stubs), 1420 ) 1421 1422 stubs_file = f'{module_name}.pyi' 1423 with open(stubs_file, 'w') as f: 1424 f.write(contents) 1425 1426 1427def unique(lst): 1428 """ 1429 Return a list without repeated entries (first occurrence is kept), 1430 preserving order. 1431 """ 1432 seen = set() 1433 new_lst = [] 1434 for item in lst: 1435 if item in seen: 1436 continue 1437 seen.add(item) 1438 new_lst.append(item) 1439 return new_lst 1440 1441 1442def all_newer(src_files, dst_files): 1443 from distutils.dep_util import newer 1444 return all(os.path.exists(dst) and newer(dst, src) 1445 for dst in dst_files for src in src_files) 1446 1447 1448def main(): 1449 p = optparse.OptionParser(usage=(__doc__ or '').strip()) 1450 options, args = p.parse_args() 1451 if len(args) != 0: 1452 p.error('invalid number of arguments') 1453 1454 pwd = os.path.dirname(__file__) 1455 src_files = (os.path.abspath(__file__), 1456 os.path.abspath(os.path.join(pwd, 'functions.json')), 1457 os.path.abspath(os.path.join(pwd, 'add_newdocs.py'))) 1458 dst_files = ('_ufuncs.pyx', 1459 '_ufuncs_defs.h', 1460 '_ufuncs_cxx.pyx', 1461 '_ufuncs_cxx.pxd', 1462 '_ufuncs_cxx_defs.h', 1463 '_ufuncs.pyi', 1464 'cython_special.pyx', 1465 'cython_special.pxd') 1466 1467 os.chdir(BASE_DIR) 1468 1469 if all_newer(src_files, dst_files): 1470 print("scipy/special/_generate_pyx.py: all files up-to-date") 1471 return 1472 1473 ufuncs, fused_funcs = [], [] 1474 with open('functions.json') as data: 1475 functions = json.load(data) 1476 for f, sig in functions.items(): 1477 ufuncs.append(Ufunc(f, sig)) 1478 fused_funcs.append(FusedFunc(f, sig)) 1479 generate_ufuncs("_ufuncs", "_ufuncs_cxx", ufuncs) 1480 generate_ufuncs_type_stubs("_ufuncs", ufuncs) 1481 generate_fused_funcs("cython_special", "_ufuncs", fused_funcs) 1482 1483 1484if __name__ == "__main__": 1485 main() 1486