1 /* -----------------------------------------------------------------
2  * Programmer(s): Jimmy Almgren-Bell @ LLNL
3  * Based on prior version by: Radu Serban @ LLNL
4  * -----------------------------------------------------------------
5  * SUNDIALS Copyright Start
6  * Copyright (c) 2002-2020, Lawrence Livermore National Security
7  * and Southern Methodist University.
8  * All rights reserved.
9  *
10  * See the top-level LICENSE and NOTICE files for details.
11  *
12  * SPDX-License-Identifier: BSD-3-Clause
13  * SUNDIALS Copyright End
14  * -----------------------------------------------------------------
15  * Adjoint sensitivity example problem.
16  * The following is a simple example problem, with the coding
17  * needed for its solution by CVODES. The problem is from chemical
18  * kinetics, and consists of the following three rate equations.
19  *    dy1/dt = -p1*y1 + p2*y2*y3
20  *    dy2/dt =  p1*y1 - p2*y2*y3 - p3*(y2)^2
21  *    dy3/dt =  p3*(y2)^2
22  * on the interval from t = 0.0 to t = 4.e10, with initial
23  * conditions: y1 = 1.0, y2 = y3 = 0. The reaction rates are:
24  * p1=0.04, p2=1e4, and p3=3e7. The problem is stiff.
25  * This program solves the problem with the BDF method, Newton
26  * iteration with the CVODE dense linear solver, and a user-supplied
27  * Jacobian routine.
28  * It uses a scalar relative tolerance and a vector absolute
29  * tolerance.
30  * The constraint y_i >= 0 is posed for all components.
31  * Output is printed in decades from t = .4 to t = 4.e10.
32  * Run statistics (optional outputs) are printed at the end.
33  *
34  * Optionally, CVODES can compute sensitivities with respect to
35  * the problem parameters p1, p2, and p3 of the following quantity:
36  *   G = int_t0^t1 g(t,p,y) dt
37  * where
38  *   g(t,p,y) = y3
39  *
40  * The gradient dG/dp is obtained as:
41  *   dG/dp = int_t0^t1 (g_p - lambda^T f_p ) dt - lambda^T(t0)*y0_p
42  *         = - xi^T(t0) - lambda^T(t0)*y0_p
43  * where lambda and xi are solutions of:
44  *   d(lambda)/dt = - (f_y)^T * lambda - (g_y)^T
45  *   lambda(t1) = 0
46  * and
47  *   d(xi)/dt = - (f_p)^T * lambda + (g_p)^T
48  *   xi(t1) = 0
49  *
50  * During the backward integration, CVODES also evaluates G as
51  *   G = - phi(t0)
52  * where
53  *   d(phi)/dt = g(t,y,p)
54  *   phi(t1) = 0
55  * -----------------------------------------------------------------*/
56 
57 #include <stdio.h>
58 #include <stdlib.h>
59 
60 #include <cvodes/cvodes.h>             /* prototypes for CVODE fcts., consts.  */
61 #include <nvector/nvector_serial.h>    /* access to serial N_Vector            */
62 #include <sunmatrix/sunmatrix_dense.h> /* access to dense SUNMatrix            */
63 #include <sunlinsol/sunlinsol_dense.h> /* access to dense SUNLinearSolver      */
64 #include <cvodes/cvodes_direct.h>      /* access to CVDls interface            */
65 #include <sundials/sundials_types.h>   /* defs. of realtype, sunindextype      */
66 #include <sundials/sundials_math.h>    /* defs. of SUNRabs, SUNRexp, etc.      */
67 
68 /* Accessor macros */
69 
70 #define Ith(v,i)    NV_Ith_S(v,i-1)         /* i-th vector component, i=1..NEQ */
71 #define IJth(A,i,j) SM_ELEMENT_D(A,i-1,j-1) /* (i,j)-th matrix el., i,j=1..NEQ */
72 
73 /* Problem Constants */
74 
75 #define NEQ      3             /* number of equations                  */
76 
77 #define RTOL     RCONST(1e-4)  /* scalar relative tolerance            */
78 
79 #define ATOL1    RCONST(1e-4)  /* vector absolute tolerance components */
80 #define ATOL2    RCONST(1e-8)
81 #define ATOL3    RCONST(1e-4)
82 
83 #define ATOLl    RCONST(1e-8)  /* absolute tolerance for adjoint vars. */
84 #define ATOLq    RCONST(1e-6)  /* absolute tolerance for quadratures   */
85 
86 #define T0       RCONST(0.0)   /* initial time                         */
87 #define TOUT     RCONST(4e7)   /* final time                           */
88 
89 #define TB1      RCONST(4e7)   /* starting point for adjoint problem   */
90 #define TB2      RCONST(50.0)  /* starting point for adjoint problem   */
91 #define TBout1   RCONST(40.0)  /* intermediate t for adjoint problem   */
92 
93 #define STEPS    150           /* number of steps between check points */
94 
95 #define NP       3             /* number of problem parameters         */
96 
97 #define ZERO     RCONST(0.0)
98 #define ONE      RCONST(1.0)
99 
100 /* Type : UserData */
101 
102 typedef struct {
103   realtype p[3];
104 } *UserData;
105 
106 /* Prototypes of user-supplied functions */
107 
108 static int f(realtype t, N_Vector y, N_Vector ydot, void *user_data);
109 static int Jac(realtype t, N_Vector y, N_Vector fy, SUNMatrix J,
110                void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
111 static int fQ(realtype t, N_Vector y, N_Vector qdot, void *user_data);
112 static int ewt(N_Vector y, N_Vector w, void *user_data);
113 
114 static int fB(realtype t, N_Vector y,
115               N_Vector yB, N_Vector yBdot, void *user_dataB);
116 static int JacB(realtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix JB,
117                 void *user_dataB, N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B);
118 static int fQB(realtype t, N_Vector y, N_Vector yB,
119                N_Vector qBdot, void *user_dataB);
120 
121 
122 /* Prototypes of private functions */
123 
124 static void PrintHead(realtype tB0);
125 static void PrintOutput(realtype tfinal, N_Vector y, N_Vector yB, N_Vector qB);
126 static void PrintOutput1(realtype time, realtype t, N_Vector y, N_Vector yB);
127 static int check_retval(void *returnvalue, const char *funcname, int opt);
128 
129 /*
130  *--------------------------------------------------------------------
131  * MAIN PROGRAM
132  *--------------------------------------------------------------------
133  */
134 
main(int argc,char * argv[])135 int main(int argc, char *argv[])
136 {
137   UserData data;
138 
139   SUNMatrix A, AB;
140   SUNLinearSolver LS, LSB;
141   void *cvode_mem;
142 
143   realtype reltolQ, abstolQ;
144   N_Vector y, q, constraints;
145 
146   int steps;
147 
148   int indexB;
149 
150   realtype reltolB, abstolB, abstolQB;
151   N_Vector yB, qB, constraintsB;
152 
153   realtype time;
154   int retval, ncheck;
155 
156   long int nst, nstB;
157 
158   CVadjCheckPointRec *ckpnt;
159 
160   data = NULL;
161   A = AB = NULL;
162   LS = LSB = NULL;
163   cvode_mem = NULL;
164   ckpnt = NULL;
165   y = yB = qB = NULL;
166   constraints = NULL;
167   constraintsB = NULL;
168 
169   /* Print problem description */
170   printf("\nAdjoint Sensitivity Example for Chemical Kinetics\n");
171   printf("-------------------------------------------------\n\n");
172   printf("ODE: dy1/dt = -p1*y1 + p2*y2*y3\n");
173   printf("     dy2/dt =  p1*y1 - p2*y2*y3 - p3*(y2)^2\n");
174   printf("     dy3/dt =  p3*(y2)^2\n\n");
175   printf("Find dG/dp for\n");
176   printf("     G = int_t0^tB0 g(t,p,y) dt\n");
177   printf("     g(t,p,y) = y3\n\n\n");
178 
179   /* User data structure */
180   data = (UserData) malloc(sizeof *data);
181   if (check_retval((void *)data, "malloc", 2)) return(1);
182   data->p[0] = RCONST(0.04);
183   data->p[1] = RCONST(1.0e4);
184   data->p[2] = RCONST(3.0e7);
185 
186   /* Initialize y */
187   y = N_VNew_Serial(NEQ);
188   if (check_retval((void *)y, "N_VNew_Serial", 0)) return(1);
189   Ith(y,1) = RCONST(1.0);
190   Ith(y,2) = ZERO;
191   Ith(y,3) = ZERO;
192 
193   /* Set constraints to all 1's for nonnegative solution values. */
194   constraints = N_VNew_Serial(NEQ);
195   if(check_retval((void *)constraints, "N_VNew_Serial", 0)) return(1);
196   N_VConst(ONE, constraints);
197 
198   /* Initialize q */
199   q = N_VNew_Serial(1);
200   if (check_retval((void *)q, "N_VNew_Serial", 0)) return(1);
201   Ith(q,1) = ZERO;
202 
203   /* Set the scalar realtive and absolute tolerances reltolQ and abstolQ */
204   reltolQ = RTOL;
205   abstolQ = ATOLq;
206 
207   /* Create and allocate CVODES memory for forward run */
208   printf("Create and allocate CVODES memory for forward runs\n");
209 
210   /* Call CVodeCreate to create the solver memory and specify the
211      Backward Differentiation Formula */
212   cvode_mem = CVodeCreate(CV_BDF);
213   if (check_retval((void *)cvode_mem, "CVodeCreate", 0)) return(1);
214 
215   /* Call CVodeInit to initialize the integrator memory and specify the
216      user's right hand side function in y'=f(t,y), the initial time T0, and
217      the initial dependent variable vector y. */
218   retval = CVodeInit(cvode_mem, f, T0, y);
219   if (check_retval(&retval, "CVodeInit", 1)) return(1);
220 
221   /* Call CVodeWFtolerances to specify a user-supplied function ewt that sets
222      the multiplicative error weights w_i for use in the weighted RMS norm */
223   retval = CVodeWFtolerances(cvode_mem, ewt);
224   if (check_retval(&retval, "CVodeWFtolerances", 1)) return(1);
225 
226   /* Attach user data */
227   retval = CVodeSetUserData(cvode_mem, data);
228   if (check_retval(&retval, "CVodeSetUserData", 1)) return(1);
229 
230   /* Call CVodeSetConstraints to initialize constraints */
231   retval = CVodeSetConstraints(cvode_mem, constraints);
232   if (check_retval(&retval, "CVODESetConstraints", 1)) return(1);
233   N_VDestroy(constraints);
234 
235   /* Create dense SUNMatrix for use in linear solves */
236   A = SUNDenseMatrix(NEQ, NEQ);
237   if (check_retval((void *)A, "SUNDenseMatrix", 0)) return(1);
238 
239   /* Create dense SUNLinearSolver object */
240   LS = SUNLinSol_Dense(y, A);
241   if (check_retval((void *)LS, "SUNLinSol_Dense", 0)) return(1);
242 
243   /* Attach the matrix and linear solver */
244   retval = CVDlsSetLinearSolver(cvode_mem, LS, A);
245   if (check_retval(&retval, "CVDlsSetLinearSolver", 1)) return(1);
246 
247   /* Set the user-supplied Jacobian routine Jac */
248   retval = CVDlsSetJacFn(cvode_mem, Jac);
249   if (check_retval(&retval, "CVDlsSetJacFn", 1)) return(1);
250 
251   /* Call CVodeQuadInit to allocate initernal memory and initialize
252      quadrature integration*/
253   retval = CVodeQuadInit(cvode_mem, fQ, q);
254   if (check_retval(&retval, "CVodeQuadInit", 1)) return(1);
255 
256   /* Call CVodeSetQuadErrCon to specify whether or not the quadrature variables
257      are to be used in the step size control mechanism within CVODES. Call
258      CVodeQuadSStolerances or CVodeQuadSVtolerances to specify the integration
259      tolerances for the quadrature variables. */
260   retval = CVodeSetQuadErrCon(cvode_mem, SUNTRUE);
261   if (check_retval(&retval, "CVodeSetQuadErrCon", 1)) return(1);
262 
263   /* Call CVodeQuadSStolerances to specify scalar relative and absolute
264      tolerances. */
265   retval = CVodeQuadSStolerances(cvode_mem, reltolQ, abstolQ);
266   if (check_retval(&retval, "CVodeQuadSStolerances", 1)) return(1);
267 
268   /* Allocate global memory */
269 
270   /* Call CVodeAdjInit to update CVODES memory block by allocting the internal
271      memory needed for backward integration.*/
272   steps = STEPS; /* no. of integration steps between two consecutive ckeckpoints*/
273   retval = CVodeAdjInit(cvode_mem, steps, CV_HERMITE);
274   /*
275   retval = CVodeAdjInit(cvode_mem, steps, CV_POLYNOMIAL);
276   */
277   if (check_retval(&retval, "CVodeAdjInit", 1)) return(1);
278 
279   /* Perform forward run */
280   printf("Forward integration ... ");
281 
282   /* Call CVodeF to integrate the forward problem over an interval in time and
283      saves checkpointing data */
284   retval = CVodeF(cvode_mem, TOUT, y, &time, CV_NORMAL, &ncheck);
285   if (check_retval(&retval, "CVodeF", 1)) return(1);
286   retval = CVodeGetNumSteps(cvode_mem, &nst);
287   if (check_retval(&retval, "CVodeGetNumSteps", 1)) return(1);
288 
289   printf("done ( nst = %ld )\n",nst);
290   printf("\nncheck = %d\n\n", ncheck);
291 
292   retval = CVodeGetQuad(cvode_mem, &time, q);
293   if (check_retval(&retval, "CVodeGetQuad", 1)) return(1);
294 
295   printf("--------------------------------------------------------\n");
296 #if defined(SUNDIALS_EXTENDED_PRECISION)
297   printf("G:          %12.4Le \n",Ith(q,1));
298 #elif defined(SUNDIALS_DOUBLE_PRECISION)
299   printf("G:          %12.4e \n",Ith(q,1));
300 #else
301   printf("G:          %12.4e \n",Ith(q,1));
302 #endif
303   printf("--------------------------------------------------------\n\n");
304 
305   /* Test check point linked list
306      (uncomment next block to print check point information) */
307 
308   /*
309   {
310     int i;
311 
312     printf("\nList of Check Points (ncheck = %d)\n\n", ncheck);
313     ckpnt = (CVadjCheckPointRec *) malloc ( (ncheck+1)*sizeof(CVadjCheckPointRec));
314     CVodeGetAdjCheckPointsInfo(cvode_mem, ckpnt);
315     for (i=0;i<=ncheck;i++) {
316       printf("Address:       %p\n",ckpnt[i].my_addr);
317       printf("Next:          %p\n",ckpnt[i].next_addr);
318       printf("Time interval: %le  %le\n",ckpnt[i].t0, ckpnt[i].t1);
319       printf("Step number:   %ld\n",ckpnt[i].nstep);
320       printf("Order:         %d\n",ckpnt[i].order);
321       printf("Step size:     %le\n",ckpnt[i].step);
322       printf("\n");
323     }
324 
325   }
326   */
327 
328   /* Initialize yB */
329   yB = N_VNew_Serial(NEQ);
330   if (check_retval((void *)yB, "N_VNew_Serial", 0)) return(1);
331   Ith(yB,1) = ZERO;
332   Ith(yB,2) = ZERO;
333   Ith(yB,3) = ZERO;
334 
335   /* Initialize qB */
336   qB = N_VNew_Serial(NP);
337   if (check_retval((void *)qB, "N_VNew", 0)) return(1);
338   Ith(qB,1) = ZERO;
339   Ith(qB,2) = ZERO;
340   Ith(qB,3) = ZERO;
341 
342   /* Set the scalar relative tolerance reltolB */
343   reltolB = RTOL;
344 
345   /* Set the scalar absolute tolerance abstolB */
346   abstolB = ATOLl;
347 
348   /* Set the scalar absolute tolerance abstolQB */
349   abstolQB = ATOLq;
350 
351   /* Set constraints to all 1's for nonnegative solution values. */
352   constraintsB = N_VNew_Serial(NEQ);
353   if(check_retval((void *)constraintsB, "N_VNew_Serial", 0)) return(1);
354   N_VConst(ONE, constraintsB);
355 
356   /* Create and allocate CVODES memory for backward run */
357   printf("Create and allocate CVODES memory for backward run\n");
358 
359   /* Call CVodeCreateB to specify the solution method for the backward
360      problem. */
361   retval = CVodeCreateB(cvode_mem, CV_BDF, &indexB);
362   if (check_retval(&retval, "CVodeCreateB", 1)) return(1);
363 
364   /* Call CVodeInitB to allocate internal memory and initialize the
365      backward problem. */
366   retval = CVodeInitB(cvode_mem, indexB, fB, TB1, yB);
367   if (check_retval(&retval, "CVodeInitB", 1)) return(1);
368 
369   /* Set the scalar relative and absolute tolerances. */
370   retval = CVodeSStolerancesB(cvode_mem, indexB, reltolB, abstolB);
371   if (check_retval(&retval, "CVodeSStolerancesB", 1)) return(1);
372 
373   /* Attach the user data for backward problem. */
374   retval = CVodeSetUserDataB(cvode_mem, indexB, data);
375   if (check_retval(&retval, "CVodeSetUserDataB", 1)) return(1);
376 
377   /* Call CVodeSetConstraintsB to initialize constraints */
378   retval = CVodeSetConstraintsB(cvode_mem, indexB, constraintsB);
379   if(check_retval(&retval, "CVodeSetConstraintsB", 1)) return(1);
380   N_VDestroy(constraintsB);
381 
382   /* Create dense SUNMatrix for use in linear solves */
383   AB = SUNDenseMatrix(NEQ, NEQ);
384   if (check_retval((void *)AB, "SUNDenseMatrix", 0)) return(1);
385 
386   /* Create dense SUNLinearSolver object */
387   LSB = SUNLinSol_Dense(yB, AB);
388   if (check_retval((void *)LSB, "SUNLinSol_Dense", 0)) return(1);
389 
390   /* Attach the matrix and linear solver */
391   retval = CVDlsSetLinearSolverB(cvode_mem, indexB, LSB, AB);
392   if (check_retval(&retval, "CVDlsSetLinearSolverB", 1)) return(1);
393 
394   /* Set the user-supplied Jacobian routine JacB */
395   retval = CVDlsSetJacFnB(cvode_mem, indexB, JacB);
396   if (check_retval(&retval, "CVDlsSetJacFnB", 1)) return(1);
397 
398   /* Call CVodeQuadInitB to allocate internal memory and initialize backward
399      quadrature integration. */
400   retval = CVodeQuadInitB(cvode_mem, indexB, fQB, qB);
401   if (check_retval(&retval, "CVodeQuadInitB", 1)) return(1);
402 
403   /* Call CVodeSetQuadErrCon to specify whether or not the quadrature variables
404      are to be used in the step size control mechanism within CVODES. Call
405      CVodeQuadSStolerances or CVodeQuadSVtolerances to specify the integration
406      tolerances for the quadrature variables. */
407   retval = CVodeSetQuadErrConB(cvode_mem, indexB, SUNTRUE);
408   if (check_retval(&retval, "CVodeSetQuadErrConB", 1)) return(1);
409 
410   /* Call CVodeQuadSStolerancesB to specify the scalar relative and absolute tolerances
411      for the backward problem. */
412   retval = CVodeQuadSStolerancesB(cvode_mem, indexB, reltolB, abstolQB);
413   if (check_retval(&retval, "CVodeQuadSStolerancesB", 1)) return(1);
414 
415   /* Backward Integration */
416 
417   PrintHead(TB1);
418 
419   /* First get results at t = TBout1 */
420 
421   /* Call CVodeB to integrate the backward ODE problem. */
422   retval = CVodeB(cvode_mem, TBout1, CV_NORMAL);
423   if (check_retval(&retval, "CVodeB", 1)) return(1);
424 
425   /* Call CVodeGetB to get yB of the backward ODE problem. */
426   retval = CVodeGetB(cvode_mem, indexB, &time, yB);
427   if (check_retval(&retval, "CVodeGetB", 1)) return(1);
428 
429   /* Call CVodeGetAdjY to get the interpolated value of the forward solution
430      y during a backward integration. */
431   retval = CVodeGetAdjY(cvode_mem, TBout1, y);
432   if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
433 
434   PrintOutput1(time, TBout1, y, yB);
435 
436   /* Then at t = T0 */
437 
438   retval = CVodeB(cvode_mem, T0, CV_NORMAL);
439   if (check_retval(&retval, "CVodeB", 1)) return(1);
440   CVodeGetNumSteps(CVodeGetAdjCVodeBmem(cvode_mem, indexB), &nstB);
441   printf("Done ( nst = %ld )\n", nstB);
442 
443   retval = CVodeGetB(cvode_mem, indexB, &time, yB);
444   if (check_retval(&retval, "CVodeGetB", 1)) return(1);
445 
446   /* Call CVodeGetQuadB to get the quadrature solution vector after a
447      successful return from CVodeB. */
448   retval = CVodeGetQuadB(cvode_mem, indexB, &time, qB);
449   if (check_retval(&retval, "CVodeGetQuadB", 1)) return(1);
450 
451   retval = CVodeGetAdjY(cvode_mem, T0, y);
452   if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
453 
454   PrintOutput(time, y, yB, qB);
455 
456   /* Reinitialize backward phase (new tB0) */
457 
458   Ith(yB,1) = ZERO;
459   Ith(yB,2) = ZERO;
460   Ith(yB,3) = ZERO;
461 
462   Ith(qB,1) = ZERO;
463   Ith(qB,2) = ZERO;
464   Ith(qB,3) = ZERO;
465 
466   printf("Re-initialize CVODES memory for backward run\n");
467 
468   retval = CVodeReInitB(cvode_mem, indexB, TB2, yB);
469   if (check_retval(&retval, "CVodeReInitB", 1)) return(1);
470 
471   retval = CVodeQuadReInitB(cvode_mem, indexB, qB);
472   if (check_retval(&retval, "CVodeQuadReInitB", 1)) return(1);
473 
474   PrintHead(TB2);
475 
476   /* First get results at t = TBout1 */
477 
478   retval = CVodeB(cvode_mem, TBout1, CV_NORMAL);
479   if (check_retval(&retval, "CVodeB", 1)) return(1);
480 
481   retval = CVodeGetB(cvode_mem, indexB, &time, yB);
482   if (check_retval(&retval, "CVodeGetB", 1)) return(1);
483 
484   retval = CVodeGetAdjY(cvode_mem, TBout1, y);
485   if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
486 
487   PrintOutput1(time, TBout1, y, yB);
488 
489   /* Then at t = T0 */
490 
491   retval = CVodeB(cvode_mem, T0, CV_NORMAL);
492   if (check_retval(&retval, "CVodeB", 1)) return(1);
493   CVodeGetNumSteps(CVodeGetAdjCVodeBmem(cvode_mem, indexB), &nstB);
494   printf("Done ( nst = %ld )\n", nstB);
495 
496   retval = CVodeGetB(cvode_mem, indexB, &time, yB);
497   if (check_retval(&retval, "CVodeGetB", 1)) return(1);
498 
499   retval = CVodeGetQuadB(cvode_mem, indexB, &time, qB);
500   if (check_retval(&retval, "CVodeGetQuadB", 1)) return(1);
501 
502   retval = CVodeGetAdjY(cvode_mem, T0, y);
503   if (check_retval(&retval, "CVodeGetAdjY", 1)) return(1);
504 
505   PrintOutput(time, y, yB, qB);
506 
507   /* Free memory */
508   printf("Free memory\n\n");
509 
510   CVodeFree(&cvode_mem);
511   N_VDestroy(y);
512   N_VDestroy(q);
513   N_VDestroy(yB);
514   N_VDestroy(qB);
515   SUNLinSolFree(LS);
516   SUNMatDestroy(A);
517   SUNLinSolFree(LSB);
518   SUNMatDestroy(AB);
519 
520   if (ckpnt != NULL) free(ckpnt);
521   free(data);
522 
523   return(0);
524 
525 }
526 
527 /*
528  *--------------------------------------------------------------------
529  * FUNCTIONS CALLED BY CVODES
530  *--------------------------------------------------------------------
531  */
532 
533 /*
534  * f routine. Compute f(t,y).
535 */
536 
f(realtype t,N_Vector y,N_Vector ydot,void * user_data)537 static int f(realtype t, N_Vector y, N_Vector ydot, void *user_data)
538 {
539   realtype y1, y2, y3, yd1, yd3;
540   UserData data;
541   realtype p1, p2, p3;
542 
543   y1 = Ith(y,1); y2 = Ith(y,2); y3 = Ith(y,3);
544   data = (UserData) user_data;
545   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
546 
547   yd1 = Ith(ydot,1) = -p1*y1 + p2*y2*y3;
548   yd3 = Ith(ydot,3) = p3*y2*y2;
549         Ith(ydot,2) = -yd1 - yd3;
550 
551   return(0);
552 }
553 
554 /*
555  * Jacobian routine. Compute J(t,y).
556 */
557 
Jac(realtype t,N_Vector y,N_Vector fy,SUNMatrix J,void * user_data,N_Vector tmp1,N_Vector tmp2,N_Vector tmp3)558 static int Jac(realtype t, N_Vector y, N_Vector fy, SUNMatrix J,
559                void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
560 {
561   realtype y2, y3;
562   UserData data;
563   realtype p1, p2, p3;
564 
565   y2 = Ith(y,2); y3 = Ith(y,3);
566   data = (UserData) user_data;
567   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
568 
569   IJth(J,1,1) = -p1;  IJth(J,1,2) = p2*y3;          IJth(J,1,3) = p2*y2;
570   IJth(J,2,1) =  p1;  IJth(J,2,2) = -p2*y3-2*p3*y2; IJth(J,2,3) = -p2*y2;
571   IJth(J,3,1) = ZERO; IJth(J,3,2) = 2*p3*y2;        IJth(J,3,3) = ZERO;
572 
573   return(0);
574 }
575 
576 /*
577  * fQ routine. Compute fQ(t,y).
578 */
579 
fQ(realtype t,N_Vector y,N_Vector qdot,void * user_data)580 static int fQ(realtype t, N_Vector y, N_Vector qdot, void *user_data)
581 {
582   Ith(qdot,1) = Ith(y,3);
583 
584   return(0);
585 }
586 
587 /*
588  * EwtSet function. Computes the error weights at the current solution.
589  */
590 
ewt(N_Vector y,N_Vector w,void * user_data)591 static int ewt(N_Vector y, N_Vector w, void *user_data)
592 {
593   int i;
594   realtype yy, ww, rtol, atol[3];
595 
596   rtol    = RTOL;
597   atol[0] = ATOL1;
598   atol[1] = ATOL2;
599   atol[2] = ATOL3;
600 
601   for (i=1; i<=3; i++) {
602     yy = Ith(y,i);
603     ww = rtol * SUNRabs(yy) + atol[i-1];
604     if (ww <= 0.0) return (-1);
605     Ith(w,i) = 1.0/ww;
606   }
607 
608   return(0);
609 }
610 
611 /*
612  * fB routine. Compute fB(t,y,yB).
613 */
614 
fB(realtype t,N_Vector y,N_Vector yB,N_Vector yBdot,void * user_dataB)615 static int fB(realtype t, N_Vector y, N_Vector yB, N_Vector yBdot, void *user_dataB)
616 {
617   UserData data;
618   realtype y2, y3;
619   realtype p1, p2, p3;
620   realtype l1, l2, l3;
621   realtype l21, l32;
622 
623   data = (UserData) user_dataB;
624 
625   /* The p vector */
626   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
627 
628   /* The y vector */
629   y2 = Ith(y,2); y3 = Ith(y,3);
630 
631   /* The lambda vector */
632   l1 = Ith(yB,1); l2 = Ith(yB,2); l3 = Ith(yB,3);
633 
634   /* Temporary variables */
635   l21 = l2-l1;
636   l32 = l3-l2;
637 
638   /* Load yBdot */
639   Ith(yBdot,1) = - p1*l21;
640   Ith(yBdot,2) = p2*y3*l21 - RCONST(2.0)*p3*y2*l32;
641   Ith(yBdot,3) = p2*y2*l21 - RCONST(1.0);
642 
643   return(0);
644 }
645 
646 /*
647  * JacB routine. Compute JB(t,y,yB).
648  */
649 
JacB(realtype t,N_Vector y,N_Vector yB,N_Vector fyB,SUNMatrix JB,void * user_dataB,N_Vector tmp1B,N_Vector tmp2B,N_Vector tmp3B)650 static int JacB(realtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix JB,
651                 void *user_dataB, N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B)
652 {
653   UserData data;
654   realtype y2, y3;
655   realtype p1, p2, p3;
656 
657   data = (UserData) user_dataB;
658 
659   /* The p vector */
660   p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
661 
662   /* The y vector */
663   y2 = Ith(y,2); y3 = Ith(y,3);
664 
665   /* Load JB */
666   IJth(JB,1,1) = p1;     IJth(JB,1,2) = -p1;             IJth(JB,1,3) = ZERO;
667   IJth(JB,2,1) = -p2*y3; IJth(JB,2,2) = p2*y3+2.0*p3*y2; IJth(JB,2,3) = RCONST(-2.0)*p3*y2;
668   IJth(JB,3,1) = -p2*y2; IJth(JB,3,2) = p2*y2;           IJth(JB,3,3) = ZERO;
669 
670   return(0);
671 }
672 
673 /*
674  * fQB routine. Compute integrand for quadratures
675 */
676 
fQB(realtype t,N_Vector y,N_Vector yB,N_Vector qBdot,void * user_dataB)677 static int fQB(realtype t, N_Vector y, N_Vector yB,
678                N_Vector qBdot, void *user_dataB)
679 {
680   realtype y1, y2, y3;
681   realtype l1, l2, l3;
682   realtype l21, l32, y23;
683 
684   /* The y vector */
685   y1 = Ith(y,1); y2 = Ith(y,2); y3 = Ith(y,3);
686 
687   /* The lambda vector */
688   l1 = Ith(yB,1); l2 = Ith(yB,2); l3 = Ith(yB,3);
689 
690   /* Temporary variables */
691   l21 = l2-l1;
692   l32 = l3-l2;
693   y23 = y2*y3;
694 
695   Ith(qBdot,1) = y1*l21;
696   Ith(qBdot,2) = - y23*l21;
697   Ith(qBdot,3) = y2*y2*l32;
698 
699   return(0);
700 }
701 
702 /*
703  *--------------------------------------------------------------------
704  * PRIVATE FUNCTIONS
705  *--------------------------------------------------------------------
706  */
707 
708 /*
709  * Print heading for backward integration
710  */
711 
PrintHead(realtype tB0)712 static void PrintHead(realtype tB0)
713 {
714 #if defined(SUNDIALS_EXTENDED_PRECISION)
715   printf("Backward integration from tB0 = %12.4Le\n\n",tB0);
716 #elif defined(SUNDIALS_DOUBLE_PRECISION)
717   printf("Backward integration from tB0 = %12.4e\n\n",tB0);
718 #else
719   printf("Backward integration from tB0 = %12.4e\n\n",tB0);
720 #endif
721 }
722 
723 /*
724  * Print intermediate results during backward integration
725  */
726 
PrintOutput1(realtype time,realtype t,N_Vector y,N_Vector yB)727 static void PrintOutput1(realtype time, realtype t, N_Vector y, N_Vector yB)
728 {
729   printf("--------------------------------------------------------\n");
730 #if defined(SUNDIALS_EXTENDED_PRECISION)
731   printf("returned t: %12.4Le\n",time);
732   printf("tout:       %12.4Le\n",t);
733   printf("lambda(t):  %12.4Le %12.4Le %12.4Le\n",
734          Ith(yB,1), Ith(yB,2), Ith(yB,3));
735   printf("y(t):       %12.4Le %12.4Le %12.4Le\n",
736          Ith(y,1), Ith(y,2), Ith(y,3));
737 #elif defined(SUNDIALS_DOUBLE_PRECISION)
738   printf("returned t: %12.4e\n",time);
739   printf("tout:       %12.4e\n",t);
740   printf("lambda(t):  %12.4e %12.4e %12.4e\n",
741          Ith(yB,1), Ith(yB,2), Ith(yB,3));
742   printf("y(t):       %12.4e %12.4e %12.4e\n",
743          Ith(y,1), Ith(y,2), Ith(y,3));
744 #else
745   printf("returned t: %12.4e\n",time);
746   printf("tout:       %12.4e\n",t);
747   printf("lambda(t):  %12.4e %12.4e %12.4e\n",
748          Ith(yB,1), Ith(yB,2), Ith(yB,3));
749   printf("y(t)      : %12.4e %12.4e %12.4e\n",
750          Ith(y,1), Ith(y,2), Ith(y,3));
751 #endif
752   printf("--------------------------------------------------------\n\n");
753 }
754 
755 /*
756  * Print final results of backward integration
757  */
758 
PrintOutput(realtype tfinal,N_Vector y,N_Vector yB,N_Vector qB)759 static void PrintOutput(realtype tfinal, N_Vector y, N_Vector yB, N_Vector qB)
760 {
761   printf("--------------------------------------------------------\n");
762 #if defined(SUNDIALS_EXTENDED_PRECISION)
763   printf("returned t: %12.4Le\n",tfinal);
764   printf("lambda(t0): %12.4Le %12.4Le %12.4Le\n",
765          Ith(yB,1), Ith(yB,2), Ith(yB,3));
766   printf("y(t0):      %12.4Le %12.4Le %12.4Le\n",
767          Ith(y,1), Ith(y,2), Ith(y,3));
768   printf("dG/dp:      %12.4Le %12.4Le %12.4Le\n",
769          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
770 #elif defined(SUNDIALS_DOUBLE_PRECISION)
771   printf("returned t: %12.4e\n",tfinal);
772   printf("lambda(t0): %12.4e %12.4e %12.4e\n",
773          Ith(yB,1), Ith(yB,2), Ith(yB,3));
774   printf("y(t0):      %12.4e %12.4e %12.4e\n",
775          Ith(y,1), Ith(y,2), Ith(y,3));
776   printf("dG/dp:      %12.4e %12.4e %12.4e\n",
777          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
778 #else
779   printf("returned t: %12.4e\n",tfinal);
780   printf("lambda(t0): %12.4e %12.4e %12.4e\n",
781          Ith(yB,1), Ith(yB,2), Ith(yB,3));
782   printf("y(t0)     : %12.4e %12.4e %12.4e\n",
783          Ith(y,1), Ith(y,2), Ith(y,3));
784   printf("dG/dp:      %12.4e %12.4e %12.4e\n",
785          -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
786 #endif
787   printf("--------------------------------------------------------\n\n");
788 }
789 
790 /*
791  * Check function return value.
792  *    opt == 0 means SUNDIALS function allocates memory so check if
793  *             returned NULL pointer
794  *    opt == 1 means SUNDIALS function returns an integer value so check if
795  *             retval < 0
796  *    opt == 2 means function allocates memory so check if returned
797  *             NULL pointer
798  */
799 
check_retval(void * returnvalue,const char * funcname,int opt)800 static int check_retval(void *returnvalue, const char *funcname, int opt)
801 {
802   int *retval;
803 
804   /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
805   if (opt == 0 && returnvalue == NULL) {
806     fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
807 	    funcname);
808     return(1); }
809 
810   /* Check if retval < 0 */
811   else if (opt == 1) {
812     retval = (int *) returnvalue;
813     if (*retval < 0) {
814       fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with retval = %d\n\n",
815 	      funcname, *retval);
816       return(1); }}
817 
818   /* Check if function returned NULL pointer - no memory allocated */
819   else if (opt == 2 && returnvalue == NULL) {
820     fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
821 	    funcname);
822     return(1); }
823 
824   return(0);
825 }
826