1# -*- coding: utf-8; mode: cython -*-
2# distutils: language = c++
3
4from cpython.object cimport PyObject
5from libcpp cimport bool
6from libcpp.string cimport string
7from libcpp.vector cimport vector
8cimport numpy as cnp
9
10from anyode_numpy cimport PyOdeSys
11from gsl_odeiv2_cxx cimport styp_from_name, fpes as _fpes
12from gsl_odeiv2_anyode cimport simple_adaptive, simple_predefined
13
14import numpy as np
15
16ctypedef PyOdeSys[double, int] PyOdeSys_t
17
18cnp.import_array()  # Numpy C-API initialization
19
20requires_jac = ('rk1imp', 'rk2imp', 'rk4imp', 'bsimp', 'msbdf')
21steppers = requires_jac + ('rk2', 'rk4', 'rkf45', 'rkck', 'rk8pd', 'msadams')
22fpes = {str(k.decode('utf-8')): v for k, v in dict(_fpes).items()}
23
24cdef dict get_last_info(PyOdeSys_t * odesys, success=True):
25    info = {str(k.decode('utf-8')): v for k, v in dict(odesys.current_info.nfo_int).items()}
26    info.update({str(k.decode('utf-8')): v for k, v in dict(odesys.current_info.nfo_dbl).items()})
27    info.update({str(k.decode('utf-8')): np.array(v, dtype=np.float64) for k, v in dict(odesys.current_info.nfo_vecdbl).items()})
28    info.update({str(k.decode('utf-8')): np.array(v, dtype=int) for k, v in dict(odesys.current_info.nfo_vecint).items()})
29    info['nfev'] = odesys.nfev
30    info['njev'] = odesys.njev
31    info['success'] = success
32    return info
33
34def adaptive(rhs, jac, cnp.ndarray[cnp.float64_t, mode='c'] y0, double x0, double xend, double atol,
35             double rtol, str method='bsimp', long int nsteps=500, double dx0=0.0, double dx_min=0.0,
36             double dx_max=0.0, int autorestart=0, bool return_on_error=False, cb_kwargs=None,
37             bool record_rhs_xvals=False, bool record_jac_xvals=False, bool record_order=False,
38             bool record_fpe=False, dx0cb=None, dx_max_cb=None):
39    cdef:
40        int ny = y0.shape[y0.ndim - 1]
41        int mlower=-1, mupper=-1, nquads=0, nroots=0, nnz=-1
42        PyOdeSys_t * odesys
43
44    if method in requires_jac and jac is None:
45        raise ValueError("Method requires explicit jacobian callback")
46    if np.isnan(y0).any():
47        raise ValueError("NaN found in y0")
48
49    odesys = new PyOdeSys_t(ny, <PyObject *>rhs, <PyObject *>jac, NULL, NULL, NULL,
50                          <PyObject *>cb_kwargs, mlower, mupper, nquads, nroots, <PyObject *>dx0cb, <PyObject *>dx_max_cb, nnz)
51    odesys.record_rhs_xvals = record_rhs_xvals
52    odesys.record_jac_xvals = record_jac_xvals
53    odesys.record_order = record_order
54    odesys.record_fpe = record_fpe
55
56    try:
57        xout, yout = map(np.asarray, simple_adaptive[PyOdeSys_t](
58            odesys, atol, rtol, styp_from_name(method.lower().encode('UTF-8')),
59            &y0[0], x0, xend, nsteps, dx0, dx_min, dx_max, autorestart, return_on_error))
60        info = get_last_info(odesys, False if return_on_error and xout[-1] != xend else True)
61        info['atol'], info['rtol'] = atol, rtol
62        return xout, yout.reshape((xout.size, ny)), info
63    finally:
64        del odesys
65
66
67def predefined(rhs, jac,
68               cnp.ndarray[cnp.float64_t, mode='c'] y0,
69               cnp.ndarray[cnp.float64_t, ndim=1] xout,
70               double atol, double rtol, str method='bsimp', int nsteps=500, double dx0=0.0,
71               double dx_min=0.0, double dx_max=0.0, int autorestart=0,
72               bool return_on_error=False, cb_kwargs=None, bool record_rhs_xvals=False,
73               bool record_jac_xvals=False, bool record_order=False, bool record_fpe=False,
74               dx0cb=None, dx_max_cb=None):
75    cdef:
76        int ny = y0.shape[y0.ndim - 1]
77        cnp.ndarray[cnp.float64_t, ndim=2] yout = np.empty((xout.size, ny))
78        int nreached
79        PyOdeSys_t * odesys
80        int mlower=-1, mupper=-1, nquads=0, nroots=0, nnz=-1
81
82    if method in requires_jac and jac is None:
83        raise ValueError("Method requires explicit jacobian callback")
84    if np.isnan(y0).any():
85        raise ValueError("NaN found in y0")
86    odesys = new PyOdeSys_t(ny, <PyObject *>rhs, <PyObject *>jac, NULL, NULL, NULL, <PyObject *>cb_kwargs,
87                          mlower, mupper, nquads, nroots, <PyObject *>dx0cb, <PyObject *>dx_max_cb, nnz)
88    odesys.record_rhs_xvals = record_rhs_xvals
89    odesys.record_jac_xvals = record_jac_xvals
90    odesys.record_order = record_order
91    odesys.record_fpe = record_fpe
92    try:
93        nreached = simple_predefined[PyOdeSys_t](odesys, atol, rtol, styp_from_name(method.lower().encode('UTF-8')), &y0[0],
94                                               xout.size, &xout[0], <double *>yout.data, nsteps,
95                                               dx0, dx_min, dx_max, autorestart, return_on_error)
96        info = get_last_info(odesys, success=False if return_on_error and nreached < xout.size else True)
97        info['nreached'] = nreached
98        info['atol'], info['rtol'] = atol, rtol
99        return yout.reshape((xout.size, ny)), info
100    finally:
101        del odesys
102