1 /* -----------------------------------------------------------------
2  * Programmer(s): Radu Serban @ LLNL
3  * -----------------------------------------------------------------
4  * SUNDIALS Copyright Start
5  * Copyright (c) 2002-2021, Lawrence Livermore National Security
6  * and Southern Methodist University.
7  * All rights reserved.
8  *
9  * See the top-level LICENSE and NOTICE files for details.
10  *
11  * SPDX-License-Identifier: BSD-3-Clause
12  * SUNDIALS Copyright End
13  * -----------------------------------------------------------------
14  * Adjoint sensitivity example problem.
15  * The following is a simple example problem, with the coding
16  * needed for its solution by CVODES. The problem is from chemical
17  * kinetics, and consists of the following three rate equations.
18  *    dy1/dt = -p1*y1 + p2*y2*y3
19  *    dy2/dt =  p1*y1 - p2*y2*y3 - p3*(y2)^2
20  *    dy3/dt =  p3*(y2)^2
21  * on the interval from t = 0.0 to t = 4.e10, with initial
22  * conditions: y1 = 1.0, y2 = y3 = 0. The reaction rates are:
23  * p1=0.04, p2=1e4, and p3=3e7. The problem is stiff.
24  * This program solves the problem with the BDF method, Newton
25  * iteration with the DENSE linear solver, and a user-supplied
26  * Jacobian routine.
27  * It uses a scalar relative tolerance and a vector absolute
28  * tolerance.
29  * Output is printed in decades from t = .4 to t = 4.e10.
30  * Run statistics (optional outputs) are printed at the end.
31  *
32  * Optionally, CVODES can compute sensitivities with respect to
33  * the problem parameters p1, p2, and p3 of the following quantity:
34  *   G = int_t0^t1 g(t,p,y) dt
35  * where
36  *   g(t,p,y) = y3
37  *
38  * The gradient dG/dp is obtained as:
39  *   dG/dp = int_t0^t1 (g_p - lambda^T f_p ) dt - lambda^T(t0)*y0_p
40  *         = - xi^T(t0) - lambda^T(t0)*y0_p
41  * where lambda and xi are solutions of:
42  *   d(lambda)/dt = - (f_y)^T * lambda - (g_y)^T
43  *   lambda(t1) = 0
44  * and
45  *   d(xi)/dt = - (f_p)^T * lambda + (g_p)^T
46  *   xi(t1) = 0
47  *
48  * During the backward integration, CVODES also evaluates G as
49  *   G = - phi(t0)
50  * where
51  *   d(phi)/dt = g(t,y,p)
52  *   phi(t1) = 0
53  * -----------------------------------------------------------------*/
54 
55 #include <stdio.h>
56 #include <stdlib.h>
57 
58 #include <cvodes/cvodes.h>             /* prototypes for CVODE fcts., consts.  */
59 #include <nvector/nvector_serial.h>    /* access to serial N_Vector            */
60 #include <sunmatrix/sunmatrix_dense.h> /* access to dense SUNMatrix            */
61 #include <sunlinsol/sunlinsol_dense.h> /* access to dense SUNLinearSolver      */
62 #include <sundials/sundials_types.h>   /* defs. of realtype, sunindextype      */
63 #include <sundials/sundials_math.h>    /* defs. of SUNRabs, SUNRexp, etc.      */
64 
65 /* Accessor macros */
66 
67 #define Ith(v,i)    NV_Ith_S(v,i-1)         /* i-th vector component, i=1..NEQ */
68 #define IJth(A,i,j) SM_ELEMENT_D(A,i-1,j-1) /* (i,j)-th matrix el., i,j=1..NEQ */
69 
70 /* Problem Constants */
71 
72 #define NEQ      3             /* number of equations                  */
73 
74 #define RTOL     RCONST(1e-6)  /* scalar relative tolerance            */
75 
76 #define ATOL1    RCONST(1e-8)  /* vector absolute tolerance components */
77 #define ATOL2    RCONST(1e-14)
78 #define ATOL3    RCONST(1e-6)
79 
80 #define ATOLl    RCONST(1e-8)  /* absolute tolerance for adjoint vars. */
81 #define ATOLq    RCONST(1e-6)  /* absolute tolerance for quadratures   */
82 
83 #define T0       RCONST(0.0)   /* initial time                         */
84 #define TOUT     RCONST(4e7)   /* final time                           */
85 
86 #define TB1      RCONST(4e7)   /* starting point for adjoint problem   */
87 #define TB2      RCONST(50.0)  /* starting point for adjoint problem   */
88 #define TBout1   RCONST(40.0)  /* intermediate t for adjoint problem   */
89 
90 #define STEPS    150           /* number of steps between check points */
91 
92 #define NP       3             /* number of problem parameters         */
93 
94 #define ZERO     RCONST(0.0)
95 
96 
97 /* Type : UserData */
98 
99 typedef struct {
100   realtype p[3];
101 } *UserData;
102 
103 /* Prototypes of user-supplied functions */
104 
105 static int f(realtype t, N_Vector y, N_Vector ydot, void *user_data);
106 static int Jac(realtype t, N_Vector y, N_Vector fy, SUNMatrix J,
107                void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
108 static int fQ(realtype t, N_Vector y, N_Vector qdot, void *user_data);
109 static int ewt(N_Vector y, N_Vector w, void *user_data);
110 
111 static int fB(realtype t, N_Vector y,
112               N_Vector yB, N_Vector yBdot, void *user_dataB);
113 static int JacB(realtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix JB,
114                 void *user_dataB, N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B);
115 static int fQB(realtype t, N_Vector y, N_Vector yB,
116                N_Vector qBdot, void *user_dataB);
117 
118 
119 /* Prototypes of private functions */
120 
121 static void PrintHead(realtype tB0);
122 static void PrintOutput(realtype tfinal, N_Vector y, N_Vector yB, N_Vector qB);
123 static void PrintOutput1(realtype time, realtype t, N_Vector y, N_Vector yB);
124 static int check_retval(void *returnvalue, const char *funcname, int opt);
125 
126 /*
127  *--------------------------------------------------------------------
128  * MAIN PROGRAM
129  *--------------------------------------------------------------------
130  */
131 
main(int argc,char * argv[])132 int main(int argc, char *argv[])
133 {
134   UserData data;
135 
136   SUNMatrix A, AB;
137   SUNLinearSolver LS, LSB;
138   void *cvode_mem;
139 
140   realtype reltolQ, abstolQ;
141   N_Vector y, q;
142 
143   int steps;
144 
145   int indexB;
146 
147   realtype reltolB, abstolB, abstolQB;
148   N_Vector yB, qB;
149 
150   realtype time;
151   int retval, ncheck;
152 
153   long int nst, nstB;
154 
155   CVadjCheckPointRec *ckpnt;
156 
157   data = NULL;
158   A = AB = NULL;
159   LS = LSB = NULL;
160   cvode_mem = NULL;
161   ckpnt = NULL;
162   y = yB = qB = NULL;
163 
164   /* Print problem description */
165   printf("\nAdjoint Sensitivity Example for Chemical Kinetics\n");
166   printf("-------------------------------------------------\n\n");
167   printf("ODE: dy1/dt = -p1*y1 + p2*y2*y3\n");
168   printf("     dy2/dt =  p1*y1 - p2*y2*y3 - p3*(y2)^2\n");
169   printf("     dy3/dt =  p3*(y2)^2\n\n");
170   printf("Find dG/dp for\n");
171   printf("     G = int_t0^tB0 g(t,p,y) dt\n");
172   printf("     g(t,p,y) = y3\n\n\n");
173 
174   /* User data structure */
175   data = (UserData) malloc(sizeof *data);
176   if (check_retval((void *)data, "malloc", 2)) return(1);
177   data->p[0] = RCONST(0.04);
178   data->p[1] = RCONST(1.0e4);
179   data->p[2] = RCONST(3.0e7);
180 
181   /* Initialize y */
182   y = N_VNew_Serial(NEQ);
183   if (check_retval((void *)y, "N_VNew_Serial", 0)) return(1);
184   Ith(y,1) = RCONST(1.0);
185   Ith(y,2) = ZERO;
186   Ith(y,3) = ZERO;
187 
188   /* Initialize q */
189   q = N_VNew_Serial(1);
190   if (check_retval((void *)q, "N_VNew_Serial", 0)) return(1);
191   Ith(q,1) = ZERO;
192 
193   /* Set the scalar realtive and absolute tolerances reltolQ and abstolQ */
194   reltolQ = RTOL;
195   abstolQ = ATOLq;
196 
197   /* Create and allocate CVODES memory for forward run */
198   printf("Create and allocate CVODES memory for forward runs\n");
199 
200   /* Call CVodeCreate to create the solver memory and specify the
201      Backward Differentiation Formula */
202   cvode_mem = CVodeCreate(CV_BDF);
203   if (check_retval((void *)cvode_mem, "CVodeCreate", 0)) return(1);
204 
205   /* Call CVodeInit to initialize the integrator memory and specify the
206      user's right hand side function in y'=f(t,y), the initial time T0, and
207      the initial dependent variable vector y. */
208   retval = CVodeInit(cvode_mem, f, T0, y);
209   if (check_retval(&retval, "CVodeInit", 1)) return(1);
210 
211   /* Call CVodeWFtolerances to specify a user-supplied function ewt that sets
212      the multiplicative error weights w_i for use in the weighted RMS norm */
213   retval = CVodeWFtolerances(cvode_mem, ewt);
214   if (check_retval(&retval, "CVodeWFtolerances", 1)) return(1);
215 
216   /* Attach user data */
217   retval = CVodeSetUserData(cvode_mem, data);
218   if (check_retval(&retval, "CVodeSetUserData", 1)) return(1);
219 
220   /* Create dense SUNMatrix for use in linear solves */
221   A = SUNDenseMatrix(NEQ, NEQ);
222   if (check_retval((void *)A, "SUNDenseMatrix", 0)) return(1);
223 
224   /* Create dense SUNLinearSolver object */
225   LS = SUNLinSol_Dense(y, A);
226   if (check_retval((void *)LS, "SUNLinSol_Dense", 0)) return(1);
227 
228   /* Attach the matrix and linear solver */
229   retval = CVodeSetLinearSolver(cvode_mem, LS, A);
230   if (check_retval(&retval, "CVodeSetLinearSolver", 1)) return(1);
231 
232   /* Set the user-supplied Jacobian routine Jac */
233   retval = CVodeSetJacFn(cvode_mem, Jac);
234   if (check_retval(&retval, "CVodeSetJacFn", 1)) return(1);
235 
236   /* Call CVodeQuadInit to allocate initernal memory and initialize
237      quadrature integration*/
238   retval = CVodeQuadInit(cvode_mem, fQ, q);
239   if (check_retval(&retval, "CVodeQuadInit", 1)) return(1);
240 
241   /* Call CVodeSetQuadErrCon to specify whether or not the quadrature variables
242      are to be used in the step size control mechanism within CVODES. Call
243      CVodeQuadSStolerances or CVodeQuadSVtolerances to specify the integration
244      tolerances for the quadrature variables. */
245   retval = CVodeSetQuadErrCon(cvode_mem, SUNTRUE);
246   if (check_retval(&retval, "CVodeSetQuadErrCon", 1)) return(1);
247 
248   /* Call CVodeQuadSStolerances to specify scalar relative and absolute
249      tolerances. */
250   retval = CVodeQuadSStolerances(cvode_mem, reltolQ, abstolQ);
251   if (check_retval(&retval, "CVodeQuadSStolerances", 1)) return(1);
252 
253   /* Call CVodeSetMaxNumSteps to set the maximum number of steps the
254    * solver will take in an attempt to reach the next output time
255    * during forward integration. */
256   retval = CVodeSetMaxNumSteps(cvode_mem, 2500);
257   if (check_retval(&retval, "CVodeSetMaxNumSteps", 1)) return(1);
258 
259   /* Allocate global memory */
260 
261   /* Call CVodeAdjInit to update CVODES memory block by allocting the internal
262      memory needed for backward integration.*/
263   steps = STEPS; /* no. of integration steps between two consecutive ckeckpoints*/
264   retval = CVodeAdjInit(cvode_mem, steps, CV_HERMITE);
265   /*
266   retval = CVodeAdjInit(cvode_mem, steps, CV_POLYNOMIAL);
267   */
268   if (check_retval(&retval, "CVodeAdjInit", 1)) return(1);
269 
270   /* Perform forward run */
271   printf("Forward integration ... ");
272 
273   /* Call CVodeF to integrate the forward problem over an interval in time and
274      saves checkpointing data */
275   retval = CVodeF(cvode_mem, TOUT, y, &time, CV_NORMAL, &ncheck);
276   if (check_retval(&retval, "CVodeF", 1)) return(1);
277   retval = CVodeGetNumSteps(cvode_mem, &nst);
278   if (check_retval(&retval, "CVodeGetNumSteps", 1)) return(1);
279 
280   printf("done ( nst = %ld )\n",nst);
281   printf("\nncheck = %d\n\n", ncheck);
282 
283   retval = CVodeGetQuad(cvode_mem, &time, q);
284   if (check_retval(&retval, "CVodeGetQuad", 1)) return(1);
285 
286   printf("--------------------------------------------------------\n");
287 #if defined(SUNDIALS_EXTENDED_PRECISION)
288   printf("G:          %12.4Le \n",Ith(q,1));
289 #elif defined(SUNDIALS_DOUBLE_PRECISION)
290   printf("G:          %12.4e \n",Ith(q,1));
291 #else
292   printf("G:          %12.4e \n",Ith(q,1));
293 #endif
294   printf("--------------------------------------------------------\n\n");
295 
296   /* Test check point linked list
297      (uncomment next block to print check point information) */
298 
299   /*
300   {
301     int i;
302 
303     printf("\nList of Check Points (ncheck = %d)\n\n", ncheck);
304     ckpnt = (CVadjCheckPointRec *) malloc ( (ncheck+1)*sizeof(CVadjCheckPointRec));
305     CVodeGetAdjCheckPointsInfo(cvode_mem, ckpnt);
306     for (i=0;i<=ncheck;i++) {
307       printf("Address:       %p\n",ckpnt[i].my_addr);
308       printf("Next:          %p\n",ckpnt[i].next_addr);
309       printf("Time interval: %le  %le\n",ckpnt[i].t0, ckpnt[i].t1);
310       printf("Step number:   %ld\n",ckpnt[i].nstep);
311       printf("Order:         %d\n",ckpnt[i].order);
312       printf("Step size:     %le\n",ckpnt[i].step);
313       printf("\n");
314     }
315 
316   }
317   */
318 
319   /* Initialize yB */
320   yB = N_VNew_Serial(NEQ);
321   if (check_retval((void *)yB, "N_VNew_Serial", 0)) return(1);
322   Ith(yB,1) = ZERO;
323   Ith(yB,2) = ZERO;
324   Ith(yB,3) = ZERO;
325 
326   /* Initialize qB */
327   qB = N_VNew_Serial(NP);
328   if (check_retval((void *)qB, "N_VNew", 0)) return(1);
329   Ith(qB,1) = ZERO;
330   Ith(qB,2) = ZERO;
331   Ith(qB,3) = ZERO;
332 
333   /* Set the scalar relative tolerance reltolB */
334   reltolB = RTOL;
335 
336   /* Set the scalar absolute tolerance abstolB */
337   abstolB = ATOLl;
338 
339   /* Set the scalar absolute tolerance abstolQB */
340   abstolQB = ATOLq;
341 
342   /* Create and allocate CVODES memory for backward run */
343   printf("Create and allocate CVODES memory for backward run\n");
344 
345   /* Call CVodeCreateB to specify the solution method for the backward
346      problem. */
347   retval = CVodeCreateB(cvode_mem, CV_BDF, &indexB);
348   if (check_retval(&retval, "CVodeCreateB", 1)) return(1);
349 
350   /* Call CVodeInitB to allocate internal memory and initialize the
351      backward problem. */
352   retval = CVodeInitB(cvode_mem, indexB, fB, TB1, yB);
353   if (check_retval(&retval, "CVodeInitB", 1)) return(1);
354 
355   /* Set the scalar relative and absolute tolerances. */
356   retval = CVodeSStolerancesB(cvode_mem, indexB, reltolB, abstolB);
357   if (check_retval(&retval, "CVodeSStolerancesB", 1)) return(1);
358 
359   /* Attach the user data for backward problem. */
360   retval = CVodeSetUserDataB(cvode_mem, indexB, data);
361   if (check_retval(&retval, "CVodeSetUserDataB", 1)) return(1);
362 
363   /* Create dense SUNMatrix for use in linear solves */
364   AB = SUNDenseMatrix(NEQ, NEQ);
365   if (check_retval((void *)AB, "SUNDenseMatrix", 0)) return(1);
366 
367   /* Create dense SUNLinearSolver object */
368   LSB = SUNLinSol_Dense(yB, AB);
369   if (check_retval((void *)LSB, "SUNLinSol_Dense", 0)) return(1);
370 
371   /* Attach the matrix and linear solver */
372   retval = CVodeSetLinearSolverB(cvode_mem, indexB, LSB, AB);
373   if (check_retval(&retval, "CVodeSetLinearSolverB", 1)) return(1);
374 
375   /* Set the user-supplied Jacobian routine JacB */
376   retval = CVodeSetJacFnB(cvode_mem, indexB, JacB);
377   if (check_retval(&retval, "CVodeSetJacFnB", 1)) return(1);
378 
379   /* Call CVodeQuadInitB to allocate internal memory and initialize backward
380      quadrature integration. */
381   retval = CVodeQuadInitB(cvode_mem, indexB, fQB, qB);
382   if (check_retval(&retval, "CVodeQuadInitB", 1)) return(1);
383 
384   /* Call CVodeSetQuadErrCon to specify whether or not the quadrature variables
385      are to be used in the step size control mechanism within CVODES. Call
386      CVodeQuadSStolerances or CVodeQuadSVtolerances to specify the integration
387      tolerances for the quadrature variables. */
388   retval = CVodeSetQuadErrConB(cvode_mem, indexB, SUNTRUE);
389   if (check_retval(&retval, "CVodeSetQuadErrConB", 1)) return(1);
390 
391   /* Call CVodeQuadSStolerancesB to specify the scalar relative and absolute tolerances
392      for the backward problem. */
393   retval = CVodeQuadSStolerancesB(cvode_mem, indexB, reltolB, abstolQB);
394   if (check_retval(&retval, "CVodeQuadSStolerancesB", 1)) return(1);
395 
396   /* Backward Integration */
397 
398   PrintHead(TB1);
399 
400   /* First get results at t = TBout1 */
401 
402   /* Call CVodeB to integrate the backward ODE problem. */
403   retval = CVodeB(cvode_mem, TBout1, CV_NORMAL);
404   if (check_retval(&retval, "CVodeB", 1)) return(1);
405 
406   /* Call CVodeGetB to get yB of the backward ODE problem. */
407   retval = CVodeGetB(cvode_mem, indexB, &time, yB);
408   if (check_retval(&retval, "CVodeGetB", 1)) return(1);
409 
410   /* Call CVodeGetAdjY to get the interpolated value of the forward solution
411      y during a backward integration. */
412   retval = CVodeGetAdjY(cvode_mem, TBout1, y);
413   if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
414 
415   PrintOutput1(time, TBout1, y, yB);
416 
417   /* Then at t = T0 */
418 
419   retval = CVodeB(cvode_mem, T0, CV_NORMAL);
420   if (check_retval(&retval, "CVodeB", 1)) return(1);
421   CVodeGetNumSteps(CVodeGetAdjCVodeBmem(cvode_mem, indexB), &nstB);
422   printf("Done ( nst = %ld )\n", nstB);
423 
424   retval = CVodeGetB(cvode_mem, indexB, &time, yB);
425   if (check_retval(&retval, "CVodeGetB", 1)) return(1);
426 
427   /* Call CVodeGetQuadB to get the quadrature solution vector after a
428      successful return from CVodeB. */
429   retval = CVodeGetQuadB(cvode_mem, indexB, &time, qB);
430   if (check_retval(&retval, "CVodeGetQuadB", 1)) return(1);
431 
432   retval = CVodeGetAdjY(cvode_mem, T0, y);
433   if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
434 
435   PrintOutput(time, y, yB, qB);
436 
437   /* Reinitialize backward phase (new tB0) */
438 
439   Ith(yB,1) = ZERO;
440   Ith(yB,2) = ZERO;
441   Ith(yB,3) = ZERO;
442 
443   Ith(qB,1) = ZERO;
444   Ith(qB,2) = ZERO;
445   Ith(qB,3) = ZERO;
446 
447   printf("Re-initialize CVODES memory for backward run\n");
448 
449   retval = CVodeReInitB(cvode_mem, indexB, TB2, yB);
450   if (check_retval(&retval, "CVodeReInitB", 1)) return(1);
451 
452   retval = CVodeQuadReInitB(cvode_mem, indexB, qB);
453   if (check_retval(&retval, "CVodeQuadReInitB", 1)) return(1);
454 
455   PrintHead(TB2);
456 
457   /* First get results at t = TBout1 */
458 
459   retval = CVodeB(cvode_mem, TBout1, CV_NORMAL);
460   if (check_retval(&retval, "CVodeB", 1)) return(1);
461 
462   retval = CVodeGetB(cvode_mem, indexB, &time, yB);
463   if (check_retval(&retval, "CVodeGetB", 1)) return(1);
464 
465   retval = CVodeGetAdjY(cvode_mem, TBout1, y);
466   if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
467 
468   PrintOutput1(time, TBout1, y, yB);
469 
470   /* Then at t = T0 */
471 
472   retval = CVodeB(cvode_mem, T0, CV_NORMAL);
473   if (check_retval(&retval, "CVodeB", 1)) return(1);
474   CVodeGetNumSteps(CVodeGetAdjCVodeBmem(cvode_mem, indexB), &nstB);
475   printf("Done ( nst = %ld )\n", nstB);
476 
477   retval = CVodeGetB(cvode_mem, indexB, &time, yB);
478   if (check_retval(&retval, "CVodeGetB", 1)) return(1);
479 
480   retval = CVodeGetQuadB(cvode_mem, indexB, &time, qB);
481   if (check_retval(&retval, "CVodeGetQuadB", 1)) return(1);
482 
483   retval = CVodeGetAdjY(cvode_mem, T0, y);
484   if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
485 
486   PrintOutput(time, y, yB, qB);
487 
488   /* Free memory */
489   printf("Free memory\n\n");
490 
491   CVodeFree(&cvode_mem);
492   N_VDestroy(y);
493   N_VDestroy(q);
494   N_VDestroy(yB);
495   N_VDestroy(qB);
496   SUNLinSolFree(LS);
497   SUNMatDestroy(A);
498   SUNLinSolFree(LSB);
499   SUNMatDestroy(AB);
500 
501   if (ckpnt != NULL) free(ckpnt);
502   free(data);
503 
504   return(0);
505 
506 }
507 
508 /*
509  *--------------------------------------------------------------------
510  * FUNCTIONS CALLED BY CVODES
511  *--------------------------------------------------------------------
512  */
513 
514 /*
515  * f routine. Compute f(t,y).
516 */
517 
f(realtype t,N_Vector y,N_Vector ydot,void * user_data)518 static int f(realtype t, N_Vector y, N_Vector ydot, void *user_data)
519 {
520   realtype y1, y2, y3, yd1, yd3;
521   UserData data;
522   realtype p1, p2, p3;
523 
524   y1 = Ith(y,1); y2 = Ith(y,2); y3 = Ith(y,3);
525   data = (UserData) user_data;
526   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
527 
528   yd1 = Ith(ydot,1) = -p1*y1 + p2*y2*y3;
529   yd3 = Ith(ydot,3) = p3*y2*y2;
530         Ith(ydot,2) = -yd1 - yd3;
531 
532   return(0);
533 }
534 
535 /*
536  * Jacobian routine. Compute J(t,y).
537 */
538 
Jac(realtype t,N_Vector y,N_Vector fy,SUNMatrix J,void * user_data,N_Vector tmp1,N_Vector tmp2,N_Vector tmp3)539 static int Jac(realtype t, N_Vector y, N_Vector fy, SUNMatrix J,
540                void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
541 {
542   realtype y2, y3;
543   UserData data;
544   realtype p1, p2, p3;
545 
546   y2 = Ith(y,2); y3 = Ith(y,3);
547   data = (UserData) user_data;
548   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
549 
550   IJth(J,1,1) = -p1;  IJth(J,1,2) = p2*y3;          IJth(J,1,3) = p2*y2;
551   IJth(J,2,1) =  p1;  IJth(J,2,2) = -p2*y3-2*p3*y2; IJth(J,2,3) = -p2*y2;
552   IJth(J,3,1) = ZERO; IJth(J,3,2) = 2*p3*y2;        IJth(J,3,3) = ZERO;
553 
554   return(0);
555 }
556 
557 /*
558  * fQ routine. Compute fQ(t,y).
559 */
560 
fQ(realtype t,N_Vector y,N_Vector qdot,void * user_data)561 static int fQ(realtype t, N_Vector y, N_Vector qdot, void *user_data)
562 {
563   Ith(qdot,1) = Ith(y,3);
564 
565   return(0);
566 }
567 
568 /*
569  * EwtSet function. Computes the error weights at the current solution.
570  */
571 
ewt(N_Vector y,N_Vector w,void * user_data)572 static int ewt(N_Vector y, N_Vector w, void *user_data)
573 {
574   int i;
575   realtype yy, ww, rtol, atol[3];
576 
577   rtol    = RTOL;
578   atol[0] = ATOL1;
579   atol[1] = ATOL2;
580   atol[2] = ATOL3;
581 
582   for (i=1; i<=3; i++) {
583     yy = Ith(y,i);
584     ww = rtol * SUNRabs(yy) + atol[i-1];
585     if (ww <= 0.0) return (-1);
586     Ith(w,i) = 1.0/ww;
587   }
588 
589   return(0);
590 }
591 
592 /*
593  * fB routine. Compute fB(t,y,yB).
594 */
595 
fB(realtype t,N_Vector y,N_Vector yB,N_Vector yBdot,void * user_dataB)596 static int fB(realtype t, N_Vector y, N_Vector yB, N_Vector yBdot, void *user_dataB)
597 {
598   UserData data;
599   realtype y2, y3;
600   realtype p1, p2, p3;
601   realtype l1, l2, l3;
602   realtype l21, l32;
603 
604   data = (UserData) user_dataB;
605 
606   /* The p vector */
607   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
608 
609   /* The y vector */
610   y2 = Ith(y,2); y3 = Ith(y,3);
611 
612   /* The lambda vector */
613   l1 = Ith(yB,1); l2 = Ith(yB,2); l3 = Ith(yB,3);
614 
615   /* Temporary variables */
616   l21 = l2-l1;
617   l32 = l3-l2;
618 
619   /* Load yBdot */
620   Ith(yBdot,1) = - p1*l21;
621   Ith(yBdot,2) = p2*y3*l21 - RCONST(2.0)*p3*y2*l32;
622   Ith(yBdot,3) = p2*y2*l21 - RCONST(1.0);
623 
624   return(0);
625 }
626 
627 /*
628  * JacB routine. Compute JB(t,y,yB).
629  */
630 
JacB(realtype t,N_Vector y,N_Vector yB,N_Vector fyB,SUNMatrix JB,void * user_dataB,N_Vector tmp1B,N_Vector tmp2B,N_Vector tmp3B)631 static int JacB(realtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix JB,
632                 void *user_dataB, N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B)
633 {
634   UserData data;
635   realtype y2, y3;
636   realtype p1, p2, p3;
637 
638   data = (UserData) user_dataB;
639 
640   /* The p vector */
641   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
642 
643   /* The y vector */
644   y2 = Ith(y,2); y3 = Ith(y,3);
645 
646   /* Load JB */
647   IJth(JB,1,1) = p1;     IJth(JB,1,2) = -p1;             IJth(JB,1,3) = ZERO;
648   IJth(JB,2,1) = -p2*y3; IJth(JB,2,2) = p2*y3+2.0*p3*y2; IJth(JB,2,3) = RCONST(-2.0)*p3*y2;
649   IJth(JB,3,1) = -p2*y2; IJth(JB,3,2) = p2*y2;           IJth(JB,3,3) = ZERO;
650 
651   return(0);
652 }
653 
654 /*
655  * fQB routine. Compute integrand for quadratures
656 */
657 
fQB(realtype t,N_Vector y,N_Vector yB,N_Vector qBdot,void * user_dataB)658 static int fQB(realtype t, N_Vector y, N_Vector yB,
659                N_Vector qBdot, void *user_dataB)
660 {
661   realtype y1, y2, y3;
662   realtype l1, l2, l3;
663   realtype l21, l32, y23;
664 
665   /* The y vector */
666   y1 = Ith(y,1); y2 = Ith(y,2); y3 = Ith(y,3);
667 
668   /* The lambda vector */
669   l1 = Ith(yB,1); l2 = Ith(yB,2); l3 = Ith(yB,3);
670 
671   /* Temporary variables */
672   l21 = l2-l1;
673   l32 = l3-l2;
674   y23 = y2*y3;
675 
676   Ith(qBdot,1) = y1*l21;
677   Ith(qBdot,2) = - y23*l21;
678   Ith(qBdot,3) = y2*y2*l32;
679 
680   return(0);
681 }
682 
683 /*
684  *--------------------------------------------------------------------
685  * PRIVATE FUNCTIONS
686  *--------------------------------------------------------------------
687  */
688 
689 /*
690  * Print heading for backward integration
691  */
692 
PrintHead(realtype tB0)693 static void PrintHead(realtype tB0)
694 {
695 #if defined(SUNDIALS_EXTENDED_PRECISION)
696   printf("Backward integration from tB0 = %12.4Le\n\n",tB0);
697 #elif defined(SUNDIALS_DOUBLE_PRECISION)
698   printf("Backward integration from tB0 = %12.4e\n\n",tB0);
699 #else
700   printf("Backward integration from tB0 = %12.4e\n\n",tB0);
701 #endif
702 }
703 
704 /*
705  * Print intermediate results during backward integration
706  */
707 
PrintOutput1(realtype time,realtype t,N_Vector y,N_Vector yB)708 static void PrintOutput1(realtype time, realtype t, N_Vector y, N_Vector yB)
709 {
710   printf("--------------------------------------------------------\n");
711 #if defined(SUNDIALS_EXTENDED_PRECISION)
712   printf("returned t: %12.4Le\n",time);
713   printf("tout:       %12.4Le\n",t);
714   printf("lambda(t):  %12.4Le %12.4Le %12.4Le\n",
715          Ith(yB,1), Ith(yB,2), Ith(yB,3));
716   printf("y(t):       %12.4Le %12.4Le %12.4Le\n",
717          Ith(y,1), Ith(y,2), Ith(y,3));
718 #elif defined(SUNDIALS_DOUBLE_PRECISION)
719   printf("returned t: %12.4e\n",time);
720   printf("tout:       %12.4e\n",t);
721   printf("lambda(t):  %12.4e %12.4e %12.4e\n",
722          Ith(yB,1), Ith(yB,2), Ith(yB,3));
723   printf("y(t):       %12.4e %12.4e %12.4e\n",
724          Ith(y,1), Ith(y,2), Ith(y,3));
725 #else
726   printf("returned t: %12.4e\n",time);
727   printf("tout:       %12.4e\n",t);
728   printf("lambda(t):  %12.4e %12.4e %12.4e\n",
729          Ith(yB,1), Ith(yB,2), Ith(yB,3));
730   printf("y(t)      : %12.4e %12.4e %12.4e\n",
731          Ith(y,1), Ith(y,2), Ith(y,3));
732 #endif
733   printf("--------------------------------------------------------\n\n");
734 }
735 
736 /*
737  * Print final results of backward integration
738  */
739 
PrintOutput(realtype tfinal,N_Vector y,N_Vector yB,N_Vector qB)740 static void PrintOutput(realtype tfinal, N_Vector y, N_Vector yB, N_Vector qB)
741 {
742   printf("--------------------------------------------------------\n");
743 #if defined(SUNDIALS_EXTENDED_PRECISION)
744   printf("returned t: %12.4Le\n",tfinal);
745   printf("lambda(t0): %12.4Le %12.4Le %12.4Le\n",
746          Ith(yB,1), Ith(yB,2), Ith(yB,3));
747   printf("y(t0):      %12.4Le %12.4Le %12.4Le\n",
748          Ith(y,1), Ith(y,2), Ith(y,3));
749   printf("dG/dp:      %12.4Le %12.4Le %12.4Le\n",
750          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
751 #elif defined(SUNDIALS_DOUBLE_PRECISION)
752   printf("returned t: %12.4e\n",tfinal);
753   printf("lambda(t0): %12.4e %12.4e %12.4e\n",
754          Ith(yB,1), Ith(yB,2), Ith(yB,3));
755   printf("y(t0):      %12.4e %12.4e %12.4e\n",
756          Ith(y,1), Ith(y,2), Ith(y,3));
757   printf("dG/dp:      %12.4e %12.4e %12.4e\n",
758          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
759 #else
760   printf("returned t: %12.4e\n",tfinal);
761   printf("lambda(t0): %12.4e %12.4e %12.4e\n",
762          Ith(yB,1), Ith(yB,2), Ith(yB,3));
763   printf("y(t0)     : %12.4e %12.4e %12.4e\n",
764          Ith(y,1), Ith(y,2), Ith(y,3));
765   printf("dG/dp:      %12.4e %12.4e %12.4e\n",
766          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
767 #endif
768   printf("--------------------------------------------------------\n\n");
769 }
770 
771 /*
772  * Check function return value.
773  *    opt == 0 means SUNDIALS function allocates memory so check if
774  *             returned NULL pointer
775  *    opt == 1 means SUNDIALS function returns an integer value so check if
776  *             retval < 0
777  *    opt == 2 means function allocates memory so check if returned
778  *             NULL pointer
779  */
780 
check_retval(void * returnvalue,const char * funcname,int opt)781 static int check_retval(void *returnvalue, const char *funcname, int opt)
782 {
783   int *retval;
784 
785   /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
786   if (opt == 0 && returnvalue == NULL) {
787     fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
788 	    funcname);
789     return(1); }
790 
791   /* Check if retval < 0 */
792   else if (opt == 1) {
793     retval = (int *) returnvalue;
794     if (*retval < 0) {
795       fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with retval = %d\n\n",
796 	      funcname, *retval);
797       return(1); }}
798 
799   /* Check if function returned NULL pointer - no memory allocated */
800   else if (opt == 2 && returnvalue == NULL) {
801     fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
802 	    funcname);
803     return(1); }
804 
805   return(0);
806 }
807