1# -*- coding: utf-8; mode: cython -*- 2# distutils: language = c++ 3 4from cpython.object cimport PyObject 5from libcpp cimport bool 6from libcpp.string cimport string 7from libcpp.vector cimport vector 8cimport numpy as cnp 9 10from anyode_numpy cimport PyOdeSys 11from gsl_odeiv2_cxx cimport styp_from_name, fpes as _fpes 12from gsl_odeiv2_anyode cimport simple_adaptive, simple_predefined 13 14import numpy as np 15 16ctypedef PyOdeSys[double, int] PyOdeSys_t 17 18cnp.import_array() # Numpy C-API initialization 19 20requires_jac = ('rk1imp', 'rk2imp', 'rk4imp', 'bsimp', 'msbdf') 21steppers = requires_jac + ('rk2', 'rk4', 'rkf45', 'rkck', 'rk8pd', 'msadams') 22fpes = {str(k.decode('utf-8')): v for k, v in dict(_fpes).items()} 23 24cdef dict get_last_info(PyOdeSys_t * odesys, success=True): 25 info = {str(k.decode('utf-8')): v for k, v in dict(odesys.current_info.nfo_int).items()} 26 info.update({str(k.decode('utf-8')): v for k, v in dict(odesys.current_info.nfo_dbl).items()}) 27 info.update({str(k.decode('utf-8')): np.array(v, dtype=np.float64) for k, v in dict(odesys.current_info.nfo_vecdbl).items()}) 28 info.update({str(k.decode('utf-8')): np.array(v, dtype=int) for k, v in dict(odesys.current_info.nfo_vecint).items()}) 29 info['nfev'] = odesys.nfev 30 info['njev'] = odesys.njev 31 info['success'] = success 32 return info 33 34def adaptive(rhs, jac, cnp.ndarray[cnp.float64_t, mode='c'] y0, double x0, double xend, double atol, 35 double rtol, str method='bsimp', long int nsteps=500, double dx0=0.0, double dx_min=0.0, 36 double dx_max=0.0, int autorestart=0, bool return_on_error=False, cb_kwargs=None, 37 bool record_rhs_xvals=False, bool record_jac_xvals=False, bool record_order=False, 38 bool record_fpe=False, dx0cb=None, dx_max_cb=None): 39 cdef: 40 int ny = y0.shape[y0.ndim - 1] 41 int mlower=-1, mupper=-1, nquads=0, nroots=0, nnz=-1 42 PyOdeSys_t * odesys 43 44 if method in requires_jac and jac is None: 45 raise ValueError("Method requires explicit jacobian callback") 46 if np.isnan(y0).any(): 47 raise ValueError("NaN found in y0") 48 49 odesys = new PyOdeSys_t(ny, <PyObject *>rhs, <PyObject *>jac, NULL, NULL, NULL, 50 <PyObject *>cb_kwargs, mlower, mupper, nquads, nroots, <PyObject *>dx0cb, <PyObject *>dx_max_cb, nnz) 51 odesys.record_rhs_xvals = record_rhs_xvals 52 odesys.record_jac_xvals = record_jac_xvals 53 odesys.record_order = record_order 54 odesys.record_fpe = record_fpe 55 56 try: 57 xout, yout = map(np.asarray, simple_adaptive[PyOdeSys_t]( 58 odesys, atol, rtol, styp_from_name(method.lower().encode('UTF-8')), 59 &y0[0], x0, xend, nsteps, dx0, dx_min, dx_max, autorestart, return_on_error)) 60 info = get_last_info(odesys, False if return_on_error and xout[-1] != xend else True) 61 info['atol'], info['rtol'] = atol, rtol 62 return xout, yout.reshape((xout.size, ny)), info 63 finally: 64 del odesys 65 66 67def predefined(rhs, jac, 68 cnp.ndarray[cnp.float64_t, mode='c'] y0, 69 cnp.ndarray[cnp.float64_t, ndim=1] xout, 70 double atol, double rtol, str method='bsimp', int nsteps=500, double dx0=0.0, 71 double dx_min=0.0, double dx_max=0.0, int autorestart=0, 72 bool return_on_error=False, cb_kwargs=None, bool record_rhs_xvals=False, 73 bool record_jac_xvals=False, bool record_order=False, bool record_fpe=False, 74 dx0cb=None, dx_max_cb=None): 75 cdef: 76 int ny = y0.shape[y0.ndim - 1] 77 cnp.ndarray[cnp.float64_t, ndim=2] yout = np.empty((xout.size, ny)) 78 int nreached 79 PyOdeSys_t * odesys 80 int mlower=-1, mupper=-1, nquads=0, nroots=0, nnz=-1 81 82 if method in requires_jac and jac is None: 83 raise ValueError("Method requires explicit jacobian callback") 84 if np.isnan(y0).any(): 85 raise ValueError("NaN found in y0") 86 odesys = new PyOdeSys_t(ny, <PyObject *>rhs, <PyObject *>jac, NULL, NULL, NULL, <PyObject *>cb_kwargs, 87 mlower, mupper, nquads, nroots, <PyObject *>dx0cb, <PyObject *>dx_max_cb, nnz) 88 odesys.record_rhs_xvals = record_rhs_xvals 89 odesys.record_jac_xvals = record_jac_xvals 90 odesys.record_order = record_order 91 odesys.record_fpe = record_fpe 92 try: 93 nreached = simple_predefined[PyOdeSys_t](odesys, atol, rtol, styp_from_name(method.lower().encode('UTF-8')), &y0[0], 94 xout.size, &xout[0], <double *>yout.data, nsteps, 95 dx0, dx_min, dx_max, autorestart, return_on_error) 96 info = get_last_info(odesys, success=False if return_on_error and nreached < xout.size else True) 97 info['nreached'] = nreached 98 info['atol'], info['rtol'] = atol, rtol 99 return yout.reshape((xout.size, ny)), info 100 finally: 101 del odesys 102