1 /* -----------------------------------------------------------------
2  * Programmer(s): Ting Yan @ SMU
3  *      Based on idasRoberts_ASAi_dns.c and modified to use SuperLUMT
4  * -----------------------------------------------------------------
5  * SUNDIALS Copyright Start
6  * Copyright (c) 2002-2021, Lawrence Livermore National Security
7  * and Southern Methodist University.
8  * All rights reserved.
9  *
10  * See the top-level LICENSE and NOTICE files for details.
11  *
12  * SPDX-License-Identifier: BSD-3-Clause
13  * SUNDIALS Copyright End
14  * -----------------------------------------------------------------
15  * Adjoint sensitivity example problem.
16  *
17  * This simple example problem for IDAS, due to Robertson,
18  * is from chemical kinetics, and consists of the following three
19  * equations:
20  *
21  *      dy1/dt + p1*y1 - p2*y2*y3            = 0
22  *      dy2/dt - p1*y1 + p2*y2*y3 + p3*y2**2 = 0
23  *                 y1  +  y2  +  y3  -  1    = 0
24  *
25  * on the interval from t = 0.0 to t = 4.e10, with initial
26  * conditions: y1 = 1, y2 = y3 = 0.The reaction rates are: p1=0.04,
27  * p2=1e4, and p3=3e7
28  *
29  * It uses a scalar relative tolerance and a vector absolute
30  * tolerance.
31  *
32  * IDAS can also compute sensitivities with respect to
33  * the problem parameters p1, p2, and p3 of the following quantity:
34  *   G = int_t0^t1 g(t,p,y) dt
35  * where
36  *   g(t,p,y) = y3
37  *
38  * The gradient dG/dp is obtained as:
39  *   dG/dp = int_t0^t1 (g_p - lambda^T F_p ) dt -
40  *           lambda^T*F_y'*y_p | _t0^t1
41  *         = int_t0^t1 (lambda^T*F_p) dt
42  * where lambda and are solutions of the adjoint system:
43  *   d(lambda^T * F_y' )/dt -lambda^T F_y = -g_y
44  *
45  * During the backward integration, IDAS also evaluates G as
46  *   G = - phi(t0)
47  * where
48  *   d(phi)/dt = g(t,y,p)
49  *   phi(t1) = 0
50  * -----------------------------------------------------------------*/
51 
52 #include <stdio.h>
53 #include <stdlib.h>
54 
55 #include <idas/idas.h>                     /* prototypes for IDA fcts., consts.    */
56 #include <nvector/nvector_serial.h>        /* access to serial N_Vector            */
57 #include <sunmatrix/sunmatrix_sparse.h>    /* access to sparse SUNMatrix           */
58 #include <sunlinsol/sunlinsol_superlumt.h> /* access to SuperLUMT linear solver    */
59 #include <sundials/sundials_types.h>       /* defs. of realtype, sunindextype      */
60 #include <sundials/sundials_math.h>        /* defs. of SUNRabs, SUNRexp, etc.      */
61 
62 /* Accessor macros */
63 
64 #define Ith(v,i)    NV_Ith_S(v,i-1)       /* i-th vector component i= 1..NEQ */
65 
66 /* Problem Constants */
67 
68 #define NEQ      3             /* number of equations                  */
69 
70 #define RTOL     RCONST(1e-06) /* scalar relative tolerance            */
71 
72 #define ATOL1    RCONST(1e-08) /* vector absolute tolerance components */
73 #define ATOL2    RCONST(1e-12)
74 #define ATOL3    RCONST(1e-08)
75 
76 #define ATOLA    RCONST(1e-08) /* absolute tolerance for adjoint vars. */
77 #define ATOLQ    RCONST(1e-06) /* absolute tolerance for quadratures   */
78 
79 #define T0       RCONST(0.0)   /* initial time                         */
80 #define TOUT     RCONST(4e10)  /* final time                           */
81 
82 #define TB1      RCONST(50.0)  /* starting point for adjoint problem   */
83 #define TB2      TOUT          /* starting point for adjoint problem   */
84 
85 #define T1B      RCONST(49.0)  /* for IDACalcICB                       */
86 
87 #define STEPS    100           /* number of steps between check points */
88 
89 #define NP       3             /* number of problem parameters         */
90 
91 #define ONE     RCONST(1.0)
92 #define ZERO    RCONST(0.0)
93 
94 
95 /* Type : UserData */
96 
97 typedef struct {
98   realtype p[3];
99 } *UserData;
100 
101 /* Prototypes of user-supplied functions */
102 
103 static int res(realtype t, N_Vector yy, N_Vector yp,
104                N_Vector resval, void *user_data);
105 static int Jac(realtype t, realtype cj,
106                N_Vector yy, N_Vector yp, N_Vector resvec,
107                SUNMatrix JJ, void *user_data,
108                N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
109 
110 static int rhsQ(realtype t, N_Vector yy, N_Vector yp, N_Vector qdot, void *user_data);
111 static int ewt(N_Vector y, N_Vector w, void *user_data);
112 
113 static int resB(realtype tt,
114                 N_Vector yy, N_Vector yp,
115                 N_Vector yyB, N_Vector ypB, N_Vector rrB,
116                 void *user_dataB);
117 
118 static int JacB(realtype tt, realtype cjB,
119                 N_Vector yy, N_Vector yp,
120                 N_Vector yyB, N_Vector ypB, N_Vector rrB,
121                 SUNMatrix JB, void *user_data,
122                 N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B);
123 
124 
125 static int rhsQB(realtype tt,
126                  N_Vector yy, N_Vector yp,
127                  N_Vector yyB, N_Vector ypB,
128                  N_Vector rrQB, void *user_dataB);
129 
130 /* Prototypes of private functions */
131 static void PrintOutput(realtype tfinal, N_Vector yB, N_Vector ypB, N_Vector qB);
132 static int check_retval(void *returnvalue, char *funcname, int opt);
133 
134 /*
135  *--------------------------------------------------------------------
136  * MAIN PROGRAM
137  *--------------------------------------------------------------------
138  */
139 
main(int argc,char * argv[])140 int main(int argc, char *argv[])
141 {
142   UserData data;
143 
144   void *ida_mem;
145   SUNMatrix A, AB;
146   SUNLinearSolver LS, LSB;
147 
148   realtype reltolQ, abstolQ;
149   N_Vector yy, yp, q;
150   N_Vector yyTB1, ypTB1;
151   N_Vector id;
152 
153   int steps;
154 
155   int indexB;
156 
157   realtype reltolB, abstolB, abstolQB;
158   N_Vector yB, ypB, qB;
159   realtype time;
160   int retval, nthreads, nnz, ncheck;
161 
162   IDAadjCheckPointRec *ckpnt;
163 
164   long int nst, nstB;
165 
166   data = NULL;
167   ckpnt = NULL;
168   ida_mem = NULL;
169   yy = yp = yB = qB = NULL;
170   A = AB = NULL;
171   LS = LSB = NULL;
172 
173   /* Print problem description */
174   printf("\nAdjoint Sensitivity Example for Chemical Kinetics\n");
175   printf("-------------------------------------------------\n\n");
176   printf("DAE: dy1/dt + p1*y1 - p2*y2*y3 = 0\n");
177   printf("     dy2/dt - p1*y1 + p2*y2*y3 + p3*(y2)^2 = 0\n");
178   printf("               y1  +  y2  +  y3 = 0\n\n");
179   printf("Find dG/dp for\n");
180   printf("     G = int_t0^tB0 g(t,p,y) dt\n");
181   printf("     g(t,p,y) = y3\n\n\n");
182 
183   /* User data structure */
184   data = (UserData) malloc(sizeof *data);
185   if (check_retval((void *)data, "malloc", 2)) return(1);
186   data->p[0] = RCONST(0.04);
187   data->p[1] = RCONST(1.0e4);
188   data->p[2] = RCONST(3.0e7);
189 
190   /* Initialize y */
191   yy = N_VNew_Serial(NEQ);
192   if (check_retval((void *)yy, "N_VNew_Serial", 0)) return(1);
193   Ith(yy,1) = ONE;
194   Ith(yy,2) = ZERO;
195   Ith(yy,3) = ZERO;
196 
197   /* Initialize yprime */
198   yp = N_VNew_Serial(NEQ);
199   if (check_retval((void *)yp, "N_VNew_Serial", 0)) return(1);
200   Ith(yp,1) = RCONST(-0.04);
201   Ith(yp,2) = RCONST( 0.04);
202   Ith(yp,3) = ZERO;
203 
204   /* Initialize q */
205   q = N_VNew_Serial(1);
206   if (check_retval((void *)q, "N_VNew_Serial", 0)) return(1);
207   Ith(q,1) = ZERO;
208 
209   /* Set the scalar realtive and absolute tolerances reltolQ and abstolQ */
210   reltolQ = RTOL;
211   abstolQ = ATOLQ;
212 
213   /* Create and allocate IDAS memory for forward run */
214   printf("Create and allocate IDAS memory for forward runs\n");
215 
216   ida_mem = IDACreate();
217   if (check_retval((void *)ida_mem, "IDACreate", 0)) return(1);
218 
219   retval = IDAInit(ida_mem, res, T0, yy, yp);
220   if (check_retval(&retval, "IDAInit", 1)) return(1);
221 
222   retval = IDAWFtolerances(ida_mem, ewt);
223   if (check_retval(&retval, "IDAWFtolerances", 1)) return(1);
224 
225   retval = IDASetUserData(ida_mem, data);
226   if (check_retval(&retval, "IDASetUserData", 1)) return(1);
227 
228   /* Create sparse SUNMatrix for use in linear solves */
229   nnz = NEQ * NEQ;
230   A = SUNSparseMatrix(NEQ, NEQ, nnz, CSC_MAT);
231   if(check_retval((void *)A, "SUNSparseMatrix", 0)) return(1);
232 
233   /* Create SuperLUMT SUNLinearSolver object (one thread) */
234   nthreads = 1;
235   LS = SUNLinSol_SuperLUMT(yy, A, nthreads);
236   if(check_retval((void *)LS, "SUNLinSol_SuperLUMT", 0)) return(1);
237 
238   /* Attach the matrix and linear solver */
239   retval = IDASetLinearSolver(ida_mem, LS, A);
240   if(check_retval(&retval, "IDASetLinearSolver", 1)) return(1);
241 
242   /* Set the user-supplied Jacobian routine */
243   retval = IDASetJacFn(ida_mem, Jac);
244   if(check_retval(&retval, "IDASetJacFn", 1)) return(1);
245 
246   /* Setup quadrature integration */
247   retval = IDAQuadInit(ida_mem, rhsQ, q);
248   if (check_retval(&retval, "IDAQuadInit", 1)) return(1);
249 
250   retval = IDAQuadSStolerances(ida_mem, reltolQ, abstolQ);
251   if (check_retval(&retval, "IDAQuadSStolerances", 1)) return(1);
252 
253   retval = IDASetQuadErrCon(ida_mem, SUNTRUE);
254   if (check_retval(&retval, "IDASetQuadErrCon", 1)) return(1);
255 
256   /* Call IDASetMaxNumSteps to set the maximum number of steps the
257    * solver will take in an attempt to reach the next output time
258    * during forward integration. */
259   retval = IDASetMaxNumSteps(ida_mem, 2500);
260   if (check_retval(&retval, "IDASetMaxNumSteps", 1)) return(1);
261 
262   /* Allocate global memory */
263 
264   steps = STEPS;
265   retval = IDAAdjInit(ida_mem, steps, IDA_HERMITE);
266   /*retval = IDAAdjInit(ida_mem, steps, IDA_POLYNOMIAL);*/
267   if (check_retval(&retval, "IDAAdjInit", 1)) return(1);
268 
269   /* Perform forward run */
270   printf("Forward integration ... ");
271 
272   /* Integrate till TB1 and get the solution (y, y') at that time. */
273   retval = IDASolveF(ida_mem, TB1, &time, yy, yp, IDA_NORMAL, &ncheck);
274   if (check_retval(&retval, "IDASolveF", 1)) return(1);
275 
276   yyTB1 = N_VClone(yy);
277   ypTB1 = N_VClone(yp);
278   /* Save the states at t=TB1. */
279   N_VScale(ONE, yy, yyTB1);
280   N_VScale(ONE, yp, ypTB1);
281 
282   /* Continue integrating till TOUT is reached. */
283   retval = IDASolveF(ida_mem, TOUT, &time, yy, yp, IDA_NORMAL, &ncheck);
284   if (check_retval(&retval, "IDASolveF", 1)) return(1);
285 
286   retval = IDAGetNumSteps(ida_mem, &nst);
287   if (check_retval(&retval, "IDAGetNumSteps", 1)) return(1);
288 
289   printf("done ( nst = %ld )\n",nst);
290 
291   retval = IDAGetQuad(ida_mem, &time, q);
292   if (check_retval(&retval, "IDAGetQuad", 1)) return(1);
293 
294   printf("--------------------------------------------------------\n");
295 #if defined(SUNDIALS_EXTENDED_PRECISION)
296   printf("G:          %12.4Le \n",Ith(q,1));
297 #elif defined(SUNDIALS_DOUBLE_PRECISION)
298   printf("G:          %12.4e \n",Ith(q,1));
299 #else
300   printf("G:          %12.4e \n",Ith(q,1));
301 #endif
302   printf("--------------------------------------------------------\n\n");
303 
304   /* Test check point linked list
305      (uncomment next block to print check point information) */
306 
307   /*
308   {
309     int i;
310 
311     printf("\nList of Check Points (ncheck = %d)\n\n", ncheck);
312     ckpnt = (IDAadjCheckPointRec *) malloc ( (ncheck+1)*sizeof(IDAadjCheckPointRec));
313     IDAGetAdjCheckPointsInfo(ida_mem, ckpnt);
314     for (i=0;i<=ncheck;i++) {
315       printf("Address:       %p\n",ckpnt[i].my_addr);
316       printf("Next:          %p\n",ckpnt[i].next_addr);
317       printf("Time interval: %le  %le\n",ckpnt[i].t0, ckpnt[i].t1);
318       printf("Step number:   %ld\n",ckpnt[i].nstep);
319       printf("Order:         %d\n",ckpnt[i].order);
320       printf("Step size:     %le\n",ckpnt[i].step);
321       printf("\n");
322     }
323 
324   }
325   */
326 
327 
328   /* Create BACKWARD problem. */
329 
330   /* Allocate yB (i.e. lambda_0). */
331   yB = N_VNew_Serial(NEQ);
332   if (check_retval((void *)yB, "N_VNew_Serial", 0)) return(1);
333 
334   /* Consistently initialize yB. */
335   Ith(yB,1) = ZERO;
336   Ith(yB,2) = ZERO;
337   Ith(yB,3) = ONE;
338 
339 
340   /* Allocate ypB (i.e. lambda'_0). */
341   ypB = N_VNew_Serial(NEQ);
342   if (check_retval((void *)ypB, "N_VNew_Serial", 0)) return(1);
343 
344   /* Consistently initialize ypB. */
345   Ith(ypB,1) = ONE;
346   Ith(ypB,2) = ONE;
347   Ith(ypB,3) = ZERO;
348 
349 
350   /* Set the scalar relative tolerance reltolB */
351   reltolB = RTOL;
352 
353   /* Set the scalar absolute tolerance abstolB */
354   abstolB = ATOLA;
355 
356   /* Set the scalar absolute tolerance abstolQB */
357   abstolQB = ATOLQ;
358 
359   /* Create and allocate IDAS memory for backward run */
360   printf("Create and allocate IDAS memory for backward run\n");
361 
362   retval = IDACreateB(ida_mem, &indexB);
363   if (check_retval(&retval, "IDACreateB", 1)) return(1);
364 
365   retval = IDAInitB(ida_mem, indexB, resB, TB2, yB, ypB);
366   if (check_retval(&retval, "IDAInitB", 1)) return(1);
367 
368   retval = IDASStolerancesB(ida_mem, indexB, reltolB, abstolB);
369   if (check_retval(&retval, "IDASStolerancesB", 1)) return(1);
370 
371   retval = IDASetUserDataB(ida_mem, indexB, data);
372   if (check_retval(&retval, "IDASetUserDataB", 1)) return(1);
373 
374   retval = IDASetMaxNumStepsB(ida_mem, indexB, 1000);
375   if (check_retval(&retval, "IDASetMaxNumStepsB", 1)) return(1);
376 
377   /* Create sparse SUNMatrix for use in linear solves */
378   AB = SUNSparseMatrix(NEQ, NEQ, nnz, CSC_MAT);
379   if(check_retval((void *)AB, "SUNSparseMatrix", 0)) return(1);
380 
381   /* Create SuperLUMT SUNLinearSolver object (one thread) */
382   LSB = SUNLinSol_SuperLUMT(yB, AB, nthreads);
383   if(check_retval((void *)LSB, "SUNLinSol_SuperLUMT", 0)) return(1);
384 
385   /* Attach the matrix and linear solver */
386   retval = IDASetLinearSolverB(ida_mem, indexB, LSB, AB);
387   if(check_retval(&retval, "IDASetLinearSolverB", 1)) return(1);
388 
389   /* Set the user-supplied Jacobian routine */
390   retval = IDASetJacFnB(ida_mem, indexB, JacB);
391   if(check_retval(&retval, "IDASetJacFnB", 1)) return(1);
392 
393   /* Quadrature for backward problem. */
394 
395   /* Initialize qB */
396   qB = N_VNew_Serial(NP);
397   if (check_retval((void *)qB, "N_VNew", 0)) return(1);
398   Ith(qB,1) = ZERO;
399   Ith(qB,2) = ZERO;
400   Ith(qB,3) = ZERO;
401 
402   retval = IDAQuadInitB(ida_mem, indexB, rhsQB, qB);
403   if (check_retval(&retval, "IDAQuadInitB", 1)) return(1);
404 
405   retval = IDAQuadSStolerancesB(ida_mem, indexB, reltolB, abstolQB);
406   if (check_retval(&retval, "IDAQuadSStolerancesB", 1)) return(1);
407 
408   /* Include quadratures in error control. */
409   retval = IDASetQuadErrConB(ida_mem, indexB, SUNTRUE);
410   if (check_retval(&retval, "IDASetQuadErrConB", 1)) return(1);
411 
412 
413   /* Backward Integration */
414   printf("Backward integration ... ");
415 
416   retval = IDASolveB(ida_mem, T0, IDA_NORMAL);
417   if (check_retval(&retval, "IDASolveB", 1)) return(1);
418 
419   IDAGetNumSteps(IDAGetAdjIDABmem(ida_mem, indexB), &nstB);
420   printf("done ( nst = %ld )\n", nstB);
421 
422   retval = IDAGetB(ida_mem, indexB, &time, yB, ypB);
423   if (check_retval(&retval, "IDAGetB", 1)) return(1);
424 
425   retval = IDAGetQuadB(ida_mem, indexB, &time, qB);
426   if (check_retval(&retval, "IDAGetB", 1)) return(1);
427 
428   PrintOutput(TB2, yB, ypB, qB);
429 
430 
431   /* Reinitialize backward phase and start from a different time (TB1). */
432   printf("Re-initialize IDAS memory for backward run\n");
433 
434   /* Both algebraic part from y and the entire y' are computed by IDACalcIC. */
435   Ith(yB,1) = ZERO;
436   Ith(yB,2) = ZERO;
437   Ith(yB,3) = RCONST(0.50); /* not consistent */
438 
439   /* Rough guess for ypB. */
440   Ith(ypB,1) = RCONST(0.80);
441   Ith(ypB,2) = RCONST(0.75);
442   Ith(ypB,3) = ZERO;
443 
444   /* Initialize qB */
445   Ith(qB,1) = ZERO;
446   Ith(qB,2) = ZERO;
447   Ith(qB,3) = ZERO;
448 
449   retval = IDAReInitB(ida_mem, indexB, TB1, yB, ypB);
450   if (check_retval(&retval, "IDAReInitB", 1)) return(1);
451 
452   /* Also reinitialize quadratures. */
453   retval = IDAQuadReInitB(ida_mem, indexB, qB);
454   if (check_retval(&retval, "IDAQuadReInitB", 1)) return(1);
455 
456   /* Use IDACalcICB to compute consistent initial conditions
457      for this backward problem. */
458 
459   id = N_VNew_Serial(NEQ);
460   Ith(id,1) = 1.0;
461   Ith(id,2) = 1.0;
462   Ith(id,3) = 0.0;
463 
464   /* Specify which variables are differential (1) and which algebraic (0).*/
465   retval = IDASetIdB(ida_mem, indexB, id);
466   if (check_retval(&retval, "IDASetId", 1)) return(1);
467 
468   retval = IDACalcICB(ida_mem, indexB, T1B, yyTB1, ypTB1);
469   if (check_retval(&retval, "IDACalcICB", 1)) return(1);
470 
471   /* Get the consistent IC found by IDAS. */
472   retval = IDAGetConsistentICB(ida_mem, indexB, yB, ypB);
473   if (check_retval(&retval, "IDAGetConsistentICB", 1)) return(1);
474 
475   printf("Backward integration ... ");
476 
477   retval = IDASolveB(ida_mem, T0, IDA_NORMAL);
478   if (check_retval(&retval, "IDASolveB", 1)) return(1);
479 
480   IDAGetNumSteps(IDAGetAdjIDABmem(ida_mem, indexB), &nstB);
481   printf("done ( nst = %ld )\n", nstB);
482 
483   retval = IDAGetB(ida_mem, indexB, &time, yB, ypB);
484   if (check_retval(&retval, "IDAGetB", 1)) return(1);
485 
486   retval = IDAGetQuadB(ida_mem, indexB, &time, qB);
487   if (check_retval(&retval, "IDAGetQuadB", 1)) return(1);
488 
489   PrintOutput(TB1, yB, ypB, qB);
490 
491   /* Free any memory used.*/
492 
493   printf("Free memory\n\n");
494 
495   IDAFree(&ida_mem);
496   SUNLinSolFree(LS);
497   SUNMatDestroy(A);
498   SUNLinSolFree(LSB);
499   SUNMatDestroy(AB);
500   N_VDestroy(yy);
501   N_VDestroy(yp);
502   N_VDestroy(q);
503   N_VDestroy(yB);
504   N_VDestroy(ypB);
505   N_VDestroy(qB);
506   N_VDestroy(id);
507   N_VDestroy(yyTB1);
508   N_VDestroy(ypTB1);
509 
510   if (ckpnt != NULL) free(ckpnt);
511   free(data);
512 
513   return(0);
514 
515 }
516 
517 /*
518  *--------------------------------------------------------------------
519  * FUNCTIONS CALLED BY IDAS
520  *--------------------------------------------------------------------
521  */
522 
523 /*
524  * f routine. Compute f(t,y).
525 */
526 
res(realtype t,N_Vector yy,N_Vector yp,N_Vector resval,void * user_data)527 static int res(realtype t, N_Vector yy, N_Vector yp, N_Vector resval, void *user_data)
528 {
529   realtype y1, y2, y3,yp1, yp2, *rval;
530   UserData data;
531   realtype p1, p2, p3;
532 
533   y1  = Ith(yy,1); y2  = Ith(yy,2); y3  = Ith(yy,3);
534   yp1 = Ith(yp,1); yp2 = Ith(yp,2);
535   rval = N_VGetArrayPointer(resval);
536 
537   data = (UserData) user_data;
538   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
539 
540   rval[0] = p1*y1-p2*y2*y3;
541   rval[1] = -rval[0] + p3*y2*y2 + yp2;
542   rval[0]+= yp1;
543   rval[2] = y1+y2+y3-1;
544 
545   return(0);
546 }
547 
548 /*
549  * Jacobian routine. Compute J(t,y).
550 */
551 
Jac(realtype t,realtype cj,N_Vector yy,N_Vector yp,N_Vector resvec,SUNMatrix JJ,void * user_data,N_Vector tmp1,N_Vector tmp2,N_Vector tmp3)552 static int Jac(realtype t, realtype cj,
553                N_Vector yy, N_Vector yp, N_Vector resvec,
554                SUNMatrix JJ, void *user_data,
555                N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
556 {
557   realtype *yval;
558   sunindextype *colptrs = SUNSparseMatrix_IndexPointers(JJ);
559   sunindextype *rowvals = SUNSparseMatrix_IndexValues(JJ);
560   realtype *data = SUNSparseMatrix_Data(JJ);
561 
562   UserData userdata;
563   realtype p1, p2, p3;
564 
565   yval = N_VGetArrayPointer(yy);
566 
567   userdata = (UserData) user_data;
568   p1 = userdata->p[0]; p2 = userdata->p[1]; p3 = userdata->p[2];
569 
570   SUNMatZero(JJ);
571 
572   colptrs[0] = 0;
573   colptrs[1] = 3;
574   colptrs[2] = 6;
575   colptrs[3] = 9;
576 
577   /* column 0 */
578   data[0] = p1+cj;
579   rowvals[0] = 0;
580   data[1] = -p1;
581   rowvals[1] = 1;
582   data[2] = ONE;
583   rowvals[2] = 2;
584 
585   /* column 1 */
586   data[3] = -p2*yval[2];
587   rowvals[3] = 0;
588   data[4] = p2*yval[2]+2*p3*yval[1]+cj;
589   rowvals[4] = 1;
590   data[5] = ONE;
591   rowvals[5] = 2;
592 
593   /* column 2 */
594   data[6] = -p2*yval[1];
595   rowvals[6] = 0;
596   data[7] = p2*yval[1];
597   rowvals[7] = 1;
598   data[8] = ONE;
599   rowvals[8] = 2;
600 
601   return(0);
602 }
603 
604 /*
605  * rhsQ routine. Compute fQ(t,y).
606 */
607 
rhsQ(realtype t,N_Vector yy,N_Vector yp,N_Vector qdot,void * user_data)608 static int rhsQ(realtype t, N_Vector yy, N_Vector yp, N_Vector qdot, void *user_data)
609 {
610   Ith(qdot,1) = Ith(yy,3);
611   return(0);
612 }
613 
614 /*
615  * EwtSet function. Computes the error weights at the current solution.
616  */
617 
ewt(N_Vector y,N_Vector w,void * user_data)618 static int ewt(N_Vector y, N_Vector w, void *user_data)
619 {
620   int i;
621   realtype yy, ww, rtol, atol[3];
622 
623   rtol    = RTOL;
624   atol[0] = ATOL1;
625   atol[1] = ATOL2;
626   atol[2] = ATOL3;
627 
628   for (i=1; i<=3; i++) {
629     yy = Ith(y,i);
630     ww = rtol * SUNRabs(yy) + atol[i-1];
631     if (ww <= 0.0) return (-1);
632     Ith(w,i) = 1.0/ww;
633   }
634 
635   return(0);
636 }
637 
638 
639 /*
640  * resB routine.
641 */
642 
resB(realtype tt,N_Vector yy,N_Vector yp,N_Vector yyB,N_Vector ypB,N_Vector rrB,void * user_dataB)643 static int resB(realtype tt,
644                  N_Vector yy, N_Vector yp,
645                  N_Vector yyB, N_Vector ypB, N_Vector rrB,
646                  void *user_dataB)
647 {
648   UserData data;
649   realtype y2, y3;
650   realtype p1, p2, p3;
651   realtype l1, l2, l3;
652   realtype lp1, lp2;
653   realtype l21;
654 
655   data = (UserData) user_dataB;
656 
657   /* The p vector */
658   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
659 
660   /* The y  vector */
661   y2 = Ith(yy,2); y3 = Ith(yy,3);
662 
663   /* The lambda vector */
664   l1 = Ith(yyB,1); l2 = Ith(yyB,2); l3 = Ith(yyB,3);
665 
666   /* The lambda dot vector */
667   lp1 = Ith(ypB,1); lp2 = Ith(ypB,2);
668 
669   /* Temporary variables */
670   l21 = l2-l1;
671 
672   /* Load residual. */
673   Ith(rrB,1) = lp1 + p1*l21 - l3;
674   Ith(rrB,2) = lp2 - p2*y3*l21 - RCONST(2.0)*p3*y2*l2-l3;
675   Ith(rrB,3) = - p2*y2*l21 -l3 + RCONST(1.0);
676 
677   return(0);
678 }
679 
680 /*Jacobian for backward problem. */
JacB(realtype tt,realtype cjB,N_Vector yy,N_Vector yp,N_Vector yyB,N_Vector ypB,N_Vector rrB,SUNMatrix JB,void * user_data,N_Vector tmp1B,N_Vector tmp2B,N_Vector tmp3B)681 static int JacB(realtype tt, realtype cjB,
682                 N_Vector yy, N_Vector yp,
683                 N_Vector yyB, N_Vector ypB, N_Vector rrB,
684                 SUNMatrix JB, void *user_data,
685                 N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B)
686 {
687   realtype *yvalB;
688   sunindextype *colptrsB = SUNSparseMatrix_IndexPointers(JB);
689   sunindextype *rowvalsB = SUNSparseMatrix_IndexValues(JB);
690   realtype *dataB = SUNSparseMatrix_Data(JB);
691 
692   UserData userdata;
693   realtype p1, p2, p3;
694 
695   yvalB = N_VGetArrayPointer(yy);
696 
697   userdata = (UserData) user_data;
698   p1 = userdata->p[0]; p2 = userdata->p[1]; p3 = userdata->p[2];
699 
700   SUNMatZero(JB);
701 
702   colptrsB[0] = 0;
703   colptrsB[1] = 3;
704   colptrsB[2] = 6;
705   colptrsB[3] = 9;
706 
707   /* column 0 */
708   dataB[0] = -p1+cjB;
709   rowvalsB[0] = 0;
710   dataB[1] = p2*yvalB[2];
711   rowvalsB[1] = 1;
712   dataB[2] = p2*yvalB[1];
713   rowvalsB[2] = 2;
714 
715   /* column 1 */
716   dataB[3] = p1;
717   rowvalsB[3] = 0;
718   dataB[4] = -(p2*yvalB[2]+RCONST(2.0)*p3*yvalB[1])+cjB;
719   rowvalsB[4] = 1;
720   dataB[5] = -p2*yvalB[1];
721   rowvalsB[5] = 2;
722 
723   /* column 2 */
724   dataB[6] = -ONE;
725   rowvalsB[6] = 0;
726   dataB[7] = -ONE;
727   rowvalsB[7] = 1;
728   dataB[8] = -ONE;
729   rowvalsB[8] = 2;
730 
731   return(0);
732 }
733 
rhsQB(realtype tt,N_Vector yy,N_Vector yp,N_Vector yyB,N_Vector ypB,N_Vector rrQB,void * user_dataB)734 static int rhsQB(realtype tt,
735                  N_Vector yy, N_Vector yp,
736                  N_Vector yyB, N_Vector ypB,
737                  N_Vector rrQB, void *user_dataB)
738 {
739   realtype y1, y2, y3;
740   realtype l1, l2;
741   realtype l21;
742 
743   /* The y vector */
744   y1 = Ith(yy,1); y2 = Ith(yy,2); y3 = Ith(yy,3);
745 
746   /* The lambda vector */
747   l1 = Ith(yyB,1); l2 = Ith(yyB,2);
748 
749   /* Temporary variables */
750   l21 = l2-l1;
751 
752   Ith(rrQB,1) = y1*l21;
753   Ith(rrQB,2) = -y3*y2*l21;
754   Ith(rrQB,3) = -y2*y2*l2;
755 
756   return(0);
757 }
758 
759 
760 /*
761  *--------------------------------------------------------------------
762  * PRIVATE FUNCTIONS
763  *--------------------------------------------------------------------
764  */
765 
766 /*
767  * Print results after backward integration
768  */
769 
PrintOutput(realtype tfinal,N_Vector yB,N_Vector ypB,N_Vector qB)770 static void PrintOutput(realtype tfinal, N_Vector yB, N_Vector ypB, N_Vector qB)
771 {
772   printf("--------------------------------------------------------\n");
773 #if defined(SUNDIALS_EXTENDED_PRECISION)
774   printf("tB0:        %12.4Le\n",tfinal);
775   printf("dG/dp:      %12.4Le %12.4Le %12.4Le\n",
776          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
777   printf("lambda(t0): %12.4Le %12.4Le %12.4Le\n",
778          Ith(yB,1), Ith(yB,2), Ith(yB,3));
779 #elif defined(SUNDIALS_DOUBLE_PRECISION)
780   printf("tB0:        %12.4e\n",tfinal);
781   printf("dG/dp:      %12.4e %12.4e %12.4e\n",
782          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
783   printf("lambda(t0): %12.4e %12.4e %12.4e\n",
784          Ith(yB,1), Ith(yB,2), Ith(yB,3));
785 #else
786   printf("tB0:        %12.4e\n",tfinal);
787   printf("dG/dp:      %12.4e %12.4e %12.4e\n",
788          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
789   printf("lambda(t0): %12.4e %12.4e %12.4e\n",
790          Ith(yB,1), Ith(yB,2), Ith(yB,3));
791 #endif
792   printf("--------------------------------------------------------\n\n");
793 }
794 
795 /*
796  * Check function return value.
797  *    opt == 0 means SUNDIALS function allocates memory so check if
798  *             returned NULL pointer
799  *    opt == 1 means SUNDIALS function returns an integer value so check if
800  *             retval < 0
801  *    opt == 2 means function allocates memory so check if returned
802  *             NULL pointer
803  */
804 
check_retval(void * returnvalue,char * funcname,int opt)805 static int check_retval(void *returnvalue, char *funcname, int opt)
806 {
807   int *retval;
808 
809   /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
810   if (opt == 0 && returnvalue == NULL) {
811     fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
812 	    funcname);
813     return(1); }
814 
815   /* Check if retval < 0 */
816   else if (opt == 1) {
817     retval = (int *) returnvalue;
818     if (*retval < 0) {
819       fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with retval = %d\n\n",
820 	      funcname, *retval);
821       return(1); }}
822 
823   /* Check if function returned NULL pointer - no memory allocated */
824   else if (opt == 2 && returnvalue == NULL) {
825     fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
826 	    funcname);
827     return(1); }
828 
829   return(0);
830 }
831