1'''Generate Cython PYX wrappers for Boost stats distributions.'''
2
3from typing import NamedTuple
4from warnings import warn
5from textwrap import dedent
6from shutil import copyfile
7import pathlib
8import os
9
10from gen_func_defs_pxd import (  # type: ignore
11    _gen_func_defs_pxd)
12from _info import (  # type: ignore
13    _x_funcs, _no_x_funcs, _klass_mapper)
14
15class _MethodDef(NamedTuple):
16    ufunc_name: str
17    num_inputs: int
18    boost_func_name: str
19
20
21def _ufunc_gen(scipy_dist: str, types: list, ctor_args: tuple,
22               filename: str, boost_dist: str, x_funcs: list,
23               no_x_funcs: list):
24    '''
25    We need methods defined for each rv_continuous/_discrete internal method:
26        i.e.: _pdf, _cdf, etc.
27    Some of these methods take constructor arguments and 1 extra argument,
28        e.g.: _pdf(x, *ctor_args), _ppf(q, *ctor_args)
29    while some of the methods take only constructor arguments:
30        e.g.: _stats(*ctor_args)
31    '''
32    num_ctor_args = len(ctor_args)
33    methods = [_MethodDef(
34        ufunc_name=f'_{scipy_dist}_{x_func}',
35        num_inputs=num_ctor_args+1,  # +1 for the x argument
36        # PDF for the beta distribution has a custom wrapper:
37        boost_func_name=x_func if boost_dist != 'beta_distribution'
38        else 'pdf_beta' if x_func == 'pdf' else x_func,
39    ) for x_func in x_funcs]
40    methods += [_MethodDef(
41        ufunc_name=f'_{scipy_dist}_{func}',
42        num_inputs=num_ctor_args,
43        boost_func_name=func,
44    ) for func in no_x_funcs]
45
46    # Identify potential ufunc issues:
47    no_input_methods = [m for m in methods if m.num_inputs == 0]
48    if no_input_methods:
49        raise ValueError("ufuncs must have >0 arguments! "
50                         f"Cannot construct these ufuncs: {no_input_methods}")
51
52    boost_hdr_name = boost_dist.split('_distribution')[0]
53    unique_num_inputs = set({m.num_inputs for m in methods})
54    has_NPY_FLOAT16 = 'NPY_FLOAT16' in types
55    has_NPY_LONGDOUBLE = 'NPY_LONGDOUBLE' in types
56    line_joiner = ',\n    ' + ' '*12
57    num_types = len(types)
58    loop_fun = 'PyUFunc_T'
59    func_defs_cimports = line_joiner.join(
60        f"boost_{m.boost_func_name}{num_ctor_args}" for m in methods)
61    nontype_params = line_joiner[1:].join(
62        f'ctypedef int NINPUTS{n} "{n}"' for n in unique_num_inputs)
63
64    with open(filename, 'w') as fp:
65        boost_hdr = f'boost/math/distributions/{boost_hdr_name}.hpp'
66        fp.write(dedent(f'''\
67            # distutils: language = c++
68            # cython: language_level=3
69
70            # This file was generated by stats/_boost/include/code_gen.py
71            # All modifications to this file will be overwritten.
72
73            from numpy cimport (
74                import_array,
75                import_ufunc,
76                PyUFunc_FromFuncAndData,
77                PyUFuncGenericFunction,
78                PyUFunc_None,
79                {line_joiner.join(types)}
80            )
81            from templated_pyufunc cimport PyUFunc_T
82            from func_defs cimport (
83                {func_defs_cimports},
84            )
85            cdef extern from "{boost_hdr}" namespace "boost::math" nogil:
86                cdef cppclass {boost_dist} nogil:
87                    pass
88
89            # Workaround for Cython's lack of non-type template parameter
90            # support
91            cdef extern from * nogil:
92                {nontype_params}
93
94            _DUMMY = ""
95            import_array()
96            import_ufunc()
97            '''))
98
99        if has_NPY_LONGDOUBLE:
100            fp.write('ctypedef long double longdouble\n\n')
101        if has_NPY_FLOAT16:
102            warn('Boost stats NPY_FLOAT16 ufunc generation not '
103                 'currently not supported!')
104
105        # Generate ufuncs for each method
106        for ii, m in enumerate(methods):
107            fp.write(dedent(f'''
108                cdef PyUFuncGenericFunction loop_func{ii}[{num_types}]
109                cdef void* func{ii}[1*{num_types}]
110                cdef char types{ii}[{m.num_inputs+1}*{num_types}]
111                '''))  # m.num_inputs+1 for output arg
112
113            for jj, T in enumerate(types):
114                ctype = {
115                    'NPY_LONGDOUBLE': 'longdouble',
116                    'NPY_DOUBLE': 'double',
117                    'NPY_FLOAT': 'float',
118                    'NPY_FLOAT16': 'npy_half',
119                }[T]
120                boost_fun = f'boost_{m.boost_func_name}{num_ctor_args}'
121                type_str = ", ".join([ctype]*(1+num_ctor_args))
122                boost_tmpl = f'{boost_dist}, {type_str}'
123                N = m.num_inputs
124                fp.write(f'''\
125loop_func{ii}[{jj}] = <PyUFuncGenericFunction>{loop_fun}[{ctype}, NINPUTS{N}]
126func{ii}[{jj}] = <void*>{boost_fun}[{boost_tmpl}]
127''')
128                for tidx in range(m.num_inputs+1):
129                    fp.write(
130                        f'types{ii}[{tidx}+{jj}*{m.num_inputs+1}] = {T}\n')
131            arg_list_str = ', '.join(ctor_args)
132            if m.boost_func_name in x_funcs:
133                arg_list_str = 'x, ' + arg_list_str
134            fp.write(dedent(f'''
135                {m.ufunc_name} = PyUFunc_FromFuncAndData(
136                    loop_func{ii},
137                    func{ii},
138                    types{ii},
139                    {num_types},  # number of supported input types
140                    {m.num_inputs},  # number of input args
141                    1,  # number of output args
142                    PyUFunc_None,  # `identity` element, never mind this
143                    "{m.ufunc_name}",  # function name
144                    ("{m.ufunc_name}({arg_list_str}) -> computes "
145                     "{m.boost_func_name} of {scipy_dist} distribution"),
146                    0  # unused
147                )
148                '''))
149
150
151if __name__ == '__main__':
152    # create target directory
153    _boost_dir = pathlib.Path(__file__).resolve().parent.parent
154    src_dir = _boost_dir / 'src'
155    src_dir.mkdir(exist_ok=True, parents=True)
156
157    # copy contents of include into directory to satisfy Cython
158    # PXD include conditions
159    inc_dir = _boost_dir / 'include'
160    src = 'templated_pyufunc.pxd'
161    copyfile(inc_dir / src, src_dir / src)
162
163    # generate the PXD and PYX wrappers
164    _gen_func_defs_pxd(
165        f'{src_dir}/func_defs.pxd',
166        x_funcs=_x_funcs,
167        no_x_funcs=_no_x_funcs)
168    for b, s in _klass_mapper.items():
169        _ufunc_gen(
170            scipy_dist=s.scipy_name,
171            types=['NPY_FLOAT', 'NPY_DOUBLE', 'NPY_LONGDOUBLE'],
172            ctor_args=s.ctor_args,
173            filename=f'{src_dir}/{s.scipy_name}_ufunc.pyx',
174            boost_dist=f'{b}_distribution',
175            x_funcs=_x_funcs,
176            no_x_funcs=_no_x_funcs,
177        )
178