1 /* 2 * This file is part of CasADi. 3 * 4 * CasADi -- A symbolic framework for dynamic optimization. 5 * Copyright (C) 2010-2014 Joel Andersson, Joris Gillis, Moritz Diehl, 6 * K.U. Leuven. All rights reserved. 7 * Copyright (C) 2011-2014 Greg Horn 8 * 9 * CasADi is free software; you can redistribute it and/or 10 * modify it under the terms of the GNU Lesser General Public 11 * License as published by the Free Software Foundation; either 12 * version 3 of the License, or (at your option) any later version. 13 * 14 * CasADi is distributed in the hope that it will be useful, 15 * but WITHOUT ANY WARRANTY; without even the implied warranty of 16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 17 * Lesser General Public License for more details. 18 * 19 * You should have received a copy of the GNU Lesser General Public 20 * License along with CasADi; if not, write to the Free Software 21 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 22 * 23 */ 24 25 26 #include "cvodes_interface.hpp" 27 #include "casadi/core/casadi_misc.hpp" 28 29 #define THROWING(fcn, ...) \ 30 cvodes_error(CASADI_STR(fcn), fcn(__VA_ARGS__)) 31 32 using namespace std; 33 namespace casadi { 34 35 extern "C" 36 int CASADI_INTEGRATOR_CVODES_EXPORT casadi_register_integrator_cvodes(Integrator::Plugin * plugin)37 casadi_register_integrator_cvodes(Integrator::Plugin* plugin) { 38 plugin->creator = CvodesInterface::creator; 39 plugin->name = "cvodes"; 40 plugin->doc = CvodesInterface::meta_doc.c_str();; 41 plugin->version = CASADI_VERSION; 42 plugin->options = &CvodesInterface::options_; 43 plugin->deserialize = &CvodesInterface::deserialize; 44 return 0; 45 } 46 47 extern "C" casadi_load_integrator_cvodes()48 void CASADI_INTEGRATOR_CVODES_EXPORT casadi_load_integrator_cvodes() { 49 Integrator::registerPlugin(casadi_register_integrator_cvodes); 50 } 51 CvodesInterface(const std::string & name,const Function & dae)52 CvodesInterface::CvodesInterface(const std::string& name, const Function& dae) 53 : SundialsInterface(name, dae) { 54 } 55 ~CvodesInterface()56 CvodesInterface::~CvodesInterface() { 57 clear_mem(); 58 } 59 60 const Options CvodesInterface::options_ 61 = {{&SundialsInterface::options_}, 62 {{"linear_multistep_method", 63 {OT_STRING, 64 "Integrator scheme: BDF|adams"}}, 65 {"nonlinear_solver_iteration", 66 {OT_STRING, 67 "Nonlinear solver type: NEWTON|functional"}}, 68 {"min_step_size", 69 {OT_DOUBLE, 70 "Min step size [default: 0/0.0]"}}, 71 {"fsens_all_at_once", 72 {OT_BOOL, 73 "Calculate all right hand sides of the sensitivity equations at once"}} 74 } 75 }; 76 init(const Dict & opts)77 void CvodesInterface::init(const Dict& opts) { 78 if (verbose_) casadi_message(name_ + "::init"); 79 80 // Initialize the base classes 81 SundialsInterface::init(opts); 82 83 // Default options 84 string linear_multistep_method = "bdf"; 85 string nonlinear_solver_iteration = "newton"; 86 min_step_size_ = 0; 87 88 // Read options 89 for (auto&& op : opts) { 90 if (op.first=="linear_multistep_method") { 91 linear_multistep_method = op.second.to_string(); 92 } else if (op.first=="min_step_size") { 93 min_step_size_ = op.second; 94 } else if (op.first=="nonlinear_solver_iteration") { 95 nonlinear_solver_iteration = op.second.to_string(); 96 } 97 } 98 99 // Create function 100 create_function("odeF", {"x", "p", "t"}, {"ode"}); 101 create_function("quadF", {"x", "p", "t"}, {"quad"}); 102 create_function("odeB", {"rx", "rp", "x", "p", "t"}, {"rode"}); 103 create_function("quadB", {"rx", "rp", "x", "p", "t"}, {"rquad"}); 104 105 // Algebraic variables not supported 106 casadi_assert(nz_==0 && nrz_==0, 107 "CVODES does not support algebraic variables"); 108 109 if (linear_multistep_method=="adams") { 110 lmm_ = CV_ADAMS; 111 } else if (linear_multistep_method=="bdf") { 112 lmm_ = CV_BDF; 113 } else { 114 casadi_error("Unknown linear multistep method: " + linear_multistep_method); 115 } 116 117 if (nonlinear_solver_iteration=="newton") { 118 iter_ = CV_NEWTON; 119 } else if (nonlinear_solver_iteration=="functional") { 120 iter_ = CV_FUNCTIONAL; 121 } else { 122 casadi_error("Unknown nonlinear solver iteration: " + nonlinear_solver_iteration); 123 } 124 125 // Attach functions for jacobian information 126 if (newton_scheme_!=SD_DIRECT || (ns_>0 && second_order_correction_)) { 127 create_function("jtimesF", {"t", "x", "p", "fwd:x"}, {"fwd:ode"}); 128 if (nrx_>0) { 129 create_function("jtimesB", 130 {"t", "x", "p", "rx", "rp", "fwd:rx"}, {"fwd:rode"}); 131 } 132 } 133 } 134 init_mem(void * mem) const135 int CvodesInterface::init_mem(void* mem) const { 136 if (SundialsInterface::init_mem(mem)) return 1; 137 auto m = to_mem(mem); 138 139 // Create CVodes memory block 140 m->mem = CVodeCreate(lmm_, iter_); 141 casadi_assert(m->mem!=nullptr, "CVodeCreate: Creation failed"); 142 143 // Set error handler function 144 THROWING(CVodeSetErrHandlerFn, m->mem, ehfun, m); 145 146 // Set user data 147 THROWING(CVodeSetUserData, m->mem, m); 148 149 // Initialize CVodes 150 double t0 = 0; 151 THROWING(CVodeInit, m->mem, rhs, t0, m->xz); 152 153 // Set tolerances 154 THROWING(CVodeSStolerances, m->mem, reltol_, abstol_); 155 156 // Maximum number of steps 157 THROWING(CVodeSetMaxNumSteps, m->mem, max_num_steps_); 158 159 // Initial step size 160 if (step0_!=0) THROWING(CVodeSetInitStep, m->mem, step0_); 161 162 // Min step size 163 if (min_step_size_!=0) THROWING(CVodeSetMinStep, m->mem, min_step_size_); 164 165 // Max step size 166 if (max_step_size_!=0) THROWING(CVodeSetMaxStep, m->mem, max_step_size_); 167 168 // Maximum order of method 169 if (max_order_) THROWING(CVodeSetMaxOrd, m->mem, max_order_); 170 171 // Coeff. in the nonlinear convergence test 172 if (nonlin_conv_coeff_!=0) THROWING(CVodeSetNonlinConvCoef, m->mem, nonlin_conv_coeff_); 173 174 // attach a linear solver 175 if (newton_scheme_==SD_DIRECT) { 176 // Direct scheme 177 CVodeMem cv_mem = static_cast<CVodeMem>(m->mem); 178 cv_mem->cv_lmem = m; 179 cv_mem->cv_lsetup = lsetup; 180 cv_mem->cv_lsolve = lsolve; 181 cv_mem->cv_setupNonNull = TRUE; 182 } else { 183 // Iterative scheme 184 casadi_int pretype = use_precon_ ? PREC_LEFT : PREC_NONE; 185 switch (newton_scheme_) { 186 case SD_DIRECT: casadi_assert_dev(0); 187 case SD_GMRES: THROWING(CVSpgmr, m->mem, pretype, max_krylov_); break; 188 case SD_BCGSTAB: THROWING(CVSpbcg, m->mem, pretype, max_krylov_); break; 189 case SD_TFQMR: THROWING(CVSptfqmr, m->mem, pretype, max_krylov_); break; 190 } 191 THROWING(CVSpilsSetJacTimesVecFn, m->mem, jtimes); 192 if (use_precon_) THROWING(CVSpilsSetPreconditioner, m->mem, psetup, psolve); 193 } 194 195 // Quadrature equations 196 if (nq_>0) { 197 // Initialize quadratures in CVodes 198 THROWING(CVodeQuadInit, m->mem, rhsQ, m->q); 199 200 // Should the quadrature errors be used for step size control? 201 if (quad_err_con_) { 202 THROWING(CVodeSetQuadErrCon, m->mem, true); 203 204 // Quadrature error tolerances 205 // TODO(Joel): vector absolute tolerances 206 THROWING(CVodeQuadSStolerances, m->mem, reltol_, abstol_); 207 } 208 } 209 210 // Initialize adjoint sensitivities 211 if (nrx_>0) { 212 casadi_int interpType = interp_==SD_HERMITE ? CV_HERMITE : CV_POLYNOMIAL; 213 THROWING(CVodeAdjInit, m->mem, steps_per_checkpoint_, interpType); 214 } 215 216 m->first_callB = true; 217 return 0; 218 } 219 rhs(double t,N_Vector x,N_Vector xdot,void * user_data)220 int CvodesInterface::rhs(double t, N_Vector x, N_Vector xdot, void *user_data) { 221 try { 222 casadi_assert_dev(user_data); 223 auto m = to_mem(user_data); 224 auto& s = m->self; 225 m->arg[0] = NV_DATA_S(x); 226 m->arg[1] = m->p; 227 m->arg[2] = &t; 228 m->res[0] = NV_DATA_S(xdot); 229 s.calc_function(m, "odeF"); 230 return 0; 231 } catch(int flag) { // recoverable error 232 return flag; 233 } catch(exception& e) { // non-recoverable error 234 uerr() << "rhs failed: " << e.what() << endl; 235 return -1; 236 } 237 } 238 reset(IntegratorMemory * mem,double t,const double * x,const double * z,const double * _p) const239 void CvodesInterface::reset(IntegratorMemory* mem, double t, const double* x, 240 const double* z, const double* _p) const { 241 if (verbose_) casadi_message(name_ + "::reset"); 242 auto m = to_mem(mem); 243 244 // Reset the base classes 245 SundialsInterface::reset(mem, t, x, z, _p); 246 247 // Re-initialize 248 THROWING(CVodeReInit, m->mem, t, m->xz); 249 250 // Re-initialize quadratures 251 if (nq_>0) { 252 N_VConst(0.0, m->q); 253 THROWING(CVodeQuadReInit, m->mem, m->q); 254 } 255 256 // Re-initialize backward integration 257 if (nrx_>0) { 258 THROWING(CVodeAdjReInit, m->mem); 259 } 260 261 // Set the stop time of the integration -- don't integrate past this point 262 if (stop_at_end_) setStopTime(m, grid_.back()); 263 } 264 advance(IntegratorMemory * mem,double t,double * x,double * z,double * q) const265 void CvodesInterface::advance(IntegratorMemory* mem, double t, double* x, 266 double* z, double* q) const { 267 auto m = to_mem(mem); 268 269 casadi_assert(t>=grid_.front(), 270 "CvodesInterface::integrate(" + str(t) + "): " 271 "Cannot integrate to a time earlier than t0 (" + str(grid_.front()) + ")"); 272 casadi_assert(t<=grid_.back() || !stop_at_end_, 273 "CvodesInterface::integrate(" + str(t) + "): " 274 "Cannot integrate past a time later than tf (" + str(grid_.back()) + ") " 275 "unless stop_at_end is set to False."); 276 277 // Integrate, unless already at desired time 278 const double ttol = 1e-9; 279 if (fabs(m->t-t)>=ttol) { 280 // Integrate forward ... 281 if (nrx_>0) { 282 // ... with taping 283 THROWING(CVodeF, m->mem, t, m->xz, &m->t, CV_NORMAL, &m->ncheck); 284 } else { 285 // ... without taping 286 THROWING(CVode, m->mem, t, m->xz, &m->t, CV_NORMAL); 287 } 288 289 // Get quadratures 290 if (nq_>0) { 291 double tret; 292 THROWING(CVodeGetQuad, m->mem, &tret, m->q); 293 } 294 } 295 296 // Set function outputs 297 casadi_copy(NV_DATA_S(m->xz), nx_, x); 298 casadi_copy(NV_DATA_S(m->q), nq_, q); 299 300 // Get stats 301 THROWING(CVodeGetIntegratorStats, m->mem, &m->nsteps, &m->nfevals, &m->nlinsetups, 302 &m->netfails, &m->qlast, &m->qcur, &m->hinused, 303 &m->hlast, &m->hcur, &m->tcur); 304 THROWING(CVodeGetNonlinSolvStats, m->mem, &m->nniters, &m->nncfails); 305 } 306 resetB(IntegratorMemory * mem,double t,const double * rx,const double * rz,const double * rp) const307 void CvodesInterface::resetB(IntegratorMemory* mem, double t, const double* rx, 308 const double* rz, const double* rp) const { 309 auto m = to_mem(mem); 310 311 // Reset the base classes 312 SundialsInterface::resetB(mem, t, rx, rz, rp); 313 314 if (m->first_callB) { 315 // Create backward problem 316 THROWING(CVodeCreateB, m->mem, lmm_, iter_, &m->whichB); 317 THROWING(CVodeInitB, m->mem, m->whichB, rhsB, grid_.back(), m->rxz); 318 THROWING(CVodeSStolerancesB, m->mem, m->whichB, reltol_, abstol_); 319 THROWING(CVodeSetUserDataB, m->mem, m->whichB, m); 320 if (newton_scheme_==SD_DIRECT) { 321 // Direct scheme 322 CVodeMem cv_mem = static_cast<CVodeMem>(m->mem); 323 CVadjMem ca_mem = cv_mem->cv_adj_mem; 324 CVodeBMem cvB_mem = ca_mem->cvB_mem; 325 cvB_mem->cv_lmem = m; 326 cvB_mem->cv_mem->cv_lmem = m; 327 cvB_mem->cv_mem->cv_lsetup = lsetupB; 328 cvB_mem->cv_mem->cv_lsolve = lsolveB; 329 cvB_mem->cv_mem->cv_setupNonNull = TRUE; 330 } else { 331 // Iterative scheme 332 casadi_int pretype = use_precon_ ? PREC_LEFT : PREC_NONE; 333 switch (newton_scheme_) { 334 case SD_DIRECT: casadi_assert_dev(0); 335 case SD_GMRES: THROWING(CVSpgmrB, m->mem, m->whichB, pretype, max_krylov_); break; 336 case SD_BCGSTAB: THROWING(CVSpbcgB, m->mem, m->whichB, pretype, max_krylov_); break; 337 case SD_TFQMR: THROWING(CVSptfqmrB, m->mem, m->whichB, pretype, max_krylov_); break; 338 } 339 THROWING(CVSpilsSetJacTimesVecFnB, m->mem, m->whichB, jtimesB); 340 if (use_precon_) THROWING(CVSpilsSetPreconditionerB, m->mem, m->whichB, psetupB, psolveB); 341 } 342 343 // Quadratures for the backward problem 344 THROWING(CVodeQuadInitB, m->mem, m->whichB, rhsQB, m->rq); 345 if (quad_err_con_) { 346 THROWING(CVodeSetQuadErrConB, m->mem, m->whichB, true); 347 THROWING(CVodeQuadSStolerancesB, m->mem, m->whichB, reltol_, abstol_); 348 } 349 350 // Mark initialized 351 m->first_callB = false; 352 } else { 353 THROWING(CVodeReInitB, m->mem, m->whichB, grid_.back(), m->rxz); 354 THROWING(CVodeQuadReInitB, m->mem, m->whichB, m->rq); 355 } 356 } 357 retreat(IntegratorMemory * mem,double t,double * rx,double * rz,double * rq) const358 void CvodesInterface::retreat(IntegratorMemory* mem, double t, 359 double* rx, double* rz, double* rq) const { 360 auto m = to_mem(mem); 361 // Integrate, unless already at desired time 362 if (t<m->t) { 363 THROWING(CVodeB, m->mem, t, CV_NORMAL); 364 THROWING(CVodeGetB, m->mem, m->whichB, &m->t, m->rxz); 365 if (nrq_>0) { 366 THROWING(CVodeGetQuadB, m->mem, m->whichB, &m->t, m->rq); 367 } 368 } 369 370 // Save outputs 371 casadi_copy(NV_DATA_S(m->rxz), nrx_, rx); 372 casadi_copy(NV_DATA_S(m->rq), nrq_, rq); 373 374 // Get stats 375 CVodeMem cv_mem = static_cast<CVodeMem>(m->mem); 376 CVadjMem ca_mem = cv_mem->cv_adj_mem; 377 CVodeBMem cvB_mem = ca_mem->cvB_mem; 378 THROWING(CVodeGetIntegratorStats, cvB_mem->cv_mem, &m->nstepsB, 379 &m->nfevalsB, &m->nlinsetupsB, &m->netfailsB, &m->qlastB, 380 &m->qcurB, &m->hinusedB, &m->hlastB, &m->hcurB, &m->tcurB); 381 THROWING(CVodeGetNonlinSolvStats, cvB_mem->cv_mem, &m->nnitersB, &m->nncfailsB); 382 } 383 cvodes_error(const char * module,int flag)384 void CvodesInterface::cvodes_error(const char* module, int flag) { 385 // Successfull return or warning 386 if (flag>=CV_SUCCESS) return; 387 // Construct error message 388 char* flagname = CVodeGetReturnFlagName(flag); 389 stringstream ss; 390 ss << module << " returned \"" << flagname << "\". Consult CVODES documentation."; 391 free(flagname); // NOLINT 392 casadi_error(ss.str()); 393 } 394 ehfun(int error_code,const char * module,const char * function,char * msg,void * user_data)395 void CvodesInterface::ehfun(int error_code, const char *module, const char *function, 396 char *msg, void *user_data) { 397 try { 398 casadi_assert_dev(user_data); 399 auto m = to_mem(user_data); 400 auto& s = m->self; 401 if (!s.disable_internal_warnings_) { 402 uerr() << msg << endl; 403 } 404 } catch(exception& e) { 405 uerr() << "ehfun failed: " << e.what() << endl; 406 } 407 } 408 rhsQ(double t,N_Vector x,N_Vector qdot,void * user_data)409 int CvodesInterface::rhsQ(double t, N_Vector x, N_Vector qdot, void *user_data) { 410 try { 411 auto m = to_mem(user_data); 412 auto& s = m->self; 413 m->arg[0] = NV_DATA_S(x); 414 m->arg[1] = m->p; 415 m->arg[2] = &t; 416 m->res[0] = NV_DATA_S(qdot); 417 s.calc_function(m, "quadF"); 418 return 0; 419 } catch(int flag) { // recoverable error 420 return flag; 421 } catch(exception& e) { // non-recoverable error 422 uerr() << "rhsQ failed: " << e.what() << endl; 423 return -1; 424 } 425 } 426 rhsB(double t,N_Vector x,N_Vector rx,N_Vector rxdot,void * user_data)427 int CvodesInterface::rhsB(double t, N_Vector x, N_Vector rx, N_Vector rxdot, 428 void *user_data) { 429 try { 430 casadi_assert_dev(user_data); 431 auto m = to_mem(user_data); 432 auto& s = m->self; 433 m->arg[0] = NV_DATA_S(rx); 434 m->arg[1] = m->rp; 435 m->arg[2] = NV_DATA_S(x); 436 m->arg[3] = m->p; 437 m->arg[4] = &t; 438 m->res[0] = NV_DATA_S(rxdot); 439 s.calc_function(m, "odeB"); 440 441 // Negate (note definition of g) 442 casadi_scal(s.nrx_, -1., NV_DATA_S(rxdot)); 443 444 return 0; 445 } catch(int flag) { // recoverable error 446 return flag; 447 } catch(exception& e) { // non-recoverable error 448 uerr() << "rhsB failed: " << e.what() << endl; 449 return -1; 450 } 451 } 452 rhsQB(double t,N_Vector x,N_Vector rx,N_Vector rqdot,void * user_data)453 int CvodesInterface::rhsQB(double t, N_Vector x, N_Vector rx, 454 N_Vector rqdot, void *user_data) { 455 try { 456 casadi_assert_dev(user_data); 457 auto m = to_mem(user_data); 458 auto& s = m->self; 459 m->arg[0] = NV_DATA_S(rx); 460 m->arg[1] = m->rp; 461 m->arg[2] = NV_DATA_S(x); 462 m->arg[3] = m->p; 463 m->arg[4] = &t; 464 m->res[0] = NV_DATA_S(rqdot); 465 s.calc_function(m, "quadB"); 466 467 // Negate (note definition of g) 468 casadi_scal(s.nrq_, -1., NV_DATA_S(rqdot)); 469 470 return 0; 471 } catch(int flag) { // recoverable error 472 return flag; 473 } catch(exception& e) { // non-recoverable error 474 uerr() << "rhsQB failed: " << e.what() << endl; 475 return -1; 476 } 477 } 478 jtimes(N_Vector v,N_Vector Jv,double t,N_Vector x,N_Vector xdot,void * user_data,N_Vector tmp)479 int CvodesInterface::jtimes(N_Vector v, N_Vector Jv, double t, N_Vector x, 480 N_Vector xdot, void *user_data, N_Vector tmp) { 481 try { 482 auto m = to_mem(user_data); 483 auto& s = m->self; 484 m->arg[0] = &t; 485 m->arg[1] = NV_DATA_S(x); 486 m->arg[2] = m->p; 487 m->arg[3] = NV_DATA_S(v); 488 m->res[0] = NV_DATA_S(Jv); 489 s.calc_function(m, "jtimesF"); 490 return 0; 491 } catch(casadi_int flag) { // recoverable error 492 return flag; 493 } catch(exception& e) { // non-recoverable error 494 uerr() << "jtimes failed: " << e.what() << endl; 495 return -1; 496 } 497 } 498 jtimesB(N_Vector v,N_Vector Jv,double t,N_Vector x,N_Vector rx,N_Vector rxdot,void * user_data,N_Vector tmpB)499 int CvodesInterface::jtimesB(N_Vector v, N_Vector Jv, double t, N_Vector x, 500 N_Vector rx, N_Vector rxdot, void *user_data , 501 N_Vector tmpB) { 502 try { 503 auto m = to_mem(user_data); 504 auto& s = m->self; 505 m->arg[0] = &t; 506 m->arg[1] = NV_DATA_S(x); 507 m->arg[2] = m->p; 508 m->arg[3] = NV_DATA_S(rx); 509 m->arg[4] = m->rp; 510 m->arg[5] = NV_DATA_S(v); 511 m->res[0] = NV_DATA_S(Jv); 512 s.calc_function(m, "jtimesB"); 513 return 0; 514 } catch(int flag) { // recoverable error 515 return flag; 516 } catch(exception& e) { // non-recoverable error 517 uerr() << "jtimes failed: " << e.what() << endl; 518 return -1; 519 } 520 } 521 setStopTime(IntegratorMemory * mem,double tf) const522 void CvodesInterface::setStopTime(IntegratorMemory* mem, double tf) const { 523 // Set the stop time of the integration -- don't integrate past this point 524 auto m = to_mem(mem); 525 THROWING(CVodeSetStopTime, m->mem, tf); 526 } 527 psolve(double t,N_Vector x,N_Vector xdot,N_Vector r,N_Vector z,double gamma,double delta,int lr,void * user_data,N_Vector tmp)528 int CvodesInterface::psolve(double t, N_Vector x, N_Vector xdot, N_Vector r, 529 N_Vector z, double gamma, double delta, int lr, 530 void *user_data, N_Vector tmp) { 531 try { 532 auto m = to_mem(user_data); 533 auto& s = m->self; 534 535 // Get right-hand sides in m->v1 536 double* v = NV_DATA_S(r); 537 casadi_copy(v, s.nx_, m->v1); 538 539 // Solve for undifferentiated right-hand-side, save to output 540 if (s.linsolF_.solve(m->jac, m->v1, 1, false, m->mem_linsolF)) 541 casadi_error("Linear system solve failed"); 542 v = NV_DATA_S(z); // possibly different from r 543 casadi_copy(m->v1, s.nx1_, v); 544 545 // Sensitivity equations 546 if (s.ns_>0) { 547 // Second order correction 548 if (s.second_order_correction_) { 549 // The outputs will double as seeds for jtimesF 550 casadi_clear(v + s.nx1_, s.nx_ - s.nx1_); 551 m->arg[0] = &t; // t 552 m->arg[1] = NV_DATA_S(x); // x 553 m->arg[2] = m->p; // p 554 m->arg[3] = v; // fwd:x 555 m->res[0] = m->v2; // fwd:ode 556 s.calc_function(m, "jtimesF"); 557 558 // Subtract m->v2 from m->v1, scaled with -gamma 559 casadi_axpy(s.nx_ - s.nx1_, m->gamma, m->v2 + s.nx1_, m->v1 + s.nx1_); 560 } 561 562 // Solve for sensitivity right-hand-sides 563 if (s.linsolF_.solve(m->jac, m->v1 + s.nx1_, s.ns_, false, m->mem_linsolF)) 564 casadi_error("Linear solve failed"); 565 566 // Save to output, reordered 567 casadi_copy(m->v1 + s.nx1_, s.nx_-s.nx1_, v+s.nx1_); 568 } 569 570 return 0; 571 } catch(int flag) { // recoverable error 572 return flag; 573 } catch(exception& e) { // non-recoverable error 574 uerr() << "psolve failed: " << e.what() << endl; 575 return -1; 576 } 577 } 578 psolveB(double t,N_Vector x,N_Vector xB,N_Vector xdotB,N_Vector rvecB,N_Vector zvecB,double gammaB,double deltaB,int lr,void * user_data,N_Vector tmpB)579 int CvodesInterface::psolveB(double t, N_Vector x, N_Vector xB, N_Vector xdotB, 580 N_Vector rvecB, N_Vector zvecB, double gammaB, 581 double deltaB, int lr, void *user_data, N_Vector tmpB) { 582 try { 583 auto m = to_mem(user_data); 584 auto& s = m->self; 585 586 // Get right-hand sides in m->v1 587 double* v = NV_DATA_S(rvecB); 588 casadi_copy(v, s.nrx_, m->v1); 589 590 // Solve for undifferentiated right-hand-side, save to output 591 if (s.linsolB_.solve(m->jacB, m->v1, 1, false, m->mem_linsolB)) 592 casadi_error("Linear solve failed"); 593 v = NV_DATA_S(zvecB); // possibly different from rvecB 594 casadi_copy(m->v1, s.nrx1_, v); 595 596 // Sensitivity equations 597 if (s.ns_>0) { 598 // Second order correction 599 if (s.second_order_correction_) { 600 // The outputs will double as seeds for jtimesF 601 casadi_clear(v + s.nrx1_, s.nrx_ - s.nrx1_); 602 m->arg[0] = &t; // t 603 m->arg[1] = NV_DATA_S(x); // x 604 m->arg[2] = m->p; // p 605 m->arg[3] = NV_DATA_S(xB); // rx 606 m->arg[4] = m->rp; // rp 607 m->arg[5] = v; // fwd:rx 608 m->res[0] = m->v2; // fwd:rode 609 s.calc_function(m, "jtimesB"); 610 611 // Subtract m->v2 from m->v1, scaled with gammaB 612 casadi_axpy(s.nrx_-s.nrx1_, -m->gammaB, m->v2 + s.nrx1_, m->v1 + s.nrx1_); 613 } 614 615 // Solve for sensitivity right-hand-sides 616 if (s.linsolB_.solve(m->jacB, m->v1 + s.nx1_, s.ns_, false, m->mem_linsolB)) { 617 casadi_error("Linear solve failed"); 618 } 619 620 // Save to output, reordered 621 casadi_copy(m->v1 + s.nx1_, s.nx_-s.nx1_, v+s.nx1_); 622 } 623 624 return 0; 625 } catch(int flag) { // recoverable error 626 return flag; 627 } catch(exception& e) { // non-recoverable error 628 uerr() << "psolveB failed: " << e.what() << endl; 629 return -1; 630 } 631 } 632 psetup(double t,N_Vector x,N_Vector xdot,booleantype jok,booleantype * jcurPtr,double gamma,void * user_data,N_Vector tmp1,N_Vector tmp2,N_Vector tmp3)633 int CvodesInterface::psetup(double t, N_Vector x, N_Vector xdot, booleantype jok, 634 booleantype *jcurPtr, double gamma, void *user_data, 635 N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) { 636 try { 637 auto m = to_mem(user_data); 638 auto& s = m->self; 639 // Store gamma for later 640 m->gamma = gamma; 641 642 // Calculate Jacobian 643 double d1 = -gamma, d2 = 1.; 644 m->arg[0] = &t; 645 m->arg[1] = NV_DATA_S(x); 646 m->arg[2] = m->p; 647 m->arg[3] = &d1; 648 m->arg[4] = &d2; 649 m->res[0] = m->jac; 650 if (s.calc_function(m, "jacF")) casadi_error("'jacF' calculation failed"); 651 652 // Prepare the solution of the linear system (e.g. factorize) 653 if (s.linsolF_.nfact(m->jac, m->mem_linsolF)) casadi_error("'jacF' factorization failed"); 654 655 return 0; 656 } catch(int flag) { // recoverable error 657 return flag; 658 } catch(exception& e) { // non-recoverable error 659 uerr() << "psetup failed: " << e.what() << endl; 660 return -1; 661 } 662 } 663 psetupB(double t,N_Vector x,N_Vector rx,N_Vector rxdot,booleantype jokB,booleantype * jcurPtrB,double gammaB,void * user_data,N_Vector tmp1B,N_Vector tmp2B,N_Vector tmp3B)664 int CvodesInterface::psetupB(double t, N_Vector x, N_Vector rx, N_Vector rxdot, 665 booleantype jokB, booleantype *jcurPtrB, double gammaB, 666 void *user_data, N_Vector tmp1B, N_Vector tmp2B, 667 N_Vector tmp3B) { 668 try { 669 auto m = to_mem(user_data); 670 auto& s = m->self; 671 // Store gamma for later 672 m->gammaB = gammaB; 673 // Calculate Jacobian 674 double one=1; 675 m->arg[0] = &t; 676 m->arg[1] = NV_DATA_S(rx); 677 m->arg[2] = m->rp; 678 m->arg[3] = NV_DATA_S(x); 679 m->arg[4] = m->p; 680 m->arg[5] = &gammaB; 681 m->arg[6] = &one; 682 m->res[0] = m->jacB; 683 if (s.calc_function(m, "jacB")) casadi_error("'jacB' calculation failed"); 684 685 // Prepare the solution of the linear system (e.g. factorize) 686 if (s.linsolB_.nfact(m->jacB, m->mem_linsolB)) casadi_error("'jacB' factorization failed"); 687 688 return 0; 689 } catch(int flag) { // recoverable error 690 return flag; 691 } catch(exception& e) { // non-recoverable error 692 uerr() << "psetupB failed: " << e.what() << endl; 693 return -1; 694 } 695 } 696 lsetup(CVodeMem cv_mem,int convfail,N_Vector x,N_Vector xdot,booleantype * jcurPtr,N_Vector vtemp1,N_Vector vtemp2,N_Vector vtemp3)697 int CvodesInterface::lsetup(CVodeMem cv_mem, int convfail, N_Vector x, N_Vector xdot, 698 booleantype *jcurPtr, 699 N_Vector vtemp1, N_Vector vtemp2, N_Vector vtemp3) { 700 try { 701 auto m = to_mem(cv_mem->cv_lmem); 702 //auto& s = m->self; 703 704 // Current time 705 double t = cv_mem->cv_tn; 706 707 // Scaling factor before J 708 double gamma = cv_mem->cv_gamma; 709 710 // Call the preconditioner setup function (which sets up the linear solver) 711 if (psetup(t, x, xdot, FALSE, jcurPtr, 712 gamma, static_cast<void*>(m), vtemp1, vtemp2, vtemp3)) return 1; 713 714 return 0; 715 } catch(int flag) { // recoverable error 716 return flag; 717 } catch(exception& e) { // non-recoverable error 718 uerr() << "lsetup failed: " << e.what() << endl; 719 return -1; 720 } 721 } 722 lsetupB(CVodeMem cv_mem,int convfail,N_Vector x,N_Vector xdot,booleantype * jcurPtr,N_Vector vtemp1,N_Vector vtemp2,N_Vector vtemp3)723 int CvodesInterface::lsetupB(CVodeMem cv_mem, int convfail, N_Vector x, N_Vector xdot, 724 booleantype *jcurPtr, 725 N_Vector vtemp1, N_Vector vtemp2, N_Vector vtemp3) { 726 try { 727 auto m = to_mem(cv_mem->cv_lmem); 728 CVadjMem ca_mem; 729 //CVodeBMem cvB_mem; 730 731 int flag; 732 733 // Current time 734 double t = cv_mem->cv_tn; // TODO(Joel): is this correct? 735 double gamma = cv_mem->cv_gamma; 736 737 cv_mem = static_cast<CVodeMem>(cv_mem->cv_user_data); 738 739 ca_mem = cv_mem->cv_adj_mem; 740 //cvB_mem = ca_mem->ca_bckpbCrt; 741 742 // Get FORWARD solution from interpolation. 743 flag = ca_mem->ca_IMget(cv_mem, t, ca_mem->ca_ytmp, nullptr); 744 if (flag != CV_SUCCESS) casadi_error("Could not interpolate forward states"); 745 746 // Call the preconditioner setup function (which sets up the linear solver) 747 if (psetupB(t, ca_mem->ca_ytmp, x, xdot, FALSE, jcurPtr, 748 gamma, static_cast<void*>(m), vtemp1, vtemp2, vtemp3)) return 1; 749 750 return 0; 751 } catch(int flag) { // recoverable error 752 return flag; 753 } catch(exception& e) { // non-recoverable error 754 uerr() << "lsetupB failed: " << e.what() << endl; 755 return -1; 756 } 757 } 758 lsolve(CVodeMem cv_mem,N_Vector b,N_Vector weight,N_Vector x,N_Vector xdot)759 int CvodesInterface::lsolve(CVodeMem cv_mem, N_Vector b, N_Vector weight, 760 N_Vector x, N_Vector xdot) { 761 try { 762 auto m = to_mem(cv_mem->cv_lmem); 763 //auto& s = m->self; 764 765 // Current time 766 double t = cv_mem->cv_tn; 767 768 // Scaling factor before J 769 double gamma = cv_mem->cv_gamma; 770 771 // Accuracy 772 double delta = 0.0; 773 774 // Left/right preconditioner 775 casadi_int lr = 1; 776 777 // Call the preconditioner solve function (which solves the linear system) 778 if (psolve(t, x, xdot, b, b, gamma, delta, 779 lr, static_cast<void*>(m), nullptr)) return 1; 780 781 return 0; 782 } catch(int flag) { // recoverable error 783 return flag; 784 } catch(exception& e) { // non-recoverable error 785 uerr() << "lsolve failed: " << e.what() << endl; 786 return -1; 787 } 788 } 789 lsolveB(CVodeMem cv_mem,N_Vector b,N_Vector weight,N_Vector x,N_Vector xdot)790 int CvodesInterface::lsolveB(CVodeMem cv_mem, N_Vector b, N_Vector weight, 791 N_Vector x, N_Vector xdot) { 792 try { 793 auto m = to_mem(cv_mem->cv_lmem); 794 CVadjMem ca_mem; 795 //CVodeBMem cvB_mem; 796 797 int flag; 798 799 // Current time 800 double t = cv_mem->cv_tn; // TODO(Joel): is this correct? 801 double gamma = cv_mem->cv_gamma; 802 803 cv_mem = static_cast<CVodeMem>(cv_mem->cv_user_data); 804 805 ca_mem = cv_mem->cv_adj_mem; 806 //cvB_mem = ca_mem->ca_bckpbCrt; 807 808 // Get FORWARD solution from interpolation. 809 flag = ca_mem->ca_IMget(cv_mem, t, ca_mem->ca_ytmp, nullptr); 810 if (flag != CV_SUCCESS) casadi_error("Could not interpolate forward states"); 811 812 813 814 // Accuracy 815 double delta = 0.0; 816 817 // Left/right preconditioner 818 int lr = 1; 819 820 // Call the preconditioner solve function (which solves the linear system) 821 if (psolveB(t, ca_mem->ca_ytmp, x, xdot, b, b, gamma, delta, lr, 822 static_cast<void*>(m), nullptr)) return 1; 823 824 return 0; 825 } catch(int flag) { // recoverable error 826 return flag; 827 } catch(exception& e) { // non-recoverable error 828 uerr() << "lsolveB failed: " << e.what() << endl; 829 return -1; 830 } 831 } 832 getJ(bool b) const833 Function CvodesInterface::getJ(bool b) const { 834 return oracle_.is_a("SXFunction") ? getJ<SX>(b) : getJ<MX>(b); 835 } 836 837 template<typename MatType> getJ(bool backward) const838 Function CvodesInterface::getJ(bool backward) const { 839 vector<MatType> a = MatType::get_input(oracle_); 840 vector<MatType> r = const_cast<Function&>(oracle_)(a); // NOLINT 841 MatType c_x = MatType::sym("c_x"); 842 MatType c_xdot = MatType::sym("c_xdot"); 843 844 // Get the Jacobian in the Newton iteration 845 if (backward) { 846 MatType jac = c_x*MatType::jacobian(r[DE_RODE], a[DE_RX]) 847 + c_xdot*MatType::eye(nrx_); 848 return Function("jacB", 849 {a[DE_T], a[DE_RX], a[DE_RP], 850 a[DE_X], a[DE_P], c_x, c_xdot}, {jac}); 851 } else { 852 MatType jac = c_x*MatType::jacobian(r[DE_ODE], a[DE_X]) 853 + c_xdot*MatType::eye(nx_); 854 return Function("jacF", {a[DE_T], a[DE_X], a[DE_P], c_x, c_xdot}, {jac}); 855 } 856 } 857 CvodesMemory(const CvodesInterface & s)858 CvodesMemory::CvodesMemory(const CvodesInterface& s) : self(s) { 859 this->mem = nullptr; 860 861 // Reset checkpoints counter 862 this->ncheck = 0; 863 } 864 ~CvodesMemory()865 CvodesMemory::~CvodesMemory() { 866 if (this->mem) CVodeFree(&this->mem); 867 } 868 CvodesInterface(DeserializingStream & s)869 CvodesInterface::CvodesInterface(DeserializingStream& s) : SundialsInterface(s) { 870 int version = s.version("CvodesInterface", 1, 2); 871 s.unpack("CvodesInterface::lmm", lmm_); 872 s.unpack("CvodesInterface::iter", iter_); 873 874 if (version>=2) { 875 s.unpack("CvodesInterface::min_step_size", min_step_size_); 876 } else { 877 min_step_size_ = 0; 878 } 879 } 880 serialize_body(SerializingStream & s) const881 void CvodesInterface::serialize_body(SerializingStream &s) const { 882 SundialsInterface::serialize_body(s); 883 s.version("CvodesInterface", 2); 884 885 s.pack("CvodesInterface::lmm", lmm_); 886 s.pack("CvodesInterface::iter", iter_); 887 s.pack("CvodesInterface::min_step_size", min_step_size_); 888 889 } 890 891 } // namespace casadi 892