1 /*
2  *    This file is part of CasADi.
3  *
4  *    CasADi -- A symbolic framework for dynamic optimization.
5  *    Copyright (C) 2010-2014 Joel Andersson, Joris Gillis, Moritz Diehl,
6  *                            K.U. Leuven. All rights reserved.
7  *    Copyright (C) 2011-2014 Greg Horn
8  *
9  *    CasADi is free software; you can redistribute it and/or
10  *    modify it under the terms of the GNU Lesser General Public
11  *    License as published by the Free Software Foundation; either
12  *    version 3 of the License, or (at your option) any later version.
13  *
14  *    CasADi is distributed in the hope that it will be useful,
15  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
16  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  *    Lesser General Public License for more details.
18  *
19  *    You should have received a copy of the GNU Lesser General Public
20  *    License along with CasADi; if not, write to the Free Software
21  *    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
22  *
23  */
24 
25 
26 #include "cvodes_interface.hpp"
27 #include "casadi/core/casadi_misc.hpp"
28 
29 #define THROWING(fcn, ...) \
30 cvodes_error(CASADI_STR(fcn), fcn(__VA_ARGS__))
31 
32 using namespace std;
33 namespace casadi {
34 
35   extern "C"
36   int CASADI_INTEGRATOR_CVODES_EXPORT
casadi_register_integrator_cvodes(Integrator::Plugin * plugin)37   casadi_register_integrator_cvodes(Integrator::Plugin* plugin) {
38     plugin->creator = CvodesInterface::creator;
39     plugin->name = "cvodes";
40     plugin->doc = CvodesInterface::meta_doc.c_str();;
41     plugin->version = CASADI_VERSION;
42     plugin->options = &CvodesInterface::options_;
43     plugin->deserialize = &CvodesInterface::deserialize;
44     return 0;
45   }
46 
47   extern "C"
casadi_load_integrator_cvodes()48   void CASADI_INTEGRATOR_CVODES_EXPORT casadi_load_integrator_cvodes() {
49     Integrator::registerPlugin(casadi_register_integrator_cvodes);
50   }
51 
CvodesInterface(const std::string & name,const Function & dae)52   CvodesInterface::CvodesInterface(const std::string& name, const Function& dae)
53     : SundialsInterface(name, dae) {
54   }
55 
~CvodesInterface()56   CvodesInterface::~CvodesInterface() {
57     clear_mem();
58   }
59 
60   const Options CvodesInterface::options_
61   = {{&SundialsInterface::options_},
62      {{"linear_multistep_method",
63        {OT_STRING,
64         "Integrator scheme: BDF|adams"}},
65       {"nonlinear_solver_iteration",
66        {OT_STRING,
67         "Nonlinear solver type: NEWTON|functional"}},
68       {"min_step_size",
69        {OT_DOUBLE,
70         "Min step size [default: 0/0.0]"}},
71       {"fsens_all_at_once",
72        {OT_BOOL,
73         "Calculate all right hand sides of the sensitivity equations at once"}}
74      }
75   };
76 
init(const Dict & opts)77   void CvodesInterface::init(const Dict& opts) {
78     if (verbose_) casadi_message(name_ + "::init");
79 
80     // Initialize the base classes
81     SundialsInterface::init(opts);
82 
83     // Default options
84     string linear_multistep_method = "bdf";
85     string nonlinear_solver_iteration = "newton";
86     min_step_size_ = 0;
87 
88     // Read options
89     for (auto&& op : opts) {
90       if (op.first=="linear_multistep_method") {
91         linear_multistep_method = op.second.to_string();
92       } else if (op.first=="min_step_size") {
93         min_step_size_ = op.second;
94       } else if (op.first=="nonlinear_solver_iteration") {
95         nonlinear_solver_iteration = op.second.to_string();
96       }
97     }
98 
99     // Create function
100     create_function("odeF", {"x", "p", "t"}, {"ode"});
101     create_function("quadF", {"x", "p", "t"}, {"quad"});
102     create_function("odeB", {"rx", "rp", "x", "p", "t"}, {"rode"});
103     create_function("quadB", {"rx", "rp", "x", "p", "t"}, {"rquad"});
104 
105     // Algebraic variables not supported
106     casadi_assert(nz_==0 && nrz_==0,
107       "CVODES does not support algebraic variables");
108 
109     if (linear_multistep_method=="adams") {
110       lmm_ = CV_ADAMS;
111     } else if (linear_multistep_method=="bdf") {
112       lmm_ = CV_BDF;
113     } else {
114       casadi_error("Unknown linear multistep method: " + linear_multistep_method);
115     }
116 
117     if (nonlinear_solver_iteration=="newton") {
118       iter_ = CV_NEWTON;
119     } else if (nonlinear_solver_iteration=="functional") {
120       iter_ = CV_FUNCTIONAL;
121     } else {
122       casadi_error("Unknown nonlinear solver iteration: " + nonlinear_solver_iteration);
123     }
124 
125     // Attach functions for jacobian information
126     if (newton_scheme_!=SD_DIRECT || (ns_>0 && second_order_correction_)) {
127       create_function("jtimesF", {"t", "x", "p", "fwd:x"}, {"fwd:ode"});
128       if (nrx_>0) {
129         create_function("jtimesB",
130                         {"t", "x", "p", "rx", "rp", "fwd:rx"}, {"fwd:rode"});
131       }
132     }
133   }
134 
init_mem(void * mem) const135   int CvodesInterface::init_mem(void* mem) const {
136     if (SundialsInterface::init_mem(mem)) return 1;
137     auto m = to_mem(mem);
138 
139     // Create CVodes memory block
140     m->mem = CVodeCreate(lmm_, iter_);
141     casadi_assert(m->mem!=nullptr, "CVodeCreate: Creation failed");
142 
143     // Set error handler function
144     THROWING(CVodeSetErrHandlerFn, m->mem, ehfun, m);
145 
146     // Set user data
147     THROWING(CVodeSetUserData, m->mem, m);
148 
149     // Initialize CVodes
150     double t0 = 0;
151     THROWING(CVodeInit, m->mem, rhs, t0, m->xz);
152 
153     // Set tolerances
154     THROWING(CVodeSStolerances, m->mem, reltol_, abstol_);
155 
156     // Maximum number of steps
157     THROWING(CVodeSetMaxNumSteps, m->mem, max_num_steps_);
158 
159     // Initial step size
160     if (step0_!=0) THROWING(CVodeSetInitStep, m->mem, step0_);
161 
162     // Min step size
163     if (min_step_size_!=0) THROWING(CVodeSetMinStep, m->mem, min_step_size_);
164 
165     // Max step size
166     if (max_step_size_!=0) THROWING(CVodeSetMaxStep, m->mem, max_step_size_);
167 
168     // Maximum order of method
169     if (max_order_) THROWING(CVodeSetMaxOrd, m->mem, max_order_);
170 
171     // Coeff. in the nonlinear convergence test
172     if (nonlin_conv_coeff_!=0) THROWING(CVodeSetNonlinConvCoef, m->mem, nonlin_conv_coeff_);
173 
174     // attach a linear solver
175     if (newton_scheme_==SD_DIRECT) {
176       // Direct scheme
177       CVodeMem cv_mem = static_cast<CVodeMem>(m->mem);
178       cv_mem->cv_lmem   = m;
179       cv_mem->cv_lsetup = lsetup;
180       cv_mem->cv_lsolve = lsolve;
181       cv_mem->cv_setupNonNull = TRUE;
182     } else {
183       // Iterative scheme
184       casadi_int pretype = use_precon_ ? PREC_LEFT : PREC_NONE;
185       switch (newton_scheme_) {
186       case SD_DIRECT: casadi_assert_dev(0);
187       case SD_GMRES: THROWING(CVSpgmr, m->mem, pretype, max_krylov_); break;
188       case SD_BCGSTAB: THROWING(CVSpbcg, m->mem, pretype, max_krylov_); break;
189       case SD_TFQMR: THROWING(CVSptfqmr, m->mem, pretype, max_krylov_); break;
190       }
191       THROWING(CVSpilsSetJacTimesVecFn, m->mem, jtimes);
192       if (use_precon_) THROWING(CVSpilsSetPreconditioner, m->mem, psetup, psolve);
193     }
194 
195     // Quadrature equations
196     if (nq_>0) {
197       // Initialize quadratures in CVodes
198       THROWING(CVodeQuadInit, m->mem, rhsQ, m->q);
199 
200       // Should the quadrature errors be used for step size control?
201       if (quad_err_con_) {
202         THROWING(CVodeSetQuadErrCon, m->mem, true);
203 
204         // Quadrature error tolerances
205         // TODO(Joel): vector absolute tolerances
206         THROWING(CVodeQuadSStolerances, m->mem, reltol_, abstol_);
207       }
208     }
209 
210     // Initialize adjoint sensitivities
211     if (nrx_>0) {
212       casadi_int interpType = interp_==SD_HERMITE ? CV_HERMITE : CV_POLYNOMIAL;
213       THROWING(CVodeAdjInit, m->mem, steps_per_checkpoint_, interpType);
214     }
215 
216     m->first_callB = true;
217     return 0;
218   }
219 
rhs(double t,N_Vector x,N_Vector xdot,void * user_data)220   int CvodesInterface::rhs(double t, N_Vector x, N_Vector xdot, void *user_data) {
221     try {
222       casadi_assert_dev(user_data);
223       auto m = to_mem(user_data);
224       auto& s = m->self;
225       m->arg[0] = NV_DATA_S(x);
226       m->arg[1] = m->p;
227       m->arg[2] = &t;
228       m->res[0] = NV_DATA_S(xdot);
229       s.calc_function(m, "odeF");
230       return 0;
231     } catch(int flag) { // recoverable error
232       return flag;
233     } catch(exception& e) { // non-recoverable error
234       uerr() << "rhs failed: " << e.what() << endl;
235       return -1;
236     }
237   }
238 
reset(IntegratorMemory * mem,double t,const double * x,const double * z,const double * _p) const239   void CvodesInterface::reset(IntegratorMemory* mem, double t, const double* x,
240                               const double* z, const double* _p) const {
241     if (verbose_) casadi_message(name_ + "::reset");
242     auto m = to_mem(mem);
243 
244     // Reset the base classes
245     SundialsInterface::reset(mem, t, x, z, _p);
246 
247     // Re-initialize
248     THROWING(CVodeReInit, m->mem, t, m->xz);
249 
250     // Re-initialize quadratures
251     if (nq_>0) {
252       N_VConst(0.0, m->q);
253       THROWING(CVodeQuadReInit, m->mem, m->q);
254     }
255 
256     // Re-initialize backward integration
257     if (nrx_>0) {
258       THROWING(CVodeAdjReInit, m->mem);
259     }
260 
261     // Set the stop time of the integration -- don't integrate past this point
262     if (stop_at_end_) setStopTime(m, grid_.back());
263   }
264 
advance(IntegratorMemory * mem,double t,double * x,double * z,double * q) const265   void CvodesInterface::advance(IntegratorMemory* mem, double t, double* x,
266                                 double* z, double* q) const {
267     auto m = to_mem(mem);
268 
269     casadi_assert(t>=grid_.front(),
270       "CvodesInterface::integrate(" + str(t) + "): "
271       "Cannot integrate to a time earlier than t0 (" + str(grid_.front()) + ")");
272     casadi_assert(t<=grid_.back() || !stop_at_end_,
273       "CvodesInterface::integrate(" + str(t) + "): "
274       "Cannot integrate past a time later than tf (" + str(grid_.back()) + ") "
275       "unless stop_at_end is set to False.");
276 
277     // Integrate, unless already at desired time
278     const double ttol = 1e-9;
279     if (fabs(m->t-t)>=ttol) {
280       // Integrate forward ...
281       if (nrx_>0) {
282         // ... with taping
283         THROWING(CVodeF, m->mem, t, m->xz, &m->t, CV_NORMAL, &m->ncheck);
284       } else {
285         // ... without taping
286         THROWING(CVode, m->mem, t, m->xz, &m->t, CV_NORMAL);
287       }
288 
289       // Get quadratures
290       if (nq_>0) {
291         double tret;
292         THROWING(CVodeGetQuad, m->mem, &tret, m->q);
293       }
294     }
295 
296     // Set function outputs
297     casadi_copy(NV_DATA_S(m->xz), nx_, x);
298     casadi_copy(NV_DATA_S(m->q), nq_, q);
299 
300     // Get stats
301     THROWING(CVodeGetIntegratorStats, m->mem, &m->nsteps, &m->nfevals, &m->nlinsetups,
302              &m->netfails, &m->qlast, &m->qcur, &m->hinused,
303              &m->hlast, &m->hcur, &m->tcur);
304     THROWING(CVodeGetNonlinSolvStats, m->mem, &m->nniters, &m->nncfails);
305   }
306 
resetB(IntegratorMemory * mem,double t,const double * rx,const double * rz,const double * rp) const307   void CvodesInterface::resetB(IntegratorMemory* mem, double t, const double* rx,
308                                const double* rz, const double* rp) const {
309     auto m = to_mem(mem);
310 
311     // Reset the base classes
312     SundialsInterface::resetB(mem, t, rx, rz, rp);
313 
314     if (m->first_callB) {
315       // Create backward problem
316       THROWING(CVodeCreateB, m->mem, lmm_, iter_, &m->whichB);
317       THROWING(CVodeInitB, m->mem, m->whichB, rhsB, grid_.back(), m->rxz);
318       THROWING(CVodeSStolerancesB, m->mem, m->whichB, reltol_, abstol_);
319       THROWING(CVodeSetUserDataB, m->mem, m->whichB, m);
320       if (newton_scheme_==SD_DIRECT) {
321         // Direct scheme
322         CVodeMem cv_mem = static_cast<CVodeMem>(m->mem);
323         CVadjMem ca_mem = cv_mem->cv_adj_mem;
324         CVodeBMem cvB_mem = ca_mem->cvB_mem;
325         cvB_mem->cv_lmem   = m;
326         cvB_mem->cv_mem->cv_lmem = m;
327         cvB_mem->cv_mem->cv_lsetup = lsetupB;
328         cvB_mem->cv_mem->cv_lsolve = lsolveB;
329         cvB_mem->cv_mem->cv_setupNonNull = TRUE;
330       } else {
331         // Iterative scheme
332         casadi_int pretype = use_precon_ ? PREC_LEFT : PREC_NONE;
333         switch (newton_scheme_) {
334         case SD_DIRECT: casadi_assert_dev(0);
335         case SD_GMRES: THROWING(CVSpgmrB, m->mem, m->whichB, pretype, max_krylov_); break;
336         case SD_BCGSTAB: THROWING(CVSpbcgB, m->mem, m->whichB, pretype, max_krylov_); break;
337         case SD_TFQMR: THROWING(CVSptfqmrB, m->mem, m->whichB, pretype, max_krylov_); break;
338         }
339         THROWING(CVSpilsSetJacTimesVecFnB, m->mem, m->whichB, jtimesB);
340         if (use_precon_) THROWING(CVSpilsSetPreconditionerB, m->mem, m->whichB, psetupB, psolveB);
341       }
342 
343       // Quadratures for the backward problem
344       THROWING(CVodeQuadInitB, m->mem, m->whichB, rhsQB, m->rq);
345       if (quad_err_con_) {
346         THROWING(CVodeSetQuadErrConB, m->mem, m->whichB, true);
347         THROWING(CVodeQuadSStolerancesB, m->mem, m->whichB, reltol_, abstol_);
348       }
349 
350       // Mark initialized
351       m->first_callB = false;
352     } else {
353       THROWING(CVodeReInitB, m->mem, m->whichB, grid_.back(), m->rxz);
354       THROWING(CVodeQuadReInitB, m->mem, m->whichB, m->rq);
355     }
356   }
357 
retreat(IntegratorMemory * mem,double t,double * rx,double * rz,double * rq) const358   void CvodesInterface::retreat(IntegratorMemory* mem, double t,
359                                 double* rx, double* rz, double* rq) const {
360     auto m = to_mem(mem);
361     // Integrate, unless already at desired time
362     if (t<m->t) {
363       THROWING(CVodeB, m->mem, t, CV_NORMAL);
364       THROWING(CVodeGetB, m->mem, m->whichB, &m->t, m->rxz);
365       if (nrq_>0) {
366         THROWING(CVodeGetQuadB, m->mem, m->whichB, &m->t, m->rq);
367       }
368     }
369 
370     // Save outputs
371     casadi_copy(NV_DATA_S(m->rxz), nrx_, rx);
372     casadi_copy(NV_DATA_S(m->rq), nrq_, rq);
373 
374     // Get stats
375     CVodeMem cv_mem = static_cast<CVodeMem>(m->mem);
376     CVadjMem ca_mem = cv_mem->cv_adj_mem;
377     CVodeBMem cvB_mem = ca_mem->cvB_mem;
378     THROWING(CVodeGetIntegratorStats, cvB_mem->cv_mem, &m->nstepsB,
379            &m->nfevalsB, &m->nlinsetupsB, &m->netfailsB, &m->qlastB,
380            &m->qcurB, &m->hinusedB, &m->hlastB, &m->hcurB, &m->tcurB);
381     THROWING(CVodeGetNonlinSolvStats, cvB_mem->cv_mem, &m->nnitersB, &m->nncfailsB);
382   }
383 
cvodes_error(const char * module,int flag)384   void CvodesInterface::cvodes_error(const char* module, int flag) {
385     // Successfull return or warning
386     if (flag>=CV_SUCCESS) return;
387     // Construct error message
388     char* flagname = CVodeGetReturnFlagName(flag);
389     stringstream ss;
390     ss << module << " returned \"" << flagname << "\". Consult CVODES documentation.";
391     free(flagname); // NOLINT
392     casadi_error(ss.str());
393   }
394 
ehfun(int error_code,const char * module,const char * function,char * msg,void * user_data)395   void CvodesInterface::ehfun(int error_code, const char *module, const char *function,
396                               char *msg, void *user_data) {
397     try {
398       casadi_assert_dev(user_data);
399       auto m = to_mem(user_data);
400       auto& s = m->self;
401       if (!s.disable_internal_warnings_) {
402         uerr() << msg << endl;
403       }
404     } catch(exception& e) {
405       uerr() << "ehfun failed: " << e.what() << endl;
406     }
407   }
408 
rhsQ(double t,N_Vector x,N_Vector qdot,void * user_data)409   int CvodesInterface::rhsQ(double t, N_Vector x, N_Vector qdot, void *user_data) {
410     try {
411       auto m = to_mem(user_data);
412       auto& s = m->self;
413       m->arg[0] = NV_DATA_S(x);
414       m->arg[1] = m->p;
415       m->arg[2] = &t;
416       m->res[0] = NV_DATA_S(qdot);
417       s.calc_function(m, "quadF");
418       return 0;
419     } catch(int flag) { // recoverable error
420       return flag;
421     } catch(exception& e) { // non-recoverable error
422       uerr() << "rhsQ failed: " << e.what() << endl;
423       return -1;
424     }
425   }
426 
rhsB(double t,N_Vector x,N_Vector rx,N_Vector rxdot,void * user_data)427   int CvodesInterface::rhsB(double t, N_Vector x, N_Vector rx, N_Vector rxdot,
428                             void *user_data) {
429     try {
430       casadi_assert_dev(user_data);
431       auto m = to_mem(user_data);
432       auto& s = m->self;
433       m->arg[0] = NV_DATA_S(rx);
434       m->arg[1] = m->rp;
435       m->arg[2] = NV_DATA_S(x);
436       m->arg[3] = m->p;
437       m->arg[4] = &t;
438       m->res[0] = NV_DATA_S(rxdot);
439       s.calc_function(m, "odeB");
440 
441       // Negate (note definition of g)
442       casadi_scal(s.nrx_, -1., NV_DATA_S(rxdot));
443 
444       return 0;
445     } catch(int flag) { // recoverable error
446       return flag;
447     } catch(exception& e) { // non-recoverable error
448       uerr() << "rhsB failed: " << e.what() << endl;
449       return -1;
450     }
451   }
452 
rhsQB(double t,N_Vector x,N_Vector rx,N_Vector rqdot,void * user_data)453   int CvodesInterface::rhsQB(double t, N_Vector x, N_Vector rx,
454                              N_Vector rqdot, void *user_data) {
455     try {
456       casadi_assert_dev(user_data);
457       auto m = to_mem(user_data);
458       auto& s = m->self;
459       m->arg[0] = NV_DATA_S(rx);
460       m->arg[1] = m->rp;
461       m->arg[2] = NV_DATA_S(x);
462       m->arg[3] = m->p;
463       m->arg[4] = &t;
464       m->res[0] = NV_DATA_S(rqdot);
465       s.calc_function(m, "quadB");
466 
467       // Negate (note definition of g)
468       casadi_scal(s.nrq_, -1., NV_DATA_S(rqdot));
469 
470       return 0;
471     } catch(int flag) { // recoverable error
472       return flag;
473     } catch(exception& e) { // non-recoverable error
474       uerr() << "rhsQB failed: " << e.what() << endl;
475       return -1;
476     }
477   }
478 
jtimes(N_Vector v,N_Vector Jv,double t,N_Vector x,N_Vector xdot,void * user_data,N_Vector tmp)479   int CvodesInterface::jtimes(N_Vector v, N_Vector Jv, double t, N_Vector x,
480                               N_Vector xdot, void *user_data, N_Vector tmp) {
481     try {
482       auto m = to_mem(user_data);
483       auto& s = m->self;
484       m->arg[0] = &t;
485       m->arg[1] = NV_DATA_S(x);
486       m->arg[2] = m->p;
487       m->arg[3] = NV_DATA_S(v);
488       m->res[0] = NV_DATA_S(Jv);
489       s.calc_function(m, "jtimesF");
490       return 0;
491     } catch(casadi_int flag) { // recoverable error
492       return flag;
493     } catch(exception& e) { // non-recoverable error
494       uerr() << "jtimes failed: " << e.what() << endl;
495       return -1;
496     }
497   }
498 
jtimesB(N_Vector v,N_Vector Jv,double t,N_Vector x,N_Vector rx,N_Vector rxdot,void * user_data,N_Vector tmpB)499   int CvodesInterface::jtimesB(N_Vector v, N_Vector Jv, double t, N_Vector x,
500                                N_Vector rx, N_Vector rxdot, void *user_data ,
501                                N_Vector tmpB) {
502     try {
503       auto m = to_mem(user_data);
504       auto& s = m->self;
505       m->arg[0] = &t;
506       m->arg[1] = NV_DATA_S(x);
507       m->arg[2] = m->p;
508       m->arg[3] = NV_DATA_S(rx);
509       m->arg[4] = m->rp;
510       m->arg[5] = NV_DATA_S(v);
511       m->res[0] = NV_DATA_S(Jv);
512       s.calc_function(m, "jtimesB");
513       return 0;
514     } catch(int flag) { // recoverable error
515       return flag;
516     } catch(exception& e) { // non-recoverable error
517       uerr() << "jtimes failed: " << e.what() << endl;
518       return -1;
519     }
520   }
521 
setStopTime(IntegratorMemory * mem,double tf) const522   void CvodesInterface::setStopTime(IntegratorMemory* mem, double tf) const {
523     // Set the stop time of the integration -- don't integrate past this point
524     auto m = to_mem(mem);
525     THROWING(CVodeSetStopTime, m->mem, tf);
526   }
527 
psolve(double t,N_Vector x,N_Vector xdot,N_Vector r,N_Vector z,double gamma,double delta,int lr,void * user_data,N_Vector tmp)528   int CvodesInterface::psolve(double t, N_Vector x, N_Vector xdot, N_Vector r,
529                               N_Vector z, double gamma, double delta, int lr,
530                               void *user_data, N_Vector tmp) {
531     try {
532       auto m = to_mem(user_data);
533       auto& s = m->self;
534 
535       // Get right-hand sides in m->v1
536       double* v = NV_DATA_S(r);
537       casadi_copy(v, s.nx_, m->v1);
538 
539       // Solve for undifferentiated right-hand-side, save to output
540       if (s.linsolF_.solve(m->jac, m->v1, 1, false, m->mem_linsolF))
541         casadi_error("Linear system solve failed");
542       v = NV_DATA_S(z); // possibly different from r
543       casadi_copy(m->v1, s.nx1_, v);
544 
545       // Sensitivity equations
546       if (s.ns_>0) {
547         // Second order correction
548         if (s.second_order_correction_) {
549           // The outputs will double as seeds for jtimesF
550           casadi_clear(v + s.nx1_, s.nx_ - s.nx1_);
551           m->arg[0] = &t; // t
552           m->arg[1] = NV_DATA_S(x); // x
553           m->arg[2] = m->p; // p
554           m->arg[3] = v; // fwd:x
555           m->res[0] = m->v2; // fwd:ode
556           s.calc_function(m, "jtimesF");
557 
558           // Subtract m->v2 from m->v1, scaled with -gamma
559           casadi_axpy(s.nx_ - s.nx1_, m->gamma, m->v2 + s.nx1_, m->v1 + s.nx1_);
560         }
561 
562         // Solve for sensitivity right-hand-sides
563         if (s.linsolF_.solve(m->jac, m->v1 + s.nx1_, s.ns_, false, m->mem_linsolF))
564           casadi_error("Linear solve failed");
565 
566         // Save to output, reordered
567         casadi_copy(m->v1 + s.nx1_, s.nx_-s.nx1_, v+s.nx1_);
568       }
569 
570       return 0;
571     } catch(int flag) { // recoverable error
572       return flag;
573     } catch(exception& e) { // non-recoverable error
574       uerr() << "psolve failed: " << e.what() << endl;
575       return -1;
576     }
577   }
578 
psolveB(double t,N_Vector x,N_Vector xB,N_Vector xdotB,N_Vector rvecB,N_Vector zvecB,double gammaB,double deltaB,int lr,void * user_data,N_Vector tmpB)579   int CvodesInterface::psolveB(double t, N_Vector x, N_Vector xB, N_Vector xdotB,
580                                N_Vector rvecB, N_Vector zvecB, double gammaB,
581                                double deltaB, int lr, void *user_data, N_Vector tmpB) {
582     try {
583       auto m = to_mem(user_data);
584       auto& s = m->self;
585 
586       // Get right-hand sides in m->v1
587       double* v = NV_DATA_S(rvecB);
588       casadi_copy(v, s.nrx_, m->v1);
589 
590       // Solve for undifferentiated right-hand-side, save to output
591       if (s.linsolB_.solve(m->jacB, m->v1, 1, false, m->mem_linsolB))
592         casadi_error("Linear solve failed");
593       v = NV_DATA_S(zvecB); // possibly different from rvecB
594       casadi_copy(m->v1, s.nrx1_, v);
595 
596       // Sensitivity equations
597       if (s.ns_>0) {
598         // Second order correction
599         if (s.second_order_correction_) {
600           // The outputs will double as seeds for jtimesF
601           casadi_clear(v + s.nrx1_, s.nrx_ - s.nrx1_);
602           m->arg[0] = &t; // t
603           m->arg[1] = NV_DATA_S(x); // x
604           m->arg[2] = m->p; // p
605           m->arg[3] = NV_DATA_S(xB); // rx
606           m->arg[4] = m->rp; // rp
607           m->arg[5] = v; // fwd:rx
608           m->res[0] = m->v2; // fwd:rode
609           s.calc_function(m, "jtimesB");
610 
611           // Subtract m->v2 from m->v1, scaled with gammaB
612           casadi_axpy(s.nrx_-s.nrx1_, -m->gammaB, m->v2 + s.nrx1_, m->v1 + s.nrx1_);
613         }
614 
615         // Solve for sensitivity right-hand-sides
616         if (s.linsolB_.solve(m->jacB, m->v1 + s.nx1_, s.ns_, false, m->mem_linsolB)) {
617           casadi_error("Linear solve failed");
618         }
619 
620         // Save to output, reordered
621         casadi_copy(m->v1 + s.nx1_, s.nx_-s.nx1_, v+s.nx1_);
622       }
623 
624       return 0;
625     } catch(int flag) { // recoverable error
626       return flag;
627     } catch(exception& e) { // non-recoverable error
628       uerr() << "psolveB failed: " << e.what() << endl;
629       return -1;
630     }
631   }
632 
psetup(double t,N_Vector x,N_Vector xdot,booleantype jok,booleantype * jcurPtr,double gamma,void * user_data,N_Vector tmp1,N_Vector tmp2,N_Vector tmp3)633   int CvodesInterface::psetup(double t, N_Vector x, N_Vector xdot, booleantype jok,
634                               booleantype *jcurPtr, double gamma, void *user_data,
635                               N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) {
636     try {
637       auto m = to_mem(user_data);
638       auto& s = m->self;
639       // Store gamma for later
640       m->gamma = gamma;
641 
642       // Calculate Jacobian
643       double d1 = -gamma, d2 = 1.;
644       m->arg[0] = &t;
645       m->arg[1] = NV_DATA_S(x);
646       m->arg[2] = m->p;
647       m->arg[3] = &d1;
648       m->arg[4] = &d2;
649       m->res[0] = m->jac;
650       if (s.calc_function(m, "jacF")) casadi_error("'jacF' calculation failed");
651 
652       // Prepare the solution of the linear system (e.g. factorize)
653       if (s.linsolF_.nfact(m->jac, m->mem_linsolF)) casadi_error("'jacF' factorization failed");
654 
655       return 0;
656     } catch(int flag) { // recoverable error
657       return flag;
658     } catch(exception& e) { // non-recoverable error
659       uerr() << "psetup failed: " << e.what() << endl;
660       return -1;
661     }
662   }
663 
psetupB(double t,N_Vector x,N_Vector rx,N_Vector rxdot,booleantype jokB,booleantype * jcurPtrB,double gammaB,void * user_data,N_Vector tmp1B,N_Vector tmp2B,N_Vector tmp3B)664   int CvodesInterface::psetupB(double t, N_Vector x, N_Vector rx, N_Vector rxdot,
665                                booleantype jokB, booleantype *jcurPtrB, double gammaB,
666                                void *user_data, N_Vector tmp1B, N_Vector tmp2B,
667                                N_Vector tmp3B) {
668     try {
669       auto m = to_mem(user_data);
670       auto& s = m->self;
671       // Store gamma for later
672       m->gammaB = gammaB;
673       // Calculate Jacobian
674       double one=1;
675       m->arg[0] = &t;
676       m->arg[1] = NV_DATA_S(rx);
677       m->arg[2] = m->rp;
678       m->arg[3] = NV_DATA_S(x);
679       m->arg[4] = m->p;
680       m->arg[5] = &gammaB;
681       m->arg[6] = &one;
682       m->res[0] = m->jacB;
683       if (s.calc_function(m, "jacB")) casadi_error("'jacB' calculation failed");
684 
685       // Prepare the solution of the linear system (e.g. factorize)
686       if (s.linsolB_.nfact(m->jacB, m->mem_linsolB)) casadi_error("'jacB' factorization failed");
687 
688       return 0;
689     } catch(int flag) { // recoverable error
690       return flag;
691     } catch(exception& e) { // non-recoverable error
692       uerr() << "psetupB failed: " << e.what() << endl;
693       return -1;
694     }
695   }
696 
lsetup(CVodeMem cv_mem,int convfail,N_Vector x,N_Vector xdot,booleantype * jcurPtr,N_Vector vtemp1,N_Vector vtemp2,N_Vector vtemp3)697   int CvodesInterface::lsetup(CVodeMem cv_mem, int convfail, N_Vector x, N_Vector xdot,
698                               booleantype *jcurPtr,
699                               N_Vector vtemp1, N_Vector vtemp2, N_Vector vtemp3) {
700     try {
701       auto m = to_mem(cv_mem->cv_lmem);
702       //auto& s = m->self;
703 
704       // Current time
705       double t = cv_mem->cv_tn;
706 
707       // Scaling factor before J
708       double gamma = cv_mem->cv_gamma;
709 
710       // Call the preconditioner setup function (which sets up the linear solver)
711       if (psetup(t, x, xdot, FALSE, jcurPtr,
712                  gamma, static_cast<void*>(m), vtemp1, vtemp2, vtemp3)) return 1;
713 
714       return 0;
715     } catch(int flag) { // recoverable error
716       return flag;
717     } catch(exception& e) { // non-recoverable error
718       uerr() << "lsetup failed: " << e.what() << endl;
719       return -1;
720     }
721   }
722 
lsetupB(CVodeMem cv_mem,int convfail,N_Vector x,N_Vector xdot,booleantype * jcurPtr,N_Vector vtemp1,N_Vector vtemp2,N_Vector vtemp3)723   int CvodesInterface::lsetupB(CVodeMem cv_mem, int convfail, N_Vector x, N_Vector xdot,
724                                booleantype *jcurPtr,
725                                N_Vector vtemp1, N_Vector vtemp2, N_Vector vtemp3) {
726     try {
727       auto m = to_mem(cv_mem->cv_lmem);
728       CVadjMem ca_mem;
729       //CVodeBMem cvB_mem;
730 
731       int flag;
732 
733       // Current time
734       double t = cv_mem->cv_tn; // TODO(Joel): is this correct?
735       double gamma = cv_mem->cv_gamma;
736 
737       cv_mem = static_cast<CVodeMem>(cv_mem->cv_user_data);
738 
739       ca_mem = cv_mem->cv_adj_mem;
740       //cvB_mem = ca_mem->ca_bckpbCrt;
741 
742       // Get FORWARD solution from interpolation.
743       flag = ca_mem->ca_IMget(cv_mem, t, ca_mem->ca_ytmp, nullptr);
744       if (flag != CV_SUCCESS) casadi_error("Could not interpolate forward states");
745 
746       // Call the preconditioner setup function (which sets up the linear solver)
747       if (psetupB(t, ca_mem->ca_ytmp, x, xdot, FALSE, jcurPtr,
748                   gamma, static_cast<void*>(m), vtemp1, vtemp2, vtemp3)) return 1;
749 
750       return 0;
751     } catch(int flag) { // recoverable error
752       return flag;
753     } catch(exception& e) { // non-recoverable error
754       uerr() << "lsetupB failed: " << e.what() << endl;
755       return -1;
756     }
757   }
758 
lsolve(CVodeMem cv_mem,N_Vector b,N_Vector weight,N_Vector x,N_Vector xdot)759   int CvodesInterface::lsolve(CVodeMem cv_mem, N_Vector b, N_Vector weight,
760                               N_Vector x, N_Vector xdot) {
761     try {
762       auto m = to_mem(cv_mem->cv_lmem);
763       //auto& s = m->self;
764 
765       // Current time
766       double t = cv_mem->cv_tn;
767 
768       // Scaling factor before J
769       double gamma = cv_mem->cv_gamma;
770 
771       // Accuracy
772       double delta = 0.0;
773 
774       // Left/right preconditioner
775       casadi_int lr = 1;
776 
777       // Call the preconditioner solve function (which solves the linear system)
778       if (psolve(t, x, xdot, b, b, gamma, delta,
779                  lr, static_cast<void*>(m), nullptr)) return 1;
780 
781       return 0;
782     } catch(int flag) { // recoverable error
783       return flag;
784     } catch(exception& e) { // non-recoverable error
785       uerr() << "lsolve failed: " << e.what() << endl;
786       return -1;
787     }
788   }
789 
lsolveB(CVodeMem cv_mem,N_Vector b,N_Vector weight,N_Vector x,N_Vector xdot)790   int CvodesInterface::lsolveB(CVodeMem cv_mem, N_Vector b, N_Vector weight,
791                                N_Vector x, N_Vector xdot) {
792     try {
793       auto m = to_mem(cv_mem->cv_lmem);
794       CVadjMem ca_mem;
795       //CVodeBMem cvB_mem;
796 
797       int flag;
798 
799       // Current time
800       double t = cv_mem->cv_tn; // TODO(Joel): is this correct?
801       double gamma = cv_mem->cv_gamma;
802 
803       cv_mem = static_cast<CVodeMem>(cv_mem->cv_user_data);
804 
805       ca_mem = cv_mem->cv_adj_mem;
806       //cvB_mem = ca_mem->ca_bckpbCrt;
807 
808       // Get FORWARD solution from interpolation.
809       flag = ca_mem->ca_IMget(cv_mem, t, ca_mem->ca_ytmp, nullptr);
810       if (flag != CV_SUCCESS) casadi_error("Could not interpolate forward states");
811 
812 
813 
814       // Accuracy
815       double delta = 0.0;
816 
817       // Left/right preconditioner
818       int lr = 1;
819 
820       // Call the preconditioner solve function (which solves the linear system)
821       if (psolveB(t, ca_mem->ca_ytmp, x, xdot, b, b, gamma, delta, lr,
822                   static_cast<void*>(m), nullptr)) return 1;
823 
824       return 0;
825     } catch(int flag) { // recoverable error
826       return flag;
827     } catch(exception& e) { // non-recoverable error
828       uerr() << "lsolveB failed: " << e.what() << endl;
829       return -1;
830     }
831   }
832 
getJ(bool b) const833   Function CvodesInterface::getJ(bool b) const {
834     return oracle_.is_a("SXFunction") ? getJ<SX>(b) : getJ<MX>(b);
835   }
836 
837   template<typename MatType>
getJ(bool backward) const838   Function CvodesInterface::getJ(bool backward) const {
839     vector<MatType> a = MatType::get_input(oracle_);
840     vector<MatType> r = const_cast<Function&>(oracle_)(a); // NOLINT
841     MatType c_x = MatType::sym("c_x");
842     MatType c_xdot = MatType::sym("c_xdot");
843 
844     // Get the Jacobian in the Newton iteration
845     if (backward) {
846       MatType jac = c_x*MatType::jacobian(r[DE_RODE], a[DE_RX])
847                   + c_xdot*MatType::eye(nrx_);
848       return Function("jacB",
849                       {a[DE_T], a[DE_RX], a[DE_RP],
850                        a[DE_X], a[DE_P], c_x, c_xdot}, {jac});
851      } else {
852       MatType jac = c_x*MatType::jacobian(r[DE_ODE], a[DE_X])
853                   + c_xdot*MatType::eye(nx_);
854       return Function("jacF", {a[DE_T], a[DE_X], a[DE_P], c_x, c_xdot}, {jac});
855     }
856   }
857 
CvodesMemory(const CvodesInterface & s)858   CvodesMemory::CvodesMemory(const CvodesInterface& s) : self(s) {
859     this->mem = nullptr;
860 
861     // Reset checkpoints counter
862     this->ncheck = 0;
863   }
864 
~CvodesMemory()865   CvodesMemory::~CvodesMemory() {
866     if (this->mem) CVodeFree(&this->mem);
867   }
868 
CvodesInterface(DeserializingStream & s)869   CvodesInterface::CvodesInterface(DeserializingStream& s) : SundialsInterface(s) {
870     int version = s.version("CvodesInterface", 1, 2);
871     s.unpack("CvodesInterface::lmm", lmm_);
872     s.unpack("CvodesInterface::iter", iter_);
873 
874     if (version>=2) {
875       s.unpack("CvodesInterface::min_step_size", min_step_size_);
876     } else {
877       min_step_size_ = 0;
878     }
879   }
880 
serialize_body(SerializingStream & s) const881   void CvodesInterface::serialize_body(SerializingStream &s) const {
882     SundialsInterface::serialize_body(s);
883     s.version("CvodesInterface", 2);
884 
885     s.pack("CvodesInterface::lmm", lmm_);
886     s.pack("CvodesInterface::iter", iter_);
887     s.pack("CvodesInterface::min_step_size", min_step_size_);
888 
889   }
890 
891 } // namespace casadi
892