1# -*- coding: utf-8; mode: cython -*-
2# distutils: language = c++
3# distutils: extra_compile_args = -std=c++11
4
5from cpython.ref cimport PyObject
6from libcpp cimport bool
7cimport numpy as cnp
8cnp.import_array()  # Numpy C-API initialization
9
10import numpy as np
11
12from anyode_numpy cimport PyOdeSys
13from odeint_anyode cimport simple_adaptive, simple_predefined, styp_from_name
14
15steppers = ('rosenbrock4', 'dopri5', 'bulirsch_stoer')
16requires_jac = ('rosenbrock4',)
17
18ctypedef PyOdeSys[double, int] PyOdeSys_t
19
20
21cdef dict get_last_info(PyOdeSys_t * odesys, success=True):
22    info = {str(k.decode('utf-8')): v for k, v in dict(odesys.current_info.nfo_int).items()}
23    info.update({str(k.decode('utf-8')): v for k, v in dict(odesys.current_info.nfo_dbl).items()})
24    info['nfev'] = odesys.nfev
25    info['njev'] = odesys.njev
26    info['success'] = success
27    return info
28
29
30def adaptive(rhs, jac, cnp.ndarray[cnp.float64_t] y0, double x0, double xend,
31             double atol, double rtol, double dx0=.0, double dx_max=.0, str method='rosenbrock4', int nsteps=500,
32             int autorestart=0, bool return_on_error=False, dx0cb=None, dx_max_cb=None):
33    cdef:
34        int ny = y0.shape[y0.ndim - 1]
35        PyOdeSys_t * odesys
36        int mlower=-1, mupper=-1, nquads=0, nroots=0, nnz=-1
37
38    if method in requires_jac and jac is None:
39        raise ValueError("Method requires explicit jacobian callback")
40    if np.isnan(y0).any():
41        raise ValueError("NaN found in y0")
42
43    odesys = new PyOdeSys_t(ny, <PyObject *>rhs, <PyObject *>jac, NULL, NULL, NULL, NULL,
44                          mlower, mupper, nquads, nroots, <PyObject *> dx0cb, <PyObject *>dx_max_cb, nnz)
45    try:
46        xout, yout = map(np.asarray, simple_adaptive[PyOdeSys_t](
47            odesys, atol, rtol, styp_from_name(method.lower().encode('UTF-8')),
48            &y0[0], x0, xend, nsteps, dx0, dx_max, autorestart, return_on_error))
49        nfo = get_last_info(odesys, False if return_on_error and xout[-1] != xend else True)
50        nfo['atol'], nfo['rtol'] = atol, rtol
51        return xout, yout.reshape(xout.size, ny), nfo
52    finally:
53        del odesys
54
55
56def predefined(rhs, jac,
57               cnp.ndarray[cnp.float64_t] y0,
58               cnp.ndarray[cnp.float64_t, ndim=1] xout,
59               double atol, double rtol, double dx0=.0, double dx_max=.0, method='rosenbrock4',
60               int nsteps=500, int autorestart=0, bool return_on_error=False, dx0cb=None, dx_max_cb=None):
61    cdef:
62        int ny = y0.shape[y0.ndim - 1]
63        int nreached
64        cnp.ndarray[cnp.float64_t, ndim=2] yout
65        PyOdeSys_t * odesys
66        int mlower=-1, mupper=-1, nquads=0, nroots=0, nnz=-1
67
68    if method in requires_jac and jac is None:
69        raise ValueError("Method requires explicit jacobian callback")
70    if np.isnan(y0).any():
71        raise ValueError("NaN found in y0")
72    odesys = new PyOdeSys_t(ny, <PyObject *>rhs, <PyObject *>jac, NULL, NULL, NULL, NULL, mlower, mupper,
73                          nquads, nroots, <PyObject *> dx0cb, <PyObject *>dx_max_cb, nnz)
74    try:
75        yout = np.empty((xout.size, ny))
76        nreached = simple_predefined[PyOdeSys_t](odesys, atol, rtol, styp_from_name(method.lower().encode('UTF-8')),
77                                                 &y0[0], xout.size, &xout[0], &yout[0, 0], nsteps, dx0, dx_max,
78                                                 autorestart, return_on_error)
79        info = get_last_info(odesys, success=False if return_on_error and nreached < xout.size else True)
80        info['nreached'] = nreached
81        info['atol'], info['rtol'] = atol, rtol
82        return yout, info
83    finally:
84        del odesys
85