1 /* -----------------------------------------------------------------
2 * Programmer(s): Radu Serban @ LLNL
3 * -----------------------------------------------------------------
4 * SUNDIALS Copyright Start
5 * Copyright (c) 2002-2021, Lawrence Livermore National Security
6 * and Southern Methodist University.
7 * All rights reserved.
8 *
9 * See the top-level LICENSE and NOTICE files for details.
10 *
11 * SPDX-License-Identifier: BSD-3-Clause
12 * SUNDIALS Copyright End
13 * -----------------------------------------------------------------
14 * Adjoint sensitivity example problem.
15 * The following is a simple example problem, with the coding
16 * needed for its solution by CVODES. The problem is from chemical
17 * kinetics, and consists of the following three rate equations.
18 * dy1/dt = -p1*y1 + p2*y2*y3
19 * dy2/dt = p1*y1 - p2*y2*y3 - p3*(y2)^2
20 * dy3/dt = p3*(y2)^2
21 * on the interval from t = 0.0 to t = 4.e10, with initial
22 * conditions: y1 = 1.0, y2 = y3 = 0. The reaction rates are:
23 * p1=0.04, p2=1e4, and p3=3e7. The problem is stiff.
24 * This program solves the problem with the BDF method, Newton
25 * iteration with the DENSE linear solver, and a user-supplied
26 * Jacobian routine.
27 * It uses a scalar relative tolerance and a vector absolute
28 * tolerance.
29 * Output is printed in decades from t = .4 to t = 4.e10.
30 * Run statistics (optional outputs) are printed at the end.
31 *
32 * Optionally, CVODES can compute sensitivities with respect to
33 * the problem parameters p1, p2, and p3 of the following quantity:
34 * G = int_t0^t1 g(t,p,y) dt
35 * where
36 * g(t,p,y) = y3
37 *
38 * The gradient dG/dp is obtained as:
39 * dG/dp = int_t0^t1 (g_p - lambda^T f_p ) dt - lambda^T(t0)*y0_p
40 * = - xi^T(t0) - lambda^T(t0)*y0_p
41 * where lambda and xi are solutions of:
42 * d(lambda)/dt = - (f_y)^T * lambda - (g_y)^T
43 * lambda(t1) = 0
44 * and
45 * d(xi)/dt = - (f_p)^T * lambda + (g_p)^T
46 * xi(t1) = 0
47 *
48 * During the backward integration, CVODES also evaluates G as
49 * G = - phi(t0)
50 * where
51 * d(phi)/dt = g(t,y,p)
52 * phi(t1) = 0
53 * -----------------------------------------------------------------*/
54
55 #include <stdio.h>
56 #include <stdlib.h>
57
58 #include <cvodes/cvodes.h> /* prototypes for CVODE fcts., consts. */
59 #include <nvector/nvector_serial.h> /* access to serial N_Vector */
60 #include <sunmatrix/sunmatrix_dense.h> /* access to dense SUNMatrix */
61 #include <sunlinsol/sunlinsol_dense.h> /* access to dense SUNLinearSolver */
62 #include <sundials/sundials_types.h> /* defs. of realtype, sunindextype */
63 #include <sundials/sundials_math.h> /* defs. of SUNRabs, SUNRexp, etc. */
64
65 /* Accessor macros */
66
67 #define Ith(v,i) NV_Ith_S(v,i-1) /* i-th vector component, i=1..NEQ */
68 #define IJth(A,i,j) SM_ELEMENT_D(A,i-1,j-1) /* (i,j)-th matrix el., i,j=1..NEQ */
69
70 /* Problem Constants */
71
72 #define NEQ 3 /* number of equations */
73
74 #define RTOL RCONST(1e-6) /* scalar relative tolerance */
75
76 #define ATOL1 RCONST(1e-8) /* vector absolute tolerance components */
77 #define ATOL2 RCONST(1e-14)
78 #define ATOL3 RCONST(1e-6)
79
80 #define ATOLl RCONST(1e-8) /* absolute tolerance for adjoint vars. */
81 #define ATOLq RCONST(1e-6) /* absolute tolerance for quadratures */
82
83 #define T0 RCONST(0.0) /* initial time */
84 #define TOUT RCONST(4e7) /* final time */
85
86 #define TB1 RCONST(4e7) /* starting point for adjoint problem */
87 #define TB2 RCONST(50.0) /* starting point for adjoint problem */
88 #define TBout1 RCONST(40.0) /* intermediate t for adjoint problem */
89
90 #define STEPS 150 /* number of steps between check points */
91
92 #define NP 3 /* number of problem parameters */
93
94 #define ZERO RCONST(0.0)
95
96
97 /* Type : UserData */
98
99 typedef struct {
100 realtype p[3];
101 } *UserData;
102
103 /* Prototypes of user-supplied functions */
104
105 static int f(realtype t, N_Vector y, N_Vector ydot, void *user_data);
106 static int Jac(realtype t, N_Vector y, N_Vector fy, SUNMatrix J,
107 void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
108 static int fQ(realtype t, N_Vector y, N_Vector qdot, void *user_data);
109 static int ewt(N_Vector y, N_Vector w, void *user_data);
110
111 static int fB(realtype t, N_Vector y,
112 N_Vector yB, N_Vector yBdot, void *user_dataB);
113 static int JacB(realtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix JB,
114 void *user_dataB, N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B);
115 static int fQB(realtype t, N_Vector y, N_Vector yB,
116 N_Vector qBdot, void *user_dataB);
117
118
119 /* Prototypes of private functions */
120
121 static void PrintHead(realtype tB0);
122 static void PrintOutput(realtype tfinal, N_Vector y, N_Vector yB, N_Vector qB);
123 static void PrintOutput1(realtype time, realtype t, N_Vector y, N_Vector yB);
124 static int check_retval(void *returnvalue, const char *funcname, int opt);
125
126 /*
127 *--------------------------------------------------------------------
128 * MAIN PROGRAM
129 *--------------------------------------------------------------------
130 */
131
main(int argc,char * argv[])132 int main(int argc, char *argv[])
133 {
134 UserData data;
135
136 SUNMatrix A, AB;
137 SUNLinearSolver LS, LSB;
138 void *cvode_mem;
139
140 realtype reltolQ, abstolQ;
141 N_Vector y, q;
142
143 int steps;
144
145 int indexB;
146
147 realtype reltolB, abstolB, abstolQB;
148 N_Vector yB, qB;
149
150 realtype time;
151 int retval, ncheck;
152
153 long int nst, nstB;
154
155 CVadjCheckPointRec *ckpnt;
156
157 data = NULL;
158 A = AB = NULL;
159 LS = LSB = NULL;
160 cvode_mem = NULL;
161 ckpnt = NULL;
162 y = yB = qB = NULL;
163
164 /* Print problem description */
165 printf("\nAdjoint Sensitivity Example for Chemical Kinetics\n");
166 printf("-------------------------------------------------\n\n");
167 printf("ODE: dy1/dt = -p1*y1 + p2*y2*y3\n");
168 printf(" dy2/dt = p1*y1 - p2*y2*y3 - p3*(y2)^2\n");
169 printf(" dy3/dt = p3*(y2)^2\n\n");
170 printf("Find dG/dp for\n");
171 printf(" G = int_t0^tB0 g(t,p,y) dt\n");
172 printf(" g(t,p,y) = y3\n\n\n");
173
174 /* User data structure */
175 data = (UserData) malloc(sizeof *data);
176 if (check_retval((void *)data, "malloc", 2)) return(1);
177 data->p[0] = RCONST(0.04);
178 data->p[1] = RCONST(1.0e4);
179 data->p[2] = RCONST(3.0e7);
180
181 /* Initialize y */
182 y = N_VNew_Serial(NEQ);
183 if (check_retval((void *)y, "N_VNew_Serial", 0)) return(1);
184 Ith(y,1) = RCONST(1.0);
185 Ith(y,2) = ZERO;
186 Ith(y,3) = ZERO;
187
188 /* Initialize q */
189 q = N_VNew_Serial(1);
190 if (check_retval((void *)q, "N_VNew_Serial", 0)) return(1);
191 Ith(q,1) = ZERO;
192
193 /* Set the scalar realtive and absolute tolerances reltolQ and abstolQ */
194 reltolQ = RTOL;
195 abstolQ = ATOLq;
196
197 /* Create and allocate CVODES memory for forward run */
198 printf("Create and allocate CVODES memory for forward runs\n");
199
200 /* Call CVodeCreate to create the solver memory and specify the
201 Backward Differentiation Formula */
202 cvode_mem = CVodeCreate(CV_BDF);
203 if (check_retval((void *)cvode_mem, "CVodeCreate", 0)) return(1);
204
205 /* Call CVodeInit to initialize the integrator memory and specify the
206 user's right hand side function in y'=f(t,y), the initial time T0, and
207 the initial dependent variable vector y. */
208 retval = CVodeInit(cvode_mem, f, T0, y);
209 if (check_retval(&retval, "CVodeInit", 1)) return(1);
210
211 /* Call CVodeWFtolerances to specify a user-supplied function ewt that sets
212 the multiplicative error weights w_i for use in the weighted RMS norm */
213 retval = CVodeWFtolerances(cvode_mem, ewt);
214 if (check_retval(&retval, "CVodeWFtolerances", 1)) return(1);
215
216 /* Attach user data */
217 retval = CVodeSetUserData(cvode_mem, data);
218 if (check_retval(&retval, "CVodeSetUserData", 1)) return(1);
219
220 /* Create dense SUNMatrix for use in linear solves */
221 A = SUNDenseMatrix(NEQ, NEQ);
222 if (check_retval((void *)A, "SUNDenseMatrix", 0)) return(1);
223
224 /* Create dense SUNLinearSolver object */
225 LS = SUNLinSol_Dense(y, A);
226 if (check_retval((void *)LS, "SUNLinSol_Dense", 0)) return(1);
227
228 /* Attach the matrix and linear solver */
229 retval = CVodeSetLinearSolver(cvode_mem, LS, A);
230 if (check_retval(&retval, "CVodeSetLinearSolver", 1)) return(1);
231
232 /* Set the user-supplied Jacobian routine Jac */
233 retval = CVodeSetJacFn(cvode_mem, Jac);
234 if (check_retval(&retval, "CVodeSetJacFn", 1)) return(1);
235
236 /* Call CVodeQuadInit to allocate initernal memory and initialize
237 quadrature integration*/
238 retval = CVodeQuadInit(cvode_mem, fQ, q);
239 if (check_retval(&retval, "CVodeQuadInit", 1)) return(1);
240
241 /* Call CVodeSetQuadErrCon to specify whether or not the quadrature variables
242 are to be used in the step size control mechanism within CVODES. Call
243 CVodeQuadSStolerances or CVodeQuadSVtolerances to specify the integration
244 tolerances for the quadrature variables. */
245 retval = CVodeSetQuadErrCon(cvode_mem, SUNTRUE);
246 if (check_retval(&retval, "CVodeSetQuadErrCon", 1)) return(1);
247
248 /* Call CVodeQuadSStolerances to specify scalar relative and absolute
249 tolerances. */
250 retval = CVodeQuadSStolerances(cvode_mem, reltolQ, abstolQ);
251 if (check_retval(&retval, "CVodeQuadSStolerances", 1)) return(1);
252
253 /* Call CVodeSetMaxNumSteps to set the maximum number of steps the
254 * solver will take in an attempt to reach the next output time
255 * during forward integration. */
256 retval = CVodeSetMaxNumSteps(cvode_mem, 2500);
257 if (check_retval(&retval, "CVodeSetMaxNumSteps", 1)) return(1);
258
259 /* Allocate global memory */
260
261 /* Call CVodeAdjInit to update CVODES memory block by allocting the internal
262 memory needed for backward integration.*/
263 steps = STEPS; /* no. of integration steps between two consecutive ckeckpoints*/
264 retval = CVodeAdjInit(cvode_mem, steps, CV_HERMITE);
265 /*
266 retval = CVodeAdjInit(cvode_mem, steps, CV_POLYNOMIAL);
267 */
268 if (check_retval(&retval, "CVodeAdjInit", 1)) return(1);
269
270 /* Perform forward run */
271 printf("Forward integration ... ");
272
273 /* Call CVodeF to integrate the forward problem over an interval in time and
274 saves checkpointing data */
275 retval = CVodeF(cvode_mem, TOUT, y, &time, CV_NORMAL, &ncheck);
276 if (check_retval(&retval, "CVodeF", 1)) return(1);
277 retval = CVodeGetNumSteps(cvode_mem, &nst);
278 if (check_retval(&retval, "CVodeGetNumSteps", 1)) return(1);
279
280 printf("done ( nst = %ld )\n",nst);
281 printf("\nncheck = %d\n\n", ncheck);
282
283 retval = CVodeGetQuad(cvode_mem, &time, q);
284 if (check_retval(&retval, "CVodeGetQuad", 1)) return(1);
285
286 printf("--------------------------------------------------------\n");
287 #if defined(SUNDIALS_EXTENDED_PRECISION)
288 printf("G: %12.4Le \n",Ith(q,1));
289 #elif defined(SUNDIALS_DOUBLE_PRECISION)
290 printf("G: %12.4e \n",Ith(q,1));
291 #else
292 printf("G: %12.4e \n",Ith(q,1));
293 #endif
294 printf("--------------------------------------------------------\n\n");
295
296 /* Test check point linked list
297 (uncomment next block to print check point information) */
298
299 /*
300 {
301 int i;
302
303 printf("\nList of Check Points (ncheck = %d)\n\n", ncheck);
304 ckpnt = (CVadjCheckPointRec *) malloc ( (ncheck+1)*sizeof(CVadjCheckPointRec));
305 CVodeGetAdjCheckPointsInfo(cvode_mem, ckpnt);
306 for (i=0;i<=ncheck;i++) {
307 printf("Address: %p\n",ckpnt[i].my_addr);
308 printf("Next: %p\n",ckpnt[i].next_addr);
309 printf("Time interval: %le %le\n",ckpnt[i].t0, ckpnt[i].t1);
310 printf("Step number: %ld\n",ckpnt[i].nstep);
311 printf("Order: %d\n",ckpnt[i].order);
312 printf("Step size: %le\n",ckpnt[i].step);
313 printf("\n");
314 }
315
316 }
317 */
318
319 /* Initialize yB */
320 yB = N_VNew_Serial(NEQ);
321 if (check_retval((void *)yB, "N_VNew_Serial", 0)) return(1);
322 Ith(yB,1) = ZERO;
323 Ith(yB,2) = ZERO;
324 Ith(yB,3) = ZERO;
325
326 /* Initialize qB */
327 qB = N_VNew_Serial(NP);
328 if (check_retval((void *)qB, "N_VNew", 0)) return(1);
329 Ith(qB,1) = ZERO;
330 Ith(qB,2) = ZERO;
331 Ith(qB,3) = ZERO;
332
333 /* Set the scalar relative tolerance reltolB */
334 reltolB = RTOL;
335
336 /* Set the scalar absolute tolerance abstolB */
337 abstolB = ATOLl;
338
339 /* Set the scalar absolute tolerance abstolQB */
340 abstolQB = ATOLq;
341
342 /* Create and allocate CVODES memory for backward run */
343 printf("Create and allocate CVODES memory for backward run\n");
344
345 /* Call CVodeCreateB to specify the solution method for the backward
346 problem. */
347 retval = CVodeCreateB(cvode_mem, CV_BDF, &indexB);
348 if (check_retval(&retval, "CVodeCreateB", 1)) return(1);
349
350 /* Call CVodeInitB to allocate internal memory and initialize the
351 backward problem. */
352 retval = CVodeInitB(cvode_mem, indexB, fB, TB1, yB);
353 if (check_retval(&retval, "CVodeInitB", 1)) return(1);
354
355 /* Set the scalar relative and absolute tolerances. */
356 retval = CVodeSStolerancesB(cvode_mem, indexB, reltolB, abstolB);
357 if (check_retval(&retval, "CVodeSStolerancesB", 1)) return(1);
358
359 /* Attach the user data for backward problem. */
360 retval = CVodeSetUserDataB(cvode_mem, indexB, data);
361 if (check_retval(&retval, "CVodeSetUserDataB", 1)) return(1);
362
363 /* Create dense SUNMatrix for use in linear solves */
364 AB = SUNDenseMatrix(NEQ, NEQ);
365 if (check_retval((void *)AB, "SUNDenseMatrix", 0)) return(1);
366
367 /* Create dense SUNLinearSolver object */
368 LSB = SUNLinSol_Dense(yB, AB);
369 if (check_retval((void *)LSB, "SUNLinSol_Dense", 0)) return(1);
370
371 /* Attach the matrix and linear solver */
372 retval = CVodeSetLinearSolverB(cvode_mem, indexB, LSB, AB);
373 if (check_retval(&retval, "CVodeSetLinearSolverB", 1)) return(1);
374
375 /* Set the user-supplied Jacobian routine JacB */
376 retval = CVodeSetJacFnB(cvode_mem, indexB, JacB);
377 if (check_retval(&retval, "CVodeSetJacFnB", 1)) return(1);
378
379 /* Call CVodeQuadInitB to allocate internal memory and initialize backward
380 quadrature integration. */
381 retval = CVodeQuadInitB(cvode_mem, indexB, fQB, qB);
382 if (check_retval(&retval, "CVodeQuadInitB", 1)) return(1);
383
384 /* Call CVodeSetQuadErrCon to specify whether or not the quadrature variables
385 are to be used in the step size control mechanism within CVODES. Call
386 CVodeQuadSStolerances or CVodeQuadSVtolerances to specify the integration
387 tolerances for the quadrature variables. */
388 retval = CVodeSetQuadErrConB(cvode_mem, indexB, SUNTRUE);
389 if (check_retval(&retval, "CVodeSetQuadErrConB", 1)) return(1);
390
391 /* Call CVodeQuadSStolerancesB to specify the scalar relative and absolute tolerances
392 for the backward problem. */
393 retval = CVodeQuadSStolerancesB(cvode_mem, indexB, reltolB, abstolQB);
394 if (check_retval(&retval, "CVodeQuadSStolerancesB", 1)) return(1);
395
396 /* Backward Integration */
397
398 PrintHead(TB1);
399
400 /* First get results at t = TBout1 */
401
402 /* Call CVodeB to integrate the backward ODE problem. */
403 retval = CVodeB(cvode_mem, TBout1, CV_NORMAL);
404 if (check_retval(&retval, "CVodeB", 1)) return(1);
405
406 /* Call CVodeGetB to get yB of the backward ODE problem. */
407 retval = CVodeGetB(cvode_mem, indexB, &time, yB);
408 if (check_retval(&retval, "CVodeGetB", 1)) return(1);
409
410 /* Call CVodeGetAdjY to get the interpolated value of the forward solution
411 y during a backward integration. */
412 retval = CVodeGetAdjY(cvode_mem, TBout1, y);
413 if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
414
415 PrintOutput1(time, TBout1, y, yB);
416
417 /* Then at t = T0 */
418
419 retval = CVodeB(cvode_mem, T0, CV_NORMAL);
420 if (check_retval(&retval, "CVodeB", 1)) return(1);
421 CVodeGetNumSteps(CVodeGetAdjCVodeBmem(cvode_mem, indexB), &nstB);
422 printf("Done ( nst = %ld )\n", nstB);
423
424 retval = CVodeGetB(cvode_mem, indexB, &time, yB);
425 if (check_retval(&retval, "CVodeGetB", 1)) return(1);
426
427 /* Call CVodeGetQuadB to get the quadrature solution vector after a
428 successful return from CVodeB. */
429 retval = CVodeGetQuadB(cvode_mem, indexB, &time, qB);
430 if (check_retval(&retval, "CVodeGetQuadB", 1)) return(1);
431
432 retval = CVodeGetAdjY(cvode_mem, T0, y);
433 if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
434
435 PrintOutput(time, y, yB, qB);
436
437 /* Reinitialize backward phase (new tB0) */
438
439 Ith(yB,1) = ZERO;
440 Ith(yB,2) = ZERO;
441 Ith(yB,3) = ZERO;
442
443 Ith(qB,1) = ZERO;
444 Ith(qB,2) = ZERO;
445 Ith(qB,3) = ZERO;
446
447 printf("Re-initialize CVODES memory for backward run\n");
448
449 retval = CVodeReInitB(cvode_mem, indexB, TB2, yB);
450 if (check_retval(&retval, "CVodeReInitB", 1)) return(1);
451
452 retval = CVodeQuadReInitB(cvode_mem, indexB, qB);
453 if (check_retval(&retval, "CVodeQuadReInitB", 1)) return(1);
454
455 PrintHead(TB2);
456
457 /* First get results at t = TBout1 */
458
459 retval = CVodeB(cvode_mem, TBout1, CV_NORMAL);
460 if (check_retval(&retval, "CVodeB", 1)) return(1);
461
462 retval = CVodeGetB(cvode_mem, indexB, &time, yB);
463 if (check_retval(&retval, "CVodeGetB", 1)) return(1);
464
465 retval = CVodeGetAdjY(cvode_mem, TBout1, y);
466 if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
467
468 PrintOutput1(time, TBout1, y, yB);
469
470 /* Then at t = T0 */
471
472 retval = CVodeB(cvode_mem, T0, CV_NORMAL);
473 if (check_retval(&retval, "CVodeB", 1)) return(1);
474 CVodeGetNumSteps(CVodeGetAdjCVodeBmem(cvode_mem, indexB), &nstB);
475 printf("Done ( nst = %ld )\n", nstB);
476
477 retval = CVodeGetB(cvode_mem, indexB, &time, yB);
478 if (check_retval(&retval, "CVodeGetB", 1)) return(1);
479
480 retval = CVodeGetQuadB(cvode_mem, indexB, &time, qB);
481 if (check_retval(&retval, "CVodeGetQuadB", 1)) return(1);
482
483 retval = CVodeGetAdjY(cvode_mem, T0, y);
484 if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
485
486 PrintOutput(time, y, yB, qB);
487
488 /* Free memory */
489 printf("Free memory\n\n");
490
491 CVodeFree(&cvode_mem);
492 N_VDestroy(y);
493 N_VDestroy(q);
494 N_VDestroy(yB);
495 N_VDestroy(qB);
496 SUNLinSolFree(LS);
497 SUNMatDestroy(A);
498 SUNLinSolFree(LSB);
499 SUNMatDestroy(AB);
500
501 if (ckpnt != NULL) free(ckpnt);
502 free(data);
503
504 return(0);
505
506 }
507
508 /*
509 *--------------------------------------------------------------------
510 * FUNCTIONS CALLED BY CVODES
511 *--------------------------------------------------------------------
512 */
513
514 /*
515 * f routine. Compute f(t,y).
516 */
517
f(realtype t,N_Vector y,N_Vector ydot,void * user_data)518 static int f(realtype t, N_Vector y, N_Vector ydot, void *user_data)
519 {
520 realtype y1, y2, y3, yd1, yd3;
521 UserData data;
522 realtype p1, p2, p3;
523
524 y1 = Ith(y,1); y2 = Ith(y,2); y3 = Ith(y,3);
525 data = (UserData) user_data;
526 p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
527
528 yd1 = Ith(ydot,1) = -p1*y1 + p2*y2*y3;
529 yd3 = Ith(ydot,3) = p3*y2*y2;
530 Ith(ydot,2) = -yd1 - yd3;
531
532 return(0);
533 }
534
535 /*
536 * Jacobian routine. Compute J(t,y).
537 */
538
Jac(realtype t,N_Vector y,N_Vector fy,SUNMatrix J,void * user_data,N_Vector tmp1,N_Vector tmp2,N_Vector tmp3)539 static int Jac(realtype t, N_Vector y, N_Vector fy, SUNMatrix J,
540 void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
541 {
542 realtype y2, y3;
543 UserData data;
544 realtype p1, p2, p3;
545
546 y2 = Ith(y,2); y3 = Ith(y,3);
547 data = (UserData) user_data;
548 p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
549
550 IJth(J,1,1) = -p1; IJth(J,1,2) = p2*y3; IJth(J,1,3) = p2*y2;
551 IJth(J,2,1) = p1; IJth(J,2,2) = -p2*y3-2*p3*y2; IJth(J,2,3) = -p2*y2;
552 IJth(J,3,1) = ZERO; IJth(J,3,2) = 2*p3*y2; IJth(J,3,3) = ZERO;
553
554 return(0);
555 }
556
557 /*
558 * fQ routine. Compute fQ(t,y).
559 */
560
fQ(realtype t,N_Vector y,N_Vector qdot,void * user_data)561 static int fQ(realtype t, N_Vector y, N_Vector qdot, void *user_data)
562 {
563 Ith(qdot,1) = Ith(y,3);
564
565 return(0);
566 }
567
568 /*
569 * EwtSet function. Computes the error weights at the current solution.
570 */
571
ewt(N_Vector y,N_Vector w,void * user_data)572 static int ewt(N_Vector y, N_Vector w, void *user_data)
573 {
574 int i;
575 realtype yy, ww, rtol, atol[3];
576
577 rtol = RTOL;
578 atol[0] = ATOL1;
579 atol[1] = ATOL2;
580 atol[2] = ATOL3;
581
582 for (i=1; i<=3; i++) {
583 yy = Ith(y,i);
584 ww = rtol * SUNRabs(yy) + atol[i-1];
585 if (ww <= 0.0) return (-1);
586 Ith(w,i) = 1.0/ww;
587 }
588
589 return(0);
590 }
591
592 /*
593 * fB routine. Compute fB(t,y,yB).
594 */
595
fB(realtype t,N_Vector y,N_Vector yB,N_Vector yBdot,void * user_dataB)596 static int fB(realtype t, N_Vector y, N_Vector yB, N_Vector yBdot, void *user_dataB)
597 {
598 UserData data;
599 realtype y2, y3;
600 realtype p1, p2, p3;
601 realtype l1, l2, l3;
602 realtype l21, l32;
603
604 data = (UserData) user_dataB;
605
606 /* The p vector */
607 p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
608
609 /* The y vector */
610 y2 = Ith(y,2); y3 = Ith(y,3);
611
612 /* The lambda vector */
613 l1 = Ith(yB,1); l2 = Ith(yB,2); l3 = Ith(yB,3);
614
615 /* Temporary variables */
616 l21 = l2-l1;
617 l32 = l3-l2;
618
619 /* Load yBdot */
620 Ith(yBdot,1) = - p1*l21;
621 Ith(yBdot,2) = p2*y3*l21 - RCONST(2.0)*p3*y2*l32;
622 Ith(yBdot,3) = p2*y2*l21 - RCONST(1.0);
623
624 return(0);
625 }
626
627 /*
628 * JacB routine. Compute JB(t,y,yB).
629 */
630
JacB(realtype t,N_Vector y,N_Vector yB,N_Vector fyB,SUNMatrix JB,void * user_dataB,N_Vector tmp1B,N_Vector tmp2B,N_Vector tmp3B)631 static int JacB(realtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix JB,
632 void *user_dataB, N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B)
633 {
634 UserData data;
635 realtype y2, y3;
636 realtype p1, p2, p3;
637
638 data = (UserData) user_dataB;
639
640 /* The p vector */
641 p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
642
643 /* The y vector */
644 y2 = Ith(y,2); y3 = Ith(y,3);
645
646 /* Load JB */
647 IJth(JB,1,1) = p1; IJth(JB,1,2) = -p1; IJth(JB,1,3) = ZERO;
648 IJth(JB,2,1) = -p2*y3; IJth(JB,2,2) = p2*y3+2.0*p3*y2; IJth(JB,2,3) = RCONST(-2.0)*p3*y2;
649 IJth(JB,3,1) = -p2*y2; IJth(JB,3,2) = p2*y2; IJth(JB,3,3) = ZERO;
650
651 return(0);
652 }
653
654 /*
655 * fQB routine. Compute integrand for quadratures
656 */
657
fQB(realtype t,N_Vector y,N_Vector yB,N_Vector qBdot,void * user_dataB)658 static int fQB(realtype t, N_Vector y, N_Vector yB,
659 N_Vector qBdot, void *user_dataB)
660 {
661 realtype y1, y2, y3;
662 realtype l1, l2, l3;
663 realtype l21, l32, y23;
664
665 /* The y vector */
666 y1 = Ith(y,1); y2 = Ith(y,2); y3 = Ith(y,3);
667
668 /* The lambda vector */
669 l1 = Ith(yB,1); l2 = Ith(yB,2); l3 = Ith(yB,3);
670
671 /* Temporary variables */
672 l21 = l2-l1;
673 l32 = l3-l2;
674 y23 = y2*y3;
675
676 Ith(qBdot,1) = y1*l21;
677 Ith(qBdot,2) = - y23*l21;
678 Ith(qBdot,3) = y2*y2*l32;
679
680 return(0);
681 }
682
683 /*
684 *--------------------------------------------------------------------
685 * PRIVATE FUNCTIONS
686 *--------------------------------------------------------------------
687 */
688
689 /*
690 * Print heading for backward integration
691 */
692
PrintHead(realtype tB0)693 static void PrintHead(realtype tB0)
694 {
695 #if defined(SUNDIALS_EXTENDED_PRECISION)
696 printf("Backward integration from tB0 = %12.4Le\n\n",tB0);
697 #elif defined(SUNDIALS_DOUBLE_PRECISION)
698 printf("Backward integration from tB0 = %12.4e\n\n",tB0);
699 #else
700 printf("Backward integration from tB0 = %12.4e\n\n",tB0);
701 #endif
702 }
703
704 /*
705 * Print intermediate results during backward integration
706 */
707
PrintOutput1(realtype time,realtype t,N_Vector y,N_Vector yB)708 static void PrintOutput1(realtype time, realtype t, N_Vector y, N_Vector yB)
709 {
710 printf("--------------------------------------------------------\n");
711 #if defined(SUNDIALS_EXTENDED_PRECISION)
712 printf("returned t: %12.4Le\n",time);
713 printf("tout: %12.4Le\n",t);
714 printf("lambda(t): %12.4Le %12.4Le %12.4Le\n",
715 Ith(yB,1), Ith(yB,2), Ith(yB,3));
716 printf("y(t): %12.4Le %12.4Le %12.4Le\n",
717 Ith(y,1), Ith(y,2), Ith(y,3));
718 #elif defined(SUNDIALS_DOUBLE_PRECISION)
719 printf("returned t: %12.4e\n",time);
720 printf("tout: %12.4e\n",t);
721 printf("lambda(t): %12.4e %12.4e %12.4e\n",
722 Ith(yB,1), Ith(yB,2), Ith(yB,3));
723 printf("y(t): %12.4e %12.4e %12.4e\n",
724 Ith(y,1), Ith(y,2), Ith(y,3));
725 #else
726 printf("returned t: %12.4e\n",time);
727 printf("tout: %12.4e\n",t);
728 printf("lambda(t): %12.4e %12.4e %12.4e\n",
729 Ith(yB,1), Ith(yB,2), Ith(yB,3));
730 printf("y(t) : %12.4e %12.4e %12.4e\n",
731 Ith(y,1), Ith(y,2), Ith(y,3));
732 #endif
733 printf("--------------------------------------------------------\n\n");
734 }
735
736 /*
737 * Print final results of backward integration
738 */
739
PrintOutput(realtype tfinal,N_Vector y,N_Vector yB,N_Vector qB)740 static void PrintOutput(realtype tfinal, N_Vector y, N_Vector yB, N_Vector qB)
741 {
742 printf("--------------------------------------------------------\n");
743 #if defined(SUNDIALS_EXTENDED_PRECISION)
744 printf("returned t: %12.4Le\n",tfinal);
745 printf("lambda(t0): %12.4Le %12.4Le %12.4Le\n",
746 Ith(yB,1), Ith(yB,2), Ith(yB,3));
747 printf("y(t0): %12.4Le %12.4Le %12.4Le\n",
748 Ith(y,1), Ith(y,2), Ith(y,3));
749 printf("dG/dp: %12.4Le %12.4Le %12.4Le\n",
750 -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
751 #elif defined(SUNDIALS_DOUBLE_PRECISION)
752 printf("returned t: %12.4e\n",tfinal);
753 printf("lambda(t0): %12.4e %12.4e %12.4e\n",
754 Ith(yB,1), Ith(yB,2), Ith(yB,3));
755 printf("y(t0): %12.4e %12.4e %12.4e\n",
756 Ith(y,1), Ith(y,2), Ith(y,3));
757 printf("dG/dp: %12.4e %12.4e %12.4e\n",
758 -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
759 #else
760 printf("returned t: %12.4e\n",tfinal);
761 printf("lambda(t0): %12.4e %12.4e %12.4e\n",
762 Ith(yB,1), Ith(yB,2), Ith(yB,3));
763 printf("y(t0) : %12.4e %12.4e %12.4e\n",
764 Ith(y,1), Ith(y,2), Ith(y,3));
765 printf("dG/dp: %12.4e %12.4e %12.4e\n",
766 -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
767 #endif
768 printf("--------------------------------------------------------\n\n");
769 }
770
771 /*
772 * Check function return value.
773 * opt == 0 means SUNDIALS function allocates memory so check if
774 * returned NULL pointer
775 * opt == 1 means SUNDIALS function returns an integer value so check if
776 * retval < 0
777 * opt == 2 means function allocates memory so check if returned
778 * NULL pointer
779 */
780
check_retval(void * returnvalue,const char * funcname,int opt)781 static int check_retval(void *returnvalue, const char *funcname, int opt)
782 {
783 int *retval;
784
785 /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
786 if (opt == 0 && returnvalue == NULL) {
787 fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
788 funcname);
789 return(1); }
790
791 /* Check if retval < 0 */
792 else if (opt == 1) {
793 retval = (int *) returnvalue;
794 if (*retval < 0) {
795 fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with retval = %d\n\n",
796 funcname, *retval);
797 return(1); }}
798
799 /* Check if function returned NULL pointer - no memory allocated */
800 else if (opt == 2 && returnvalue == NULL) {
801 fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
802 funcname);
803 return(1); }
804
805 return(0);
806 }
807