1 /* -----------------------------------------------------------------
2  * Programmer(s): Radu Serban and Cosmin Petra @ LLNL
3  * -----------------------------------------------------------------
4  * SUNDIALS Copyright Start
5  * Copyright (c) 2002-2021, Lawrence Livermore National Security
6  * and Southern Methodist University.
7  * All rights reserved.
8  *
9  * See the top-level LICENSE and NOTICE files for details.
10  *
11  * SPDX-License-Identifier: BSD-3-Clause
12  * SUNDIALS Copyright End
13  * -----------------------------------------------------------------
14  * Adjoint sensitivity example problem.
15  *
16  * This simple example problem for IDAS, due to Robertson,
17  * is from chemical kinetics, and consists of the following three
18  * equations:
19  *
20  *      dy1/dt + p1*y1 - p2*y2*y3            = 0
21  *      dy2/dt - p1*y1 + p2*y2*y3 + p3*y2**2 = 0
22  *                 y1  +  y2  +  y3  -  1    = 0
23  *
24  * on the interval from t = 0.0 to t = 4.e10, with initial
25  * conditions: y1 = 1, y2 = y3 = 0.The reaction rates are: p1=0.04,
26  * p2=1e4, and p3=3e7
27  *
28  * It uses a scalar relative tolerance and a vector absolute
29  * tolerance.
30  *
31  * IDAS can also compute sensitivities with respect to
32  * the problem parameters p1, p2, and p3 of the following quantity:
33  *   G = int_t0^t1 g(t,p,y) dt
34  * where
35  *   g(t,p,y) = y3
36  *
37  * The gradient dG/dp is obtained as:
38  *   dG/dp = int_t0^t1 (g_p - lambda^T F_p ) dt -
39  *           lambda^T*F_y'*y_p | _t0^t1
40  *         = int_t0^t1 (lambda^T*F_p) dt
41  * where lambda and are solutions of the adjoint system:
42  *   d(lambda^T * F_y' )/dt -lambda^T F_y = -g_y
43  *
44  * During the backward integration, IDAS also evaluates G as
45  *   G = - phi(t0)
46  * where
47  *   d(phi)/dt = g(t,y,p)
48  *   phi(t1) = 0
49  * -----------------------------------------------------------------*/
50 
51 #include <stdio.h>
52 #include <stdlib.h>
53 
54 #include <idas/idas.h>                 /* prototypes for IDA fcts., consts.    */
55 #include <nvector/nvector_serial.h>    /* access to serial N_Vector            */
56 #include <sunmatrix/sunmatrix_dense.h> /* access to dense SUNMatrix            */
57 #include <sunlinsol/sunlinsol_dense.h> /* access to dense SUNLinearSolver      */
58 #include <sundials/sundials_types.h>   /* defs. of realtype, sunindextype      */
59 #include <sundials/sundials_math.h>    /* defs. of SUNRabs, SUNRexp, etc.      */
60 
61 /* Accessor macros */
62 
63 #define Ith(v,i)    NV_Ith_S(v,i-1)         /* i-th vector component i= 1..NEQ */
64 #define IJth(A,i,j) SM_ELEMENT_D(A,i-1,j-1) /* (i,j)-th matrix component i,j = 1..NEQ */
65 
66 /* Problem Constants */
67 
68 #define NEQ      3             /* number of equations                  */
69 
70 #define RTOL     RCONST(1e-06) /* scalar relative tolerance            */
71 
72 #define ATOL1    RCONST(1e-08) /* vector absolute tolerance components */
73 #define ATOL2    RCONST(1e-12)
74 #define ATOL3    RCONST(1e-08)
75 
76 #define ATOLA    RCONST(1e-08) /* absolute tolerance for adjoint vars. */
77 #define ATOLQ    RCONST(1e-06) /* absolute tolerance for quadratures   */
78 
79 #define T0       RCONST(0.0)   /* initial time                         */
80 #define TOUT     RCONST(4e10)  /* final time                           */
81 
82 #define TB1      RCONST(50.0)  /* starting point for adjoint problem   */
83 #define TB2      TOUT          /* starting point for adjoint problem   */
84 
85 #define T1B      RCONST(49.0)  /* for IDACalcICB                       */
86 
87 #define STEPS    100           /* number of steps between check points */
88 
89 #define NP       3             /* number of problem parameters         */
90 
91 #define ONE     RCONST(1.0)
92 #define ZERO    RCONST(0.0)
93 
94 
95 /* Type : UserData */
96 
97 typedef struct {
98   realtype p[3];
99 } *UserData;
100 
101 /* Prototypes of user-supplied functions */
102 
103 static int res(realtype t, N_Vector yy, N_Vector yp,
104                N_Vector resval, void *user_data);
105 static int Jac(realtype t, realtype cj,
106                N_Vector yy, N_Vector yp, N_Vector resvec,
107                SUNMatrix J, void *user_data,
108                N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
109 
110 static int rhsQ(realtype t, N_Vector yy, N_Vector yp, N_Vector qdot, void *user_data);
111 static int ewt(N_Vector y, N_Vector w, void *user_data);
112 
113 static int resB(realtype tt,
114                 N_Vector yy, N_Vector yp,
115                 N_Vector yyB, N_Vector ypB, N_Vector rrB,
116                 void *user_dataB);
117 
118 static int JacB(realtype tt, realtype cjB,
119                 N_Vector yy, N_Vector yp,
120                 N_Vector yyB, N_Vector ypB, N_Vector rrB,
121                 SUNMatrix JB, void *user_data,
122                 N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B);
123 
124 
125 static int rhsQB(realtype tt,
126                  N_Vector yy, N_Vector yp,
127                  N_Vector yyB, N_Vector ypB,
128                  N_Vector rrQB, void *user_dataB);
129 
130 /* Prototypes of private functions */
131 static void PrintOutput(realtype tfinal, N_Vector yB, N_Vector ypB, N_Vector qB);
132 static int check_retval(void *returnvalue, const char *funcname, int opt);
133 
134 /*
135  *--------------------------------------------------------------------
136  * MAIN PROGRAM
137  *--------------------------------------------------------------------
138  */
139 
main(int argc,char * argv[])140 int main(int argc, char *argv[])
141 {
142   UserData data;
143 
144   void *ida_mem;
145   SUNMatrix A, AB;
146   SUNLinearSolver LS, LSB;
147 
148   realtype reltolQ, abstolQ;
149   N_Vector yy, yp, q;
150   N_Vector yyTB1, ypTB1;
151   N_Vector id;
152 
153   int steps;
154 
155   int indexB;
156 
157   realtype reltolB, abstolB, abstolQB;
158   N_Vector yB, ypB, qB;
159   realtype time;
160   int retval, ncheck;
161 
162   IDAadjCheckPointRec *ckpnt;
163 
164   long int nst, nstB;
165 
166   data = NULL;
167   ckpnt = NULL;
168   ida_mem = NULL;
169   A = AB = NULL;
170   LS = LSB = NULL;
171   yy = yp = yB = qB = NULL;
172 
173   /* Print problem description */
174   printf("\nAdjoint Sensitivity Example for Chemical Kinetics\n");
175   printf("-------------------------------------------------\n\n");
176   printf("DAE: dy1/dt + p1*y1 - p2*y2*y3 = 0\n");
177   printf("     dy2/dt - p1*y1 + p2*y2*y3 + p3*(y2)^2 = 0\n");
178   printf("               y1  +  y2  +  y3 = 0\n\n");
179   printf("Find dG/dp for\n");
180   printf("     G = int_t0^tB0 g(t,p,y) dt\n");
181   printf("     g(t,p,y) = y3\n\n\n");
182 
183   /* User data structure */
184   data = (UserData) malloc(sizeof *data);
185   if (check_retval((void *)data, "malloc", 2)) return(1);
186   data->p[0] = RCONST(0.04);
187   data->p[1] = RCONST(1.0e4);
188   data->p[2] = RCONST(3.0e7);
189 
190   /* Initialize y */
191   yy = N_VNew_Serial(NEQ);
192   if (check_retval((void *)yy, "N_VNew_Serial", 0)) return(1);
193   Ith(yy,1) = ONE;
194   Ith(yy,2) = ZERO;
195   Ith(yy,3) = ZERO;
196 
197   /* Initialize yprime */
198   yp = N_VNew_Serial(NEQ);
199   if (check_retval((void *)yp, "N_VNew_Serial", 0)) return(1);
200   Ith(yp,1) = RCONST(-0.04);
201   Ith(yp,2) = RCONST( 0.04);
202   Ith(yp,3) = ZERO;
203 
204   /* Initialize q */
205   q = N_VNew_Serial(1);
206   if (check_retval((void *)q, "N_VNew_Serial", 0)) return(1);
207   Ith(q,1) = ZERO;
208 
209   /* Set the scalar realtive and absolute tolerances reltolQ and abstolQ */
210   reltolQ = RTOL;
211   abstolQ = ATOLQ;
212 
213   /* Create and allocate IDAS memory for forward run */
214   printf("Create and allocate IDAS memory for forward runs\n");
215 
216   ida_mem = IDACreate();
217   if (check_retval((void *)ida_mem, "IDACreate", 0)) return(1);
218 
219   retval = IDAInit(ida_mem, res, T0, yy, yp);
220   if (check_retval(&retval, "IDAInit", 1)) return(1);
221 
222   retval = IDAWFtolerances(ida_mem, ewt);
223   if (check_retval(&retval, "IDAWFtolerances", 1)) return(1);
224 
225   retval = IDASetUserData(ida_mem, data);
226   if (check_retval(&retval, "IDASetUserData", 1)) return(1);
227 
228   /* Create dense SUNMatrix for use in linear solves */
229   A = SUNDenseMatrix(NEQ, NEQ);
230   if(check_retval((void *)A, "SUNDenseMatrix", 0)) return(1);
231 
232   /* Create dense SUNLinearSolver object */
233   LS = SUNLinSol_Dense(yy, A);
234   if(check_retval((void *)LS, "SUNLinSol_Dense", 0)) return(1);
235 
236   /* Attach the matrix and linear solver */
237   retval = IDASetLinearSolver(ida_mem, LS, A);
238   if(check_retval(&retval, "IDASetLinearSolver", 1)) return(1);
239 
240   /* Set the user-supplied Jacobian routine */
241   retval = IDASetJacFn(ida_mem, Jac);
242   if(check_retval(&retval, "IDASetJacFn", 1)) return(1);
243 
244   /* Setup quadrature integration */
245   retval = IDAQuadInit(ida_mem, rhsQ, q);
246   if (check_retval(&retval, "IDAQuadInit", 1)) return(1);
247 
248   retval = IDAQuadSStolerances(ida_mem, reltolQ, abstolQ);
249   if (check_retval(&retval, "IDAQuadSStolerances", 1)) return(1);
250 
251   retval = IDASetQuadErrCon(ida_mem, SUNTRUE);
252   if (check_retval(&retval, "IDASetQuadErrCon", 1)) return(1);
253 
254   /* Call IDASetMaxNumSteps to set the maximum number of steps the
255    * solver will take in an attempt to reach the next output time
256    * during forward integration. */
257   retval = IDASetMaxNumSteps(ida_mem, 2500);
258   if (check_retval(&retval, "IDASetMaxNumSteps", 1)) return(1);
259 
260   /* Allocate global memory */
261 
262   steps = STEPS;
263   retval = IDAAdjInit(ida_mem, steps, IDA_HERMITE);
264   /*retval = IDAAdjInit(ida_mem, steps, IDA_POLYNOMIAL);*/
265   if (check_retval(&retval, "IDAAdjInit", 1)) return(1);
266 
267   /* Perform forward run */
268   printf("Forward integration ... ");
269 
270   /* Integrate till TB1 and get the solution (y, y') at that time. */
271   retval = IDASolveF(ida_mem, TB1, &time, yy, yp, IDA_NORMAL, &ncheck);
272   if (check_retval(&retval, "IDASolveF", 1)) return(1);
273 
274   yyTB1 = N_VClone(yy);
275   ypTB1 = N_VClone(yp);
276   /* Save the states at t=TB1. */
277   N_VScale(ONE, yy, yyTB1);
278   N_VScale(ONE, yp, ypTB1);
279 
280   /* Continue integrating till TOUT is reached. */
281   retval = IDASolveF(ida_mem, TOUT, &time, yy, yp, IDA_NORMAL, &ncheck);
282   if (check_retval(&retval, "IDASolveF", 1)) return(1);
283 
284   retval = IDAGetNumSteps(ida_mem, &nst);
285   if (check_retval(&retval, "IDAGetNumSteps", 1)) return(1);
286 
287   printf("done ( nst = %ld )\n",nst);
288 
289   retval = IDAGetQuad(ida_mem, &time, q);
290   if (check_retval(&retval, "IDAGetQuad", 1)) return(1);
291 
292   printf("--------------------------------------------------------\n");
293 #if defined(SUNDIALS_EXTENDED_PRECISION)
294   printf("G:          %12.4Le \n",Ith(q,1));
295 #elif defined(SUNDIALS_DOUBLE_PRECISION)
296   printf("G:          %12.4e \n",Ith(q,1));
297 #else
298   printf("G:          %12.4e \n",Ith(q,1));
299 #endif
300   printf("--------------------------------------------------------\n\n");
301 
302   /* Test check point linked list
303      (uncomment next block to print check point information) */
304 
305   /*
306   {
307     int i;
308 
309     printf("\nList of Check Points (ncheck = %d)\n\n", ncheck);
310     ckpnt = (IDAadjCheckPointRec *) malloc ( (ncheck+1)*sizeof(IDAadjCheckPointRec));
311     IDAGetAdjCheckPointsInfo(ida_mem, ckpnt);
312     for (i=0;i<=ncheck;i++) {
313       printf("Address:       %p\n",ckpnt[i].my_addr);
314       printf("Next:          %p\n",ckpnt[i].next_addr);
315       printf("Time interval: %le  %le\n",ckpnt[i].t0, ckpnt[i].t1);
316       printf("Step number:   %ld\n",ckpnt[i].nstep);
317       printf("Order:         %d\n",ckpnt[i].order);
318       printf("Step size:     %le\n",ckpnt[i].step);
319       printf("\n");
320     }
321 
322   }
323   */
324 
325 
326   /* Create BACKWARD problem. */
327 
328   /* Allocate yB (i.e. lambda_0). */
329   yB = N_VNew_Serial(NEQ);
330   if (check_retval((void *)yB, "N_VNew_Serial", 0)) return(1);
331 
332   /* Consistently initialize yB. */
333   Ith(yB,1) = ZERO;
334   Ith(yB,2) = ZERO;
335   Ith(yB,3) = ONE;
336 
337 
338   /* Allocate ypB (i.e. lambda'_0). */
339   ypB = N_VNew_Serial(NEQ);
340   if (check_retval((void *)ypB, "N_VNew_Serial", 0)) return(1);
341 
342   /* Consistently initialize ypB. */
343   Ith(ypB,1) = ONE;
344   Ith(ypB,2) = ONE;
345   Ith(ypB,3) = ZERO;
346 
347 
348   /* Set the scalar relative tolerance reltolB */
349   reltolB = RTOL;
350 
351   /* Set the scalar absolute tolerance abstolB */
352   abstolB = ATOLA;
353 
354   /* Set the scalar absolute tolerance abstolQB */
355   abstolQB = ATOLQ;
356 
357   /* Create and allocate IDAS memory for backward run */
358   printf("Create and allocate IDAS memory for backward run\n");
359 
360   retval = IDACreateB(ida_mem, &indexB);
361   if (check_retval(&retval, "IDACreateB", 1)) return(1);
362 
363   retval = IDAInitB(ida_mem, indexB, resB, TB2, yB, ypB);
364   if (check_retval(&retval, "IDAInitB", 1)) return(1);
365 
366   retval = IDASStolerancesB(ida_mem, indexB, reltolB, abstolB);
367   if (check_retval(&retval, "IDASStolerancesB", 1)) return(1);
368 
369   retval = IDASetUserDataB(ida_mem, indexB, data);
370   if (check_retval(&retval, "IDASetUserDataB", 1)) return(1);
371 
372   retval = IDASetMaxNumStepsB(ida_mem, indexB, 1000);
373   if (check_retval(&retval, "IDASetMaxNumStepsB", 1)) return(1);
374 
375   /* Create dense SUNMatrix for use in linear solves */
376   AB = SUNDenseMatrix(NEQ, NEQ);
377   if(check_retval((void *)AB, "SUNDenseMatrix", 0)) return(1);
378 
379   /* Create dense SUNLinearSolver object */
380   LSB = SUNLinSol_Dense(yB, AB);
381   if(check_retval((void *)LSB, "SUNLinSol_Dense", 0)) return(1);
382 
383   /* Attach the matrix and linear solver */
384   retval = IDASetLinearSolverB(ida_mem, indexB, LSB, AB);
385   if(check_retval(&retval, "IDASetLinearSolverB", 1)) return(1);
386 
387   /* Set the user-supplied Jacobian routine */
388   retval = IDASetJacFnB(ida_mem, indexB, JacB);
389   if(check_retval(&retval, "IDASetJacFnB", 1)) return(1);
390 
391   /* Quadrature for backward problem. */
392 
393   /* Initialize qB */
394   qB = N_VNew_Serial(NP);
395   if (check_retval((void *)qB, "N_VNew", 0)) return(1);
396   Ith(qB,1) = ZERO;
397   Ith(qB,2) = ZERO;
398   Ith(qB,3) = ZERO;
399 
400   retval = IDAQuadInitB(ida_mem, indexB, rhsQB, qB);
401   if (check_retval(&retval, "IDAQuadInitB", 1)) return(1);
402 
403   retval = IDAQuadSStolerancesB(ida_mem, indexB, reltolB, abstolQB);
404   if (check_retval(&retval, "IDAQuadSStolerancesB", 1)) return(1);
405 
406   /* Include quadratures in error control. */
407   retval = IDASetQuadErrConB(ida_mem, indexB, SUNTRUE);
408   if (check_retval(&retval, "IDASetQuadErrConB", 1)) return(1);
409 
410 
411   /* Backward Integration */
412   printf("Backward integration ... ");
413 
414   retval = IDASolveB(ida_mem, T0, IDA_NORMAL);
415   if (check_retval(&retval, "IDASolveB", 1)) return(1);
416 
417   IDAGetNumSteps(IDAGetAdjIDABmem(ida_mem, indexB), &nstB);
418   printf("done ( nst = %ld )\n", nstB);
419 
420   retval = IDAGetB(ida_mem, indexB, &time, yB, ypB);
421   if (check_retval(&retval, "IDAGetB", 1)) return(1);
422 
423   retval = IDAGetQuadB(ida_mem, indexB, &time, qB);
424   if (check_retval(&retval, "IDAGetB", 1)) return(1);
425 
426   PrintOutput(TB2, yB, ypB, qB);
427 
428 
429   /* Reinitialize backward phase and start from a different time (TB1). */
430   printf("Re-initialize IDAS memory for backward run\n");
431 
432   /* Both algebraic part from y and the entire y' are computed by IDACalcIC. */
433   Ith(yB,1) = ZERO;
434   Ith(yB,2) = ZERO;
435   Ith(yB,3) = RCONST(0.50); /* not consistent */
436 
437   /* Rough guess for ypB. */
438   Ith(ypB,1) = RCONST(0.80);
439   Ith(ypB,2) = RCONST(0.75);
440   Ith(ypB,3) = ZERO;
441 
442   /* Initialize qB */
443   Ith(qB,1) = ZERO;
444   Ith(qB,2) = ZERO;
445   Ith(qB,3) = ZERO;
446 
447   retval = IDAReInitB(ida_mem, indexB, TB1, yB, ypB);
448   if (check_retval(&retval, "IDAReInitB", 1)) return(1);
449 
450   /* Also reinitialize quadratures. */
451   retval = IDAQuadReInitB(ida_mem, indexB, qB);
452   if (check_retval(&retval, "IDAQuadReInitB", 1)) return(1);
453 
454   /* Use IDACalcICB to compute consistent initial conditions
455      for this backward problem. */
456 
457   id = N_VNew_Serial(NEQ);
458   Ith(id,1) = 1.0;
459   Ith(id,2) = 1.0;
460   Ith(id,3) = 0.0;
461 
462   /* Specify which variables are differential (1) and which algebraic (0).*/
463   retval = IDASetIdB(ida_mem, indexB, id);
464   if (check_retval(&retval, "IDASetId", 1)) return(1);
465 
466   retval = IDACalcICB(ida_mem, indexB, T1B, yyTB1, ypTB1);
467   if (check_retval(&retval, "IDACalcICB", 1)) return(1);
468 
469   /* Get the consistent IC found by IDAS. */
470   retval = IDAGetConsistentICB(ida_mem, indexB, yB, ypB);
471   if (check_retval(&retval, "IDAGetConsistentICB", 1)) return(1);
472 
473   printf("Backward integration ... ");
474 
475   retval = IDASolveB(ida_mem, T0, IDA_NORMAL);
476   if (check_retval(&retval, "IDASolveB", 1)) return(1);
477 
478   IDAGetNumSteps(IDAGetAdjIDABmem(ida_mem, indexB), &nstB);
479   printf("done ( nst = %ld )\n", nstB);
480 
481   retval = IDAGetB(ida_mem, indexB, &time, yB, ypB);
482   if (check_retval(&retval, "IDAGetB", 1)) return(1);
483 
484   retval = IDAGetQuadB(ida_mem, indexB, &time, qB);
485   if (check_retval(&retval, "IDAGetQuadB", 1)) return(1);
486 
487   PrintOutput(TB1, yB, ypB, qB);
488 
489   /* Free any memory used.*/
490 
491   printf("Free memory\n\n");
492 
493   IDAFree(&ida_mem);
494   SUNLinSolFree(LS);
495   SUNMatDestroy(A);
496   SUNLinSolFree(LSB);
497   SUNMatDestroy(AB);
498   N_VDestroy(yy);
499   N_VDestroy(yp);
500   N_VDestroy(q);
501   N_VDestroy(yB);
502   N_VDestroy(ypB);
503   N_VDestroy(qB);
504   N_VDestroy(id);
505   N_VDestroy(yyTB1);
506   N_VDestroy(ypTB1);
507 
508   if (ckpnt != NULL) free(ckpnt);
509   free(data);
510 
511   return(0);
512 
513 }
514 
515 /*
516  *--------------------------------------------------------------------
517  * FUNCTIONS CALLED BY IDAS
518  *--------------------------------------------------------------------
519  */
520 
521 /*
522  * f routine. Compute f(t,y).
523 */
524 
res(realtype t,N_Vector yy,N_Vector yp,N_Vector resval,void * user_data)525 static int res(realtype t, N_Vector yy, N_Vector yp, N_Vector resval, void *user_data)
526 {
527   realtype y1, y2, y3, yp1, yp2, *rval;
528   UserData data;
529   realtype p1, p2, p3;
530 
531   y1  = Ith(yy,1); y2  = Ith(yy,2); y3  = Ith(yy,3);
532   yp1 = Ith(yp,1); yp2 = Ith(yp,2);
533   rval = N_VGetArrayPointer(resval);
534 
535   data = (UserData) user_data;
536   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
537 
538   rval[0] = p1*y1-p2*y2*y3;
539   rval[1] = -rval[0] + p3*y2*y2 + yp2;
540   rval[0]+= yp1;
541   rval[2] = y1+y2+y3-1;
542 
543   return(0);
544 }
545 
546 /*
547  * Jacobian routine. Compute J(t,y).
548 */
549 
Jac(realtype t,realtype cj,N_Vector yy,N_Vector yp,N_Vector resvec,SUNMatrix J,void * user_data,N_Vector tmp1,N_Vector tmp2,N_Vector tmp3)550 static int Jac(realtype t, realtype cj,
551                N_Vector yy, N_Vector yp, N_Vector resvec,
552                SUNMatrix J, void *user_data,
553                N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
554 {
555   realtype y2, y3;
556   UserData data;
557   realtype p1, p2, p3;
558 
559   y2 = Ith(yy,2); y3 = Ith(yy,3);
560 
561   data = (UserData) user_data;
562   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
563 
564   IJth(J,1,1) = p1+cj;
565   IJth(J,2,1) = -p1;
566   IJth(J,3,1) = ONE;
567 
568   IJth(J,1,2) = -p2*y3;
569   IJth(J,2,2) = p2*y3+2*p3*y2+cj;
570   IJth(J,3,2) = ONE;
571 
572   IJth(J,1,3) = -p2*y2;
573   IJth(J,2,3) = p2*y2;
574   IJth(J,3,3) = ONE;
575 
576   return(0);
577 }
578 
579 /*
580  * rhsQ routine. Compute fQ(t,y).
581 */
582 
rhsQ(realtype t,N_Vector yy,N_Vector yp,N_Vector qdot,void * user_data)583 static int rhsQ(realtype t, N_Vector yy, N_Vector yp, N_Vector qdot, void *user_data)
584 {
585   Ith(qdot,1) = Ith(yy,3);
586   return(0);
587 }
588 
589 /*
590  * EwtSet function. Computes the error weights at the current solution.
591  */
592 
ewt(N_Vector y,N_Vector w,void * user_data)593 static int ewt(N_Vector y, N_Vector w, void *user_data)
594 {
595   int i;
596   realtype yy, ww, rtol, atol[3];
597 
598   rtol    = RTOL;
599   atol[0] = ATOL1;
600   atol[1] = ATOL2;
601   atol[2] = ATOL3;
602 
603   for (i=1; i<=3; i++) {
604     yy = Ith(y,i);
605     ww = rtol * SUNRabs(yy) + atol[i-1];
606     if (ww <= 0.0) return (-1);
607     Ith(w,i) = 1.0/ww;
608   }
609 
610   return(0);
611 }
612 
613 
614 /*
615  * resB routine.
616 */
617 
resB(realtype tt,N_Vector yy,N_Vector yp,N_Vector yyB,N_Vector ypB,N_Vector rrB,void * user_dataB)618 static int resB(realtype tt,
619                  N_Vector yy, N_Vector yp,
620                  N_Vector yyB, N_Vector ypB, N_Vector rrB,
621                  void *user_dataB)
622 {
623   UserData data;
624   realtype y2, y3;
625   realtype p1, p2, p3;
626   realtype l1, l2, l3;
627   realtype lp1, lp2;
628   realtype l21;
629 
630   data = (UserData) user_dataB;
631 
632   /* The p vector */
633   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
634 
635   /* The y  vector */
636   y2 = Ith(yy,2); y3 = Ith(yy,3);
637 
638   /* The lambda vector */
639   l1 = Ith(yyB,1); l2 = Ith(yyB,2); l3 = Ith(yyB,3);
640 
641   /* The lambda dot vector */
642   lp1 = Ith(ypB,1); lp2 = Ith(ypB,2);
643 
644   /* Temporary variables */
645   l21 = l2-l1;
646 
647   /* Load residual. */
648   Ith(rrB,1) = lp1 + p1*l21 - l3;
649   Ith(rrB,2) = lp2 - p2*y3*l21 - RCONST(2.0)*p3*y2*l2-l3;
650   Ith(rrB,3) = - p2*y2*l21 -l3 + RCONST(1.0);
651 
652   return(0);
653 }
654 
655 /*Jacobian for backward problem. */
JacB(realtype tt,realtype cj,N_Vector yy,N_Vector yp,N_Vector yyB,N_Vector ypB,N_Vector rrB,SUNMatrix JB,void * user_data,N_Vector tmp1B,N_Vector tmp2B,N_Vector tmp3B)656 static int JacB(realtype tt, realtype cj,
657                 N_Vector yy, N_Vector yp,
658                 N_Vector yyB, N_Vector ypB, N_Vector rrB,
659                 SUNMatrix JB, void *user_data,
660                 N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B)
661 {
662   realtype y2, y3;
663   UserData data;
664   realtype p1, p2, p3;
665 
666   y2 = Ith(yy,2); y3 = Ith(yy,3);
667 
668   data = (UserData) user_data;
669   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
670 
671   IJth(JB,1,1) = -p1+cj;
672   IJth(JB,1,2) = p1;
673   IJth(JB,1,3) = -ONE;
674 
675   IJth(JB,2,1) = p2*y3;
676   IJth(JB,2,2) = -(p2*y3+RCONST(2.0)*p3*y2)+cj;
677   IJth(JB,2,3) = -ONE;
678 
679   IJth(JB,3,1) = p2*y2;
680   IJth(JB,3,2) = -p2*y2;
681   IJth(JB,3,3) = -ONE;
682 
683 
684   return(0);
685 }
686 
rhsQB(realtype tt,N_Vector yy,N_Vector yp,N_Vector yyB,N_Vector ypB,N_Vector rrQB,void * user_dataB)687 static int rhsQB(realtype tt,
688                  N_Vector yy, N_Vector yp,
689                  N_Vector yyB, N_Vector ypB,
690                  N_Vector rrQB, void *user_dataB)
691 {
692   realtype y1, y2, y3;
693   realtype l1, l2;
694   realtype l21;
695 
696   /* The y vector */
697   y1 = Ith(yy,1); y2 = Ith(yy,2); y3 = Ith(yy,3);
698 
699   /* The lambda vector */
700   l1 = Ith(yyB,1); l2 = Ith(yyB,2);
701 
702   /* Temporary variables */
703   l21 = l2-l1;
704 
705   Ith(rrQB,1) = y1*l21;
706   Ith(rrQB,2) = -y3*y2*l21;
707   Ith(rrQB,3) = -y2*y2*l2;
708 
709   return(0);
710 }
711 
712 
713 /*
714  *--------------------------------------------------------------------
715  * PRIVATE FUNCTIONS
716  *--------------------------------------------------------------------
717  */
718 
719 /*
720  * Print results after backward integration
721  */
722 
PrintOutput(realtype tfinal,N_Vector yB,N_Vector ypB,N_Vector qB)723 static void PrintOutput(realtype tfinal, N_Vector yB, N_Vector ypB, N_Vector qB)
724 {
725   printf("--------------------------------------------------------\n");
726 #if defined(SUNDIALS_EXTENDED_PRECISION)
727   printf("tB0:        %12.4Le\n",tfinal);
728   printf("dG/dp:      %12.4Le %12.4Le %12.4Le\n",
729          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
730   printf("lambda(t0): %12.4Le %12.4Le %12.4Le\n",
731          Ith(yB,1), Ith(yB,2), Ith(yB,3));
732 #elif defined(SUNDIALS_DOUBLE_PRECISION)
733   printf("tB0:        %12.4e\n",tfinal);
734   printf("dG/dp:      %12.4e %12.4e %12.4e\n",
735          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
736   printf("lambda(t0): %12.4e %12.4e %12.4e\n",
737          Ith(yB,1), Ith(yB,2), Ith(yB,3));
738 #else
739   printf("tB0:        %12.4e\n",tfinal);
740   printf("dG/dp:      %12.4e %12.4e %12.4e\n",
741          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
742   printf("lambda(t0): %12.4e %12.4e %12.4e\n",
743          Ith(yB,1), Ith(yB,2), Ith(yB,3));
744 #endif
745   printf("--------------------------------------------------------\n\n");
746 }
747 
748 /*
749  * Check function return value.
750  *    opt == 0 means SUNDIALS function allocates memory so check if
751  *             returned NULL pointer
752  *    opt == 1 means SUNDIALS function returns an integer value so check if
753  *             retval < 0
754  *    opt == 2 means function allocates memory so check if returned
755  *             NULL pointer
756  */
757 
check_retval(void * returnvalue,const char * funcname,int opt)758 static int check_retval(void *returnvalue, const char *funcname, int opt)
759 {
760   int *retval;
761 
762   /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
763   if (opt == 0 && returnvalue == NULL) {
764     fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
765 	    funcname);
766     return(1); }
767 
768   /* Check if retval < 0 */
769   else if (opt == 1) {
770     retval = (int *) returnvalue;
771     if (*retval < 0) {
772       fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with retval = %d\n\n",
773 	      funcname, *retval);
774       return(1); }}
775 
776   /* Check if function returned NULL pointer - no memory allocated */
777   else if (opt == 2 && returnvalue == NULL) {
778     fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
779 	    funcname);
780     return(1); }
781 
782   return(0);
783 }
784