1 /* -----------------------------------------------------------------
2  * Programmer(s): Ting Yan @ SMU
3  *      Based on cvsRoberts_ASAi_dns.c and modified to use SuperLUMT
4  * -----------------------------------------------------------------
5  * SUNDIALS Copyright Start
6  * Copyright (c) 2002-2020, Lawrence Livermore National Security
7  * and Southern Methodist University.
8  * All rights reserved.
9  *
10  * See the top-level LICENSE and NOTICE files for details.
11  *
12  * SPDX-License-Identifier: BSD-3-Clause
13  * SUNDIALS Copyright End
14  *-----------------------------------------------------------------
15  * Adjoint sensitivity example problem.
16  * The following is a simple example problem, with the coding
17  * needed for its solution by CVODES. The problem is from chemical
18  * kinetics, and consists of the following three rate equations.
19  *    dy1/dt = -p1*y1 + p2*y2*y3
20  *    dy2/dt =  p1*y1 - p2*y2*y3 - p3*(y2)^2
21  *    dy3/dt =  p3*(y2)^2
22  * on the interval from t = 0.0 to t = 4.e10, with initial
23  * conditions: y1 = 1.0, y2 = y3 = 0. The reaction rates are:
24  * p1=0.04, p2=1e4, and p3=3e7. The problem is stiff.
25  * This program solves the problem with the BDF method, Newton
26  * iteration with the SUPERLU_MT linear solver, and a user-supplied
27  * Jacobian routine.
28  * It uses a scalar relative tolerance and a vector absolute
29  * tolerance.
30  * Output is printed in decades from t = .4 to t = 4.e10.
31  * Run statistics (optional outputs) are printed at the end.
32  *
33  * Optionally, CVODES can compute sensitivities with respect to
34  * the problem parameters p1, p2, and p3 of the following quantity:
35  *   G = int_t0^t1 g(t,p,y) dt
36  * where
37  *   g(t,p,y) = y3
38  *
39  * The gradient dG/dp is obtained as:
40  *   dG/dp = int_t0^t1 (g_p - lambda^T f_p ) dt - lambda^T(t0)*y0_p
41  *         = - xi^T(t0) - lambda^T(t0)*y0_p
42  * where lambda and xi are solutions of:
43  *   d(lambda)/dt = - (f_y)^T * lambda - (g_y)^T
44  *   lambda(t1) = 0
45  * and
46  *   d(xi)/dt = - (f_p)^T * lambda + (g_p)^T
47  *   xi(t1) = 0
48  *
49  * During the backward integration, CVODES also evaluates G as
50  *   G = - phi(t0)
51  * where
52  *   d(phi)/dt = g(t,y,p)
53  *   phi(t1) = 0
54  * -----------------------------------------------------------------*/
55 
56 #include <stdio.h>
57 #include <stdlib.h>
58 
59 #include <cvodes/cvodes.h>                 /* prototypes for CVODE fcts., consts.  */
60 #include <nvector/nvector_serial.h>        /* access to serial N_Vector            */
61 #include <sunmatrix/sunmatrix_sparse.h>    /* access to sparse SUNMatrix           */
62 #include <sunlinsol/sunlinsol_superlumt.h> /* access to SuperLUMT SUNLinearSolver  */
63 #include <sundials/sundials_types.h>       /* defs. of realtype, sunindextype      */
64 #include <sundials/sundials_math.h>        /* defs. of SUNRabs, SUNRexp, etc.      */
65 
66 /* Accessor macros */
67 /* These macros are defined in order to write code with which exactly matched
68    the mathematical problem description given above. */
69 
70 #define Ith(v,i)    NV_Ith_S(v,i-1)       /* i-th vector component, i=1..NEQ */
71 
72 /* Problem Constants */
73 
74 #define NEQ      3             /* number of equations                  */
75 
76 #define RTOL     RCONST(1e-6)  /* scalar relative tolerance            */
77 
78 #define ATOL1    RCONST(1e-8)  /* vector absolute tolerance components */
79 #define ATOL2    RCONST(1e-14)
80 #define ATOL3    RCONST(1e-6)
81 
82 #define ATOLl    RCONST(1e-8)  /* absolute tolerance for adjoint vars. */
83 #define ATOLq    RCONST(1e-6)  /* absolute tolerance for quadratures   */
84 
85 #define T0       RCONST(0.0)   /* initial time                         */
86 #define TOUT     RCONST(4e7)   /* final time                           */
87 
88 #define TB1      RCONST(4e7)   /* starting point for adjoint problem   */
89 #define TB2      RCONST(50.0)  /* starting point for adjoint problem   */
90 #define TBout1   RCONST(40.0)  /* intermediate t for adjoint problem   */
91 
92 #define STEPS    150           /* number of steps between check points */
93 
94 #define NP       3             /* number of problem parameters         */
95 
96 #define ZERO     RCONST(0.0)
97 
98 
99 /* Type : UserData */
100 
101 typedef struct {
102   realtype p[3];
103 } *UserData;
104 
105 /* Prototypes of user-supplied functions */
106 
107 static int f(realtype t, N_Vector y, N_Vector ydot, void *user_data);
108 static int Jac(realtype t, N_Vector y, N_Vector fy, SUNMatrix J,
109                void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
110 static int fQ(realtype t, N_Vector y, N_Vector qdot, void *user_data);
111 static int ewt(N_Vector y, N_Vector w, void *user_data);
112 
113 static int fB(realtype t, N_Vector y,
114               N_Vector yB, N_Vector yBdot, void *user_dataB);
115 static int JacB(realtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix JB,
116                 void *user_dataB, N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B);
117 static int fQB(realtype t, N_Vector y, N_Vector yB,
118                N_Vector qBdot, void *user_dataB);
119 
120 
121 /* Prototypes of private functions */
122 
123 static void PrintHead(realtype tB0);
124 static void PrintOutput(realtype tfinal, N_Vector y, N_Vector yB, N_Vector qB);
125 static void PrintOutput1(realtype time, realtype t, N_Vector y, N_Vector yB);
126 static int check_retval(void *returnvalue, const char *funcname, int opt);
127 
128 /*
129  *--------------------------------------------------------------------
130  * MAIN PROGRAM
131  *--------------------------------------------------------------------
132  */
133 
main(int argc,char * argv[])134 int main(int argc, char *argv[])
135 {
136   UserData data;
137 
138   SUNMatrix A, AB;
139   SUNLinearSolver LS, LSB;
140   void *cvode_mem;
141 
142   realtype reltolQ, abstolQ;
143   N_Vector y, q;
144 
145   int steps;
146 
147   int indexB;
148 
149   realtype reltolB, abstolB, abstolQB;
150   N_Vector yB, qB;
151 
152   realtype time;
153   int retval, nthreads, nnz, ncheck;
154 
155   long int nst, nstB;
156 
157   CVadjCheckPointRec *ckpnt;
158 
159   data = NULL;
160   A = AB = NULL;
161   LS = LSB = NULL;
162   cvode_mem = NULL;
163   ckpnt = NULL;
164   y = yB = qB = NULL;
165 
166   /* Print problem description */
167   printf("\nAdjoint Sensitivity Example for Chemical Kinetics\n");
168   printf("-------------------------------------------------\n\n");
169   printf("ODE: dy1/dt = -p1*y1 + p2*y2*y3\n");
170   printf("     dy2/dt =  p1*y1 - p2*y2*y3 - p3*(y2)^2\n");
171   printf("     dy3/dt =  p3*(y2)^2\n\n");
172   printf("Find dG/dp for\n");
173   printf("     G = int_t0^tB0 g(t,p,y) dt\n");
174   printf("     g(t,p,y) = y3\n\n\n");
175 
176   /* User data structure */
177   data = (UserData) malloc(sizeof *data);
178   if (check_retval((void *)data, "malloc", 2)) return(1);
179   data->p[0] = RCONST(0.04);
180   data->p[1] = RCONST(1.0e4);
181   data->p[2] = RCONST(3.0e7);
182 
183   /* Initialize y */
184   y = N_VNew_Serial(NEQ);
185   if (check_retval((void *)y, "N_VNew_Serial", 0)) return(1);
186   Ith(y,1) = RCONST(1.0);
187   Ith(y,2) = ZERO;
188   Ith(y,3) = ZERO;
189 
190   /* Initialize q */
191   q = N_VNew_Serial(1);
192   if (check_retval((void *)q, "N_VNew_Serial", 0)) return(1);
193   Ith(q,1) = ZERO;
194 
195   /* Set the scalar realtive and absolute tolerances reltolQ and abstolQ */
196   reltolQ = RTOL;
197   abstolQ = ATOLq;
198 
199   /* Create and allocate CVODES memory for forward run */
200   printf("Create and allocate CVODES memory for forward runs\n");
201 
202   /* Call CVodeCreate to create the solver memory and specify the
203      Backward Differentiation Formula */
204   cvode_mem = CVodeCreate(CV_BDF);
205   if (check_retval((void *)cvode_mem, "CVodeCreate", 0)) return(1);
206 
207   /* Call CVodeInit to initialize the integrator memory and specify the
208      user's right hand side function in y'=f(t,y), the initial time T0, and
209      the initial dependent variable vector y. */
210   retval = CVodeInit(cvode_mem, f, T0, y);
211   if (check_retval(&retval, "CVodeInit", 1)) return(1);
212 
213   /* Call CVodeWFtolerances to specify a user-supplied function ewt that sets
214      the multiplicative error weights w_i for use in the weighted RMS norm */
215   retval = CVodeWFtolerances(cvode_mem, ewt);
216   if (check_retval(&retval, "CVodeWFtolerances", 1)) return(1);
217 
218   /* Attach user data */
219   retval = CVodeSetUserData(cvode_mem, data);
220   if (check_retval(&retval, "CVodeSetUserData", 1)) return(1);
221 
222   /* Create sparse SUNMatrix for use in linear solves */
223   nnz = NEQ * NEQ; /* max no. of nonzeros entries in the Jac */
224   A = SUNSparseMatrix(NEQ, NEQ, nnz, CSC_MAT);
225   if (check_retval((void *)A, "SUNSparseMatrix", 0)) return(1);
226 
227   /* Create SuperLUMT SUNLinearSolver object */
228   nthreads = 1; /* no. of threads use when factoring the system */
229   LS = SUNLinSol_SuperLUMT(y, A, nthreads);
230   if (check_retval((void *)LS, "SUNLinSol_SuperLUMT", 0)) return(1);
231 
232   /* Attach the matrix and linear solver for the forward problem */
233   retval = CVodeSetLinearSolver(cvode_mem, LS, A);
234   if (check_retval(&retval, "CVodeSetLinearSolver", 1)) return(1);
235 
236   /* Set the user-supplied Jacobian routine for the forward problem */
237   retval = CVodeSetJacFn(cvode_mem, Jac);
238   if (check_retval(&retval, "CVodeSetJacFn", 1)) return(1);
239 
240   /* Call CVodeQuadInit to allocate initernal memory and initialize
241      quadrature integration*/
242   retval = CVodeQuadInit(cvode_mem, fQ, q);
243   if (check_retval(&retval, "CVodeQuadInit", 1)) return(1);
244 
245   /* Call CVodeSetQuadErrCon to specify whether or not the quadrature variables
246      are to be used in the step size control mechanism within CVODES. Call
247      CVodeQuadSStolerances or CVodeQuadSVtolerances to specify the integration
248      tolerances for the quadrature variables. */
249   retval = CVodeSetQuadErrCon(cvode_mem, SUNTRUE);
250   if (check_retval(&retval, "CVodeSetQuadErrCon", 1)) return(1);
251 
252   /* Call CVodeQuadSStolerances to specify scalar relative and absolute
253      tolerances. */
254   retval = CVodeQuadSStolerances(cvode_mem, reltolQ, abstolQ);
255   if (check_retval(&retval, "CVodeQuadSStolerances", 1)) return(1);
256 
257   /* Allocate global memory */
258 
259   /* Call CVodeAdjInit to update CVODES memory block by allocting the internal
260      memory needed for backward integration.*/
261   steps = STEPS; /* no. of integration steps between two consecutive ckeckpoints*/
262   retval = CVodeAdjInit(cvode_mem, steps, CV_HERMITE);
263   /*
264   retval = CVodeAdjInit(cvode_mem, steps, CV_POLYNOMIAL);
265   */
266   if (check_retval(&retval, "CVodeAdjInit", 1)) return(1);
267 
268   /* Perform forward run */
269   printf("Forward integration ... ");
270 
271   /* Call CVodeF to integrate the forward problem over an interval in time and
272      saves checkpointing data */
273   retval = CVodeF(cvode_mem, TOUT, y, &time, CV_NORMAL, &ncheck);
274   if (check_retval(&retval, "CVodeF", 1)) return(1);
275   retval = CVodeGetNumSteps(cvode_mem, &nst);
276   if (check_retval(&retval, "CVodeGetNumSteps", 1)) return(1);
277 
278   printf("done ( nst = %ld )\n",nst);
279   printf("\nncheck = %d\n\n", ncheck);
280 
281   retval = CVodeGetQuad(cvode_mem, &time, q);
282   if (check_retval(&retval, "CVodeGetQuad", 1)) return(1);
283 
284   printf("--------------------------------------------------------\n");
285 #if defined(SUNDIALS_EXTENDED_PRECISION)
286   printf("G:          %12.4Le \n",Ith(q,1));
287 #elif defined(SUNDIALS_DOUBLE_PRECISION)
288   printf("G:          %12.4e \n",Ith(q,1));
289 #else
290   printf("G:          %12.4e \n",Ith(q,1));
291 #endif
292   printf("--------------------------------------------------------\n\n");
293 
294   /* Test check point linked list
295      (uncomment next block to print check point information) */
296 
297   /*
298   {
299     int i;
300 
301     printf("\nList of Check Points (ncheck = %d)\n\n", ncheck);
302     ckpnt = (CVadjCheckPointRec *) malloc ( (ncheck+1)*sizeof(CVadjCheckPointRec));
303     CVodeGetAdjCheckPointsInfo(cvode_mem, ckpnt);
304     for (i=0;i<=ncheck;i++) {
305       printf("Address:       %p\n",ckpnt[i].my_addr);
306       printf("Next:          %p\n",ckpnt[i].next_addr);
307       printf("Time interval: %le  %le\n",ckpnt[i].t0, ckpnt[i].t1);
308       printf("Step number:   %ld\n",ckpnt[i].nstep);
309       printf("Order:         %d\n",ckpnt[i].order);
310       printf("Step size:     %le\n",ckpnt[i].step);
311       printf("\n");
312     }
313 
314   }
315   */
316 
317   /* Initialize yB */
318   yB = N_VNew_Serial(NEQ);
319   if (check_retval((void *)yB, "N_VNew_Serial", 0)) return(1);
320   Ith(yB,1) = ZERO;
321   Ith(yB,2) = ZERO;
322   Ith(yB,3) = ZERO;
323 
324   /* Initialize qB */
325   qB = N_VNew_Serial(NP);
326   if (check_retval((void *)qB, "N_VNew", 0)) return(1);
327   Ith(qB,1) = ZERO;
328   Ith(qB,2) = ZERO;
329   Ith(qB,3) = ZERO;
330 
331   /* Set the scalar relative tolerance reltolB */
332   reltolB = RTOL;
333 
334   /* Set the scalar absolute tolerance abstolB */
335   abstolB = ATOLl;
336 
337   /* Set the scalar absolute tolerance abstolQB */
338   abstolQB = ATOLq;
339 
340   /* Create and allocate CVODES memory for backward run */
341   printf("Create and allocate CVODES memory for backward run\n");
342 
343   /* Call CVodeCreateB to specify the solution method for the backward
344      problem. */
345   retval = CVodeCreateB(cvode_mem, CV_BDF, &indexB);
346   if (check_retval(&retval, "CVodeCreateB", 1)) return(1);
347 
348   /* Call CVodeInitB to allocate internal memory and initialize the
349      backward problem. */
350   retval = CVodeInitB(cvode_mem, indexB, fB, TB1, yB);
351   if (check_retval(&retval, "CVodeInitB", 1)) return(1);
352 
353   /* Set the scalar relative and absolute tolerances. */
354   retval = CVodeSStolerancesB(cvode_mem, indexB, reltolB, abstolB);
355   if (check_retval(&retval, "CVodeSStolerancesB", 1)) return(1);
356 
357   /* Attach the user data for backward problem. */
358   retval = CVodeSetUserDataB(cvode_mem, indexB, data);
359   if (check_retval(&retval, "CVodeSetUserDataB", 1)) return(1);
360 
361 /* Create sparse SUNMatrix for use in linear solves */
362   AB = SUNSparseMatrix(NEQ, NEQ, nnz, CSC_MAT);
363   if (check_retval((void *)A, "SUNSparseMatrix", 0)) return(1);
364 
365   /* Create SuperLUMT SUNLinearSolver object */
366   LSB = SUNLinSol_SuperLUMT(yB, AB, nthreads);
367   if (check_retval((void *)LSB, "SUNLinSol_SuperLUMT", 0)) return(1);
368 
369   /* Attach the matrix and linear solver for the backward problem */
370   retval = CVodeSetLinearSolverB(cvode_mem, indexB, LSB, AB);
371   if (check_retval(&retval, "CVodeSetLinearSolverB", 1)) return(1);
372 
373   /* Set the user-supplied Jacobian routine for the backward problem */
374   retval = CVodeSetJacFnB(cvode_mem, indexB, JacB);
375   if (check_retval(&retval, "CVodeSetJacFnB", 1)) return(1);
376 
377   /* Call CVodeQuadInitB to allocate internal memory and initialize backward
378      quadrature integration. */
379   retval = CVodeQuadInitB(cvode_mem, indexB, fQB, qB);
380   if (check_retval(&retval, "CVodeQuadInitB", 1)) return(1);
381 
382   /* Call CVodeSetQuadErrCon to specify whether or not the quadrature variables
383      are to be used in the step size control mechanism within CVODES. Call
384      CVodeQuadSStolerances or CVodeQuadSVtolerances to specify the integration
385      tolerances for the quadrature variables. */
386   retval = CVodeSetQuadErrConB(cvode_mem, indexB, SUNTRUE);
387   if (check_retval(&retval, "CVodeSetQuadErrConB", 1)) return(1);
388 
389   /* Call CVodeQuadSStolerancesB to specify the scalar relative and absolute tolerances
390      for the backward problem. */
391   retval = CVodeQuadSStolerancesB(cvode_mem, indexB, reltolB, abstolQB);
392   if (check_retval(&retval, "CVodeQuadSStolerancesB", 1)) return(1);
393 
394   /* Backward Integration */
395 
396   PrintHead(TB1);
397 
398   /* First get results at t = TBout1 */
399 
400   /* Call CVodeB to integrate the backward ODE problem. */
401   retval = CVodeB(cvode_mem, TBout1, CV_NORMAL);
402   if (check_retval(&retval, "CVodeB", 1)) return(1);
403 
404   /* Call CVodeGetB to get yB of the backward ODE problem. */
405   retval = CVodeGetB(cvode_mem, indexB, &time, yB);
406   if (check_retval(&retval, "CVodeGetB", 1)) return(1);
407 
408   /* Call CVodeGetAdjY to get the interpolated value of the forward solution
409      y during a backward integration. */
410   retval = CVodeGetAdjY(cvode_mem, TBout1, y);
411   if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
412 
413   PrintOutput1(time, TBout1, y, yB);
414 
415   /* Then at t = T0 */
416 
417   retval = CVodeB(cvode_mem, T0, CV_NORMAL);
418   if (check_retval(&retval, "CVodeB", 1)) return(1);
419   CVodeGetNumSteps(CVodeGetAdjCVodeBmem(cvode_mem, indexB), &nstB);
420   printf("Done ( nst = %ld )\n", nstB);
421 
422   retval = CVodeGetB(cvode_mem, indexB, &time, yB);
423   if (check_retval(&retval, "CVodeGetB", 1)) return(1);
424 
425   /* Call CVodeGetQuadB to get the quadrature solution vector after a
426      successful return from CVodeB. */
427   retval = CVodeGetQuadB(cvode_mem, indexB, &time, qB);
428   if (check_retval(&retval, "CVodeGetQuadB", 1)) return(1);
429 
430   retval = CVodeGetAdjY(cvode_mem, T0, y);
431   if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
432 
433   PrintOutput(time, y, yB, qB);
434 
435   /* Reinitialize backward phase (new tB0) */
436 
437   Ith(yB,1) = ZERO;
438   Ith(yB,2) = ZERO;
439   Ith(yB,3) = ZERO;
440 
441   Ith(qB,1) = ZERO;
442   Ith(qB,2) = ZERO;
443   Ith(qB,3) = ZERO;
444 
445   printf("Re-initialize CVODES memory for backward run\n");
446 
447   retval = CVodeReInitB(cvode_mem, indexB, TB2, yB);
448   if (check_retval(&retval, "CVodeReInitB", 1)) return(1);
449 
450   retval = CVodeQuadReInitB(cvode_mem, indexB, qB);
451   if (check_retval(&retval, "CVodeQuadReInitB", 1)) return(1);
452 
453   PrintHead(TB2);
454 
455   /* First get results at t = TBout1 */
456 
457   retval = CVodeB(cvode_mem, TBout1, CV_NORMAL);
458   if (check_retval(&retval, "CVodeB", 1)) return(1);
459 
460   retval = CVodeGetB(cvode_mem, indexB, &time, yB);
461   if (check_retval(&retval, "CVodeGetB", 1)) return(1);
462 
463   retval = CVodeGetAdjY(cvode_mem, TBout1, y);
464   if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
465 
466   PrintOutput1(time, TBout1, y, yB);
467 
468   /* Then at t = T0 */
469 
470   retval = CVodeB(cvode_mem, T0, CV_NORMAL);
471   if (check_retval(&retval, "CVodeB", 1)) return(1);
472   CVodeGetNumSteps(CVodeGetAdjCVodeBmem(cvode_mem, indexB), &nstB);
473   printf("Done ( nst = %ld )\n", nstB);
474 
475   retval = CVodeGetB(cvode_mem, indexB, &time, yB);
476   if (check_retval(&retval, "CVodeGetB", 1)) return(1);
477 
478   retval = CVodeGetQuadB(cvode_mem, indexB, &time, qB);
479   if (check_retval(&retval, "CVodeGetQuadB", 1)) return(1);
480 
481   retval = CVodeGetAdjY(cvode_mem, T0, y);
482   if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
483 
484   PrintOutput(time, y, yB, qB);
485 
486   /* Free memory */
487   printf("Free memory\n\n");
488 
489   CVodeFree(&cvode_mem);
490   N_VDestroy(y);
491   N_VDestroy(q);
492   N_VDestroy(yB);
493   N_VDestroy(qB);
494   SUNLinSolFree(LS);
495   SUNMatDestroy(A);
496   SUNLinSolFree(LSB);
497   SUNMatDestroy(AB);
498 
499   if (ckpnt != NULL) free(ckpnt);
500   free(data);
501 
502   return(0);
503 
504 }
505 
506 /*
507  *--------------------------------------------------------------------
508  * FUNCTIONS CALLED BY CVODES
509  *--------------------------------------------------------------------
510  */
511 
512 /*
513  * f routine. Compute f(t,y).
514  */
515 
f(realtype t,N_Vector y,N_Vector ydot,void * user_data)516 static int f(realtype t, N_Vector y, N_Vector ydot, void *user_data)
517 {
518   realtype y1, y2, y3, yd1, yd3;
519   UserData data;
520   realtype p1, p2, p3;
521 
522   y1 = Ith(y,1); y2 = Ith(y,2); y3 = Ith(y,3);
523   data = (UserData) user_data;
524   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
525 
526   yd1 = Ith(ydot,1) = -p1*y1 + p2*y2*y3;
527   yd3 = Ith(ydot,3) = p3*y2*y2;
528         Ith(ydot,2) = -yd1 - yd3;
529 
530   return(0);
531 }
532 
533 /*
534  * Jacobian routine. Compute J(t,y).
535  */
536 
Jac(realtype t,N_Vector y,N_Vector fy,SUNMatrix J,void * user_data,N_Vector tmp1,N_Vector tmp2,N_Vector tmp3)537 static int Jac(realtype t, N_Vector y, N_Vector fy, SUNMatrix J,
538                void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
539 {
540   realtype *yval;
541   sunindextype *colptrs = SUNSparseMatrix_IndexPointers(J);
542   sunindextype *rowvals = SUNSparseMatrix_IndexValues(J);
543   realtype *data = SUNSparseMatrix_Data(J);
544   UserData userdata;
545   realtype p1, p2, p3;
546 
547   yval = N_VGetArrayPointer(y);
548 
549   userdata = (UserData) user_data;
550   p1 = userdata->p[0]; p2 = userdata->p[1]; p3 = userdata->p[2];
551 
552   SUNMatZero(J);
553 
554   colptrs[0] = 0;
555   colptrs[1] = 3;
556   colptrs[2] = 6;
557   colptrs[3] = 9;
558 
559   data[0] = -p1;
560   rowvals[0] = 0;
561   data[1] = p1;
562   rowvals[1] = 1;
563   data[2] = ZERO;
564   rowvals[2] = 2;
565 
566   data[3] = p2*yval[2];
567   rowvals[3] = 0;
568   data[4] = -p2*yval[2]-2*p3*yval[1];
569   rowvals[4] = 1;
570   data[5] = 2*yval[1];
571   rowvals[5] = 2;
572 
573   data[6] = p2*yval[1];
574   rowvals[6] = 0;
575   data[7] = -p2*yval[1];
576   rowvals[7] = 1;
577   data[8] = ZERO;
578   rowvals[8] = 2;
579 
580   return(0);
581 }
582 
583 /*
584  * fQ routine. Compute fQ(t,y).
585  */
586 
fQ(realtype t,N_Vector y,N_Vector qdot,void * user_data)587 static int fQ(realtype t, N_Vector y, N_Vector qdot, void *user_data)
588 {
589   Ith(qdot,1) = Ith(y,3);
590 
591   return(0);
592 }
593 
594 /*
595  * EwtSet function. Computes the error weights at the current solution.
596  */
597 
ewt(N_Vector y,N_Vector w,void * user_data)598 static int ewt(N_Vector y, N_Vector w, void *user_data)
599 {
600   int i;
601   realtype yy, ww, rtol, atol[3];
602 
603   rtol    = RTOL;
604   atol[0] = ATOL1;
605   atol[1] = ATOL2;
606   atol[2] = ATOL3;
607 
608   for (i=1; i<=3; i++) {
609     yy = Ith(y,i);
610     ww = rtol * SUNRabs(yy) + atol[i-1];
611     if (ww <= 0.0) return (-1);
612     Ith(w,i) = 1.0/ww;
613   }
614 
615   return(0);
616 }
617 
618 /*
619  * fB routine. Compute fB(t,y,yB).
620  */
621 
fB(realtype t,N_Vector y,N_Vector yB,N_Vector yBdot,void * user_dataB)622 static int fB(realtype t, N_Vector y, N_Vector yB, N_Vector yBdot, void *user_dataB)
623 {
624   UserData data;
625   realtype y2, y3;
626   realtype p1, p2, p3;
627   realtype l1, l2, l3;
628   realtype l21, l32;
629 
630   data = (UserData) user_dataB;
631 
632   /* The p vector */
633   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
634 
635   /* The y vector */
636   y2 = Ith(y,2); y3 = Ith(y,3);
637 
638   /* The lambda vector */
639   l1 = Ith(yB,1); l2 = Ith(yB,2); l3 = Ith(yB,3);
640 
641   /* Temporary variables */
642   l21 = l2-l1;
643   l32 = l3-l2;
644 
645   /* Load yBdot */
646   Ith(yBdot,1) = - p1*l21;
647   Ith(yBdot,2) = p2*y3*l21 - RCONST(2.0)*p3*y2*l32;
648   Ith(yBdot,3) = p2*y2*l21 - RCONST(1.0);
649 
650   return(0);
651 }
652 
653 /*
654  * JacB routine. Compute JB(t,y,yB).
655  */
656 
JacB(realtype t,N_Vector y,N_Vector yB,N_Vector fyB,SUNMatrix JB,void * user_dataB,N_Vector tmp1B,N_Vector tmp2B,N_Vector tmp3B)657 static int JacB(realtype t,
658                 N_Vector y, N_Vector yB, N_Vector fyB,
659                 SUNMatrix JB, void *user_dataB,
660                 N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B)
661 {
662   realtype *yvalB;
663   sunindextype *colptrsB = SUNSparseMatrix_IndexPointers(JB);
664   sunindextype *rowvalsB = SUNSparseMatrix_IndexValues(JB);
665   realtype *dataB = SUNSparseMatrix_Data(JB);
666   UserData userdata;
667   realtype p1, p2, p3;
668 
669   yvalB = N_VGetArrayPointer(y);
670 
671   userdata = (UserData) user_dataB;
672   p1 = userdata->p[0]; p2 = userdata->p[1]; p3 = userdata->p[2];
673 
674   SUNMatZero(JB);
675 
676   colptrsB[0] = 0;
677   colptrsB[1] = 3;
678   colptrsB[2] = 6;
679   colptrsB[3] = 9;
680 
681   dataB[0] = p1;
682   rowvalsB[0] = 0;
683   dataB[1] = -p2*yvalB[2];
684   rowvalsB[1] = 1;
685   dataB[2] = -p2*yvalB[1];
686   rowvalsB[2] = 2;
687 
688   dataB[3] =-p1;
689   rowvalsB[3] = 0;
690   dataB[4] = p2*yvalB[2]+2*p3*yvalB[1];
691   rowvalsB[4] = 1;
692   dataB[5] = p2*yvalB[1];
693   rowvalsB[5] = 2;
694 
695   dataB[6] = ZERO;
696   rowvalsB[6] = 0;
697   dataB[7] = RCONST(-2.0)*p3*yvalB[1];
698   rowvalsB[7] = 1;
699   dataB[8] = ZERO;
700   rowvalsB[8] = 2;
701 
702   return(0);
703 }
704 
705 /*
706  * fQB routine. Compute integrand for quadratures
707  */
708 
fQB(realtype t,N_Vector y,N_Vector yB,N_Vector qBdot,void * user_dataB)709 static int fQB(realtype t, N_Vector y, N_Vector yB,
710                N_Vector qBdot, void *user_dataB)
711 {
712   realtype y1, y2, y3;
713   realtype l1, l2, l3;
714   realtype l21, l32, y23;
715 
716   /* The y vector */
717   y1 = Ith(y,1); y2 = Ith(y,2); y3 = Ith(y,3);
718 
719   /* The lambda vector */
720   l1 = Ith(yB,1); l2 = Ith(yB,2); l3 = Ith(yB,3);
721 
722   /* Temporary variables */
723   l21 = l2-l1;
724   l32 = l3-l2;
725   y23 = y2*y3;
726 
727   Ith(qBdot,1) = y1*l21;
728   Ith(qBdot,2) = - y23*l21;
729   Ith(qBdot,3) = y2*y2*l32;
730 
731   return(0);
732 }
733 
734 /*
735  *--------------------------------------------------------------------
736  * PRIVATE FUNCTIONS
737  *--------------------------------------------------------------------
738  */
739 
740 /*
741  * Print heading for backward integration
742  */
743 
PrintHead(realtype tB0)744 static void PrintHead(realtype tB0)
745 {
746 #if defined(SUNDIALS_EXTENDED_PRECISION)
747   printf("Backward integration from tB0 = %12.4Le\n\n",tB0);
748 #elif defined(SUNDIALS_DOUBLE_PRECISION)
749   printf("Backward integration from tB0 = %12.4e\n\n",tB0);
750 #else
751   printf("Backward integration from tB0 = %12.4e\n\n",tB0);
752 #endif
753 }
754 
755 /*
756  * Print intermediate results during backward integration
757  */
758 
PrintOutput1(realtype time,realtype t,N_Vector y,N_Vector yB)759 static void PrintOutput1(realtype time, realtype t, N_Vector y, N_Vector yB)
760 {
761   printf("--------------------------------------------------------\n");
762 #if defined(SUNDIALS_EXTENDED_PRECISION)
763   printf("returned t: %12.4Le\n",time);
764   printf("tout:       %12.4Le\n",t);
765   printf("lambda(t):  %12.4Le %12.4Le %12.4Le\n",
766          Ith(yB,1), Ith(yB,2), Ith(yB,3));
767   printf("y(t):       %12.4Le %12.4Le %12.4Le\n",
768          Ith(y,1), Ith(y,2), Ith(y,3));
769 #elif defined(SUNDIALS_DOUBLE_PRECISION)
770   printf("returned t: %12.4e\n",time);
771   printf("tout:       %12.4e\n",t);
772   printf("lambda(t):  %12.4e %12.4e %12.4e\n",
773          Ith(yB,1), Ith(yB,2), Ith(yB,3));
774   printf("y(t):       %12.4e %12.4e %12.4e\n",
775          Ith(y,1), Ith(y,2), Ith(y,3));
776 #else
777   printf("returned t: %12.4e\n",time);
778   printf("tout:       %12.4e\n",t);
779   printf("lambda(t):  %12.4e %12.4e %12.4e\n",
780          Ith(yB,1), Ith(yB,2), Ith(yB,3));
781   printf("y(t)      : %12.4e %12.4e %12.4e\n",
782          Ith(y,1), Ith(y,2), Ith(y,3));
783 #endif
784   printf("--------------------------------------------------------\n\n");
785 }
786 
787 /*
788  * Print final results of backward integration
789  */
790 
PrintOutput(realtype tfinal,N_Vector y,N_Vector yB,N_Vector qB)791 static void PrintOutput(realtype tfinal, N_Vector y, N_Vector yB, N_Vector qB)
792 {
793   printf("--------------------------------------------------------\n");
794 #if defined(SUNDIALS_EXTENDED_PRECISION)
795   printf("returned t: %12.4Le\n",tfinal);
796   printf("lambda(t0): %12.4Le %12.4Le %12.4Le\n",
797          Ith(yB,1), Ith(yB,2), Ith(yB,3));
798   printf("y(t0):      %12.4Le %12.4Le %12.4Le\n",
799          Ith(y,1), Ith(y,2), Ith(y,3));
800   printf("dG/dp:      %12.4Le %12.4Le %12.4Le\n",
801          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
802 #elif defined(SUNDIALS_DOUBLE_PRECISION)
803   printf("returned t: %12.4e\n",tfinal);
804   printf("lambda(t0): %12.4e %12.4e %12.4e\n",
805          Ith(yB,1), Ith(yB,2), Ith(yB,3));
806   printf("y(t0):      %12.4e %12.4e %12.4e\n",
807          Ith(y,1), Ith(y,2), Ith(y,3));
808   printf("dG/dp:      %12.4e %12.4e %12.4e\n",
809          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
810 #else
811   printf("returned t: %12.4e\n",tfinal);
812   printf("lambda(t0): %12.4e %12.4e %12.4e\n",
813          Ith(yB,1), Ith(yB,2), Ith(yB,3));
814   printf("y(t0)     : %12.4e %12.4e %12.4e\n",
815          Ith(y,1), Ith(y,2), Ith(y,3));
816   printf("dG/dp:      %12.4e %12.4e %12.4e\n",
817          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
818 #endif
819   printf("--------------------------------------------------------\n\n");
820 }
821 
822 /*
823  * Check function return value.
824  *    opt == 0 means SUNDIALS function allocates memory so check if
825  *             returned NULL pointer
826  *    opt == 1 means SUNDIALS function returns an integer value so check if
827  *             retval < 0
828  *    opt == 2 means function allocates memory so check if returned
829  *             NULL pointer
830  */
831 
check_retval(void * returnvalue,const char * funcname,int opt)832 static int check_retval(void *returnvalue, const char *funcname, int opt)
833 {
834   int *retval;
835 
836   /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
837   if (opt == 0 && returnvalue == NULL) {
838     fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
839 	    funcname);
840     return(1); }
841 
842   /* Check if retval < 0 */
843   else if (opt == 1) {
844     retval = (int *) returnvalue;
845     if (*retval < 0) {
846       fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with retval = %d\n\n",
847 	      funcname, *retval);
848       return(1); }}
849 
850   /* Check if function returned NULL pointer - no memory allocated */
851   else if (opt == 2 && returnvalue == NULL) {
852     fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
853 	    funcname);
854     return(1); }
855 
856   return(0);
857 }
858