1 /* -----------------------------------------------------------------
2  * Programmer(s): Radu Serban @ LLNL
3  * -----------------------------------------------------------------
4  * SUNDIALS Copyright Start
5  * Copyright (c) 2002-2020, Lawrence Livermore National Security
6  * and Southern Methodist University.
7  * All rights reserved.
8  *
9  * See the top-level LICENSE and NOTICE files for details.
10  *
11  * SPDX-License-Identifier: BSD-3-Clause
12  * SUNDIALS Copyright End
13  * -----------------------------------------------------------------
14  *
15  * Hessian through adjoint sensitivity example problem.
16  *
17  *        [ - p1 * y1^2 - y3 ]           [ 1 ]
18  *   y' = [    - y2          ]    y(0) = [ 1 ]
19  *        [ -p2^2 * y2 * y3  ]           [ 1 ]
20  *
21  *   p1 = 1.0
22  *   p2 = 2.0
23  *
24  *           2
25  *          /
26  *   G(p) = |  0.5 * ( y1^2 + y2^2 + y3^2 ) dt
27  *          /
28  *          0
29  *
30  * Compute the gradient (ASA) and Hessian (FSA over ASA) of G(p).
31  *
32  * See D.B. Ozyurt and P.I. Barton, SISC 26(5) 1725-1743, 2005.
33  *
34  * -----------------------------------------------------------------*/
35 
36 #include <stdio.h>
37 #include <stdlib.h>
38 
39 #include <cvodes/cvodes.h>              /* prototypes for CVODES fcts., consts. */
40 #include <nvector/nvector_serial.h>     /* access to serial N_Vector            */
41 #include <sunmatrix/sunmatrix_dense.h>  /* access to band SUNMatrix             */
42 #include <sunlinsol/sunlinsol_dense.h>  /* access to band SUNLinearSolver       */
43 #include <sundials/sundials_math.h>     /* definition of SUNRabs, SUNRexp, etc. */
44 
45 #define Ith(v,i)    NV_Ith_S(v,i-1)
46 
47 #define ZERO RCONST(0.0)
48 #define ONE  RCONST(1.0)
49 
50 typedef struct {
51   realtype p1, p2;
52 } *UserData;
53 
54 static int f(realtype t, N_Vector y, N_Vector ydot, void *user_data);
55 static int fQ(realtype t, N_Vector y, N_Vector qdot, void *user_data);
56 static int fS(int Ns, realtype t,
57               N_Vector y, N_Vector ydot,
58               N_Vector *yS, N_Vector *ySdot,
59               void *user_data,
60               N_Vector tmp1, N_Vector tmp2);
61 static int fQS(int Ns, realtype t,
62                N_Vector y, N_Vector *yS,
63                N_Vector yQdot, N_Vector *yQSdot,
64                void *user_data,
65                N_Vector tmp, N_Vector tmpQ);
66 
67 static int fB1(realtype t, N_Vector y, N_Vector *yS,
68                N_Vector yB, N_Vector yBdot, void *user_dataB);
69 static int fQB1(realtype t, N_Vector y, N_Vector *yS,
70                 N_Vector yB, N_Vector qBdot, void *user_dataB);
71 
72 
73 static int fB2(realtype t, N_Vector y, N_Vector *yS,
74                N_Vector yB, N_Vector yBdot, void *user_dataB);
75 static int fQB2(realtype t, N_Vector y, N_Vector *yS,
76                 N_Vector yB, N_Vector qBdot, void *user_dataB);
77 
78 int PrintFwdStats(void *cvode_mem);
79 int PrintBckStats(void *cvode_mem, int idx);
80 
81 /* Private function to check function return values */
82 
83 static int check_retval(void *returnvalue, const char *funcname, int opt);
84 
85 /*
86  *--------------------------------------------------------------------
87  * MAIN PROGRAM
88  *--------------------------------------------------------------------
89  */
90 
main(int argc,char * argv[])91 int main(int argc, char *argv[])
92 {
93   UserData data;
94 
95   SUNMatrix A, AB1, AB2;
96   SUNLinearSolver LS, LSB1, LSB2;
97   void *cvode_mem;
98 
99   sunindextype Neq, Np2;
100   int Np;
101 
102   realtype t0, tf;
103 
104   realtype reltol;
105   realtype abstol, abstolQ, abstolB, abstolQB;
106 
107   N_Vector y, yQ;
108   N_Vector *yS, *yQS;
109   N_Vector yB1, yB2;
110   N_Vector yQB1, yQB2;
111 
112   int steps, ncheck;
113   int indexB1, indexB2;
114 
115   int retval;
116   realtype time;
117 
118   realtype dp;
119   realtype G, Gp, Gm;
120   realtype grdG_fwd[2], grdG_bck[2], grdG_cntr[2];
121   realtype H11, H22;
122 
123   data = NULL;
124   y = yQ = NULL;
125   yB1 = yB2 = NULL;
126   yQB1 = yQB2 = NULL;
127   A = AB1 = AB2 = NULL;
128   LS = LSB1 = LSB2 = NULL;
129   cvode_mem = NULL;
130 
131   /* User data structure */
132 
133   data = (UserData) malloc(sizeof *data);
134   data->p1 = RCONST(1.0);
135   data->p2 = RCONST(2.0);
136 
137   /* Problem size, integration interval, and tolerances */
138 
139   Neq = 3;
140   Np  = 2;
141   Np2 = 2*Np;
142 
143   t0 = 0.0;
144   tf = 2.0;
145 
146   reltol = 1.0e-8;
147 
148   abstol = 1.0e-8;
149   abstolQ = 1.0e-8;
150 
151   abstolB = 1.0e-8;
152   abstolQB = 1.0e-8;
153 
154   /* Initializations for forward problem */
155 
156   y = N_VNew_Serial(Neq);
157   if (check_retval((void *)y, "N_VNew_Serial", 0)) return(1);
158   N_VConst(ONE, y);
159 
160   yQ = N_VNew_Serial(1);
161   if (check_retval((void *)yQ, "N_VNew_Serial", 0)) return(1);
162   N_VConst(ZERO, yQ);
163 
164   yS = N_VCloneVectorArray(Np, y);
165   if (check_retval((void *)yS, "N_VCloneVectorArray", 0)) return(1);
166   N_VConst(ZERO, yS[0]);
167   N_VConst(ZERO, yS[1]);
168 
169   yQS = N_VCloneVectorArray(Np, yQ);
170   if (check_retval((void *)yQS, "N_VCloneVectorArray", 0)) return(1);
171   N_VConst(ZERO, yQS[0]);
172   N_VConst(ZERO, yQS[1]);
173 
174   /* Create and initialize forward problem */
175 
176   cvode_mem = CVodeCreate(CV_BDF);
177   if(check_retval((void *)cvode_mem, "CVodeCreate", 0)) return(1);
178 
179   retval = CVodeInit(cvode_mem, f, t0, y);
180   if(check_retval(&retval, "CVodeInit", 1)) return(1);
181 
182   retval = CVodeSStolerances(cvode_mem, reltol, abstol);
183   if(check_retval(&retval, "CVodeSStolerances", 1)) return(1);
184 
185   retval = CVodeSetUserData(cvode_mem, data);
186   if(check_retval(&retval, "CVodeSetUserData", 1)) return(1);
187 
188   /* Create a dense SUNMatrix */
189   A = SUNDenseMatrix(Neq, Neq);
190   if(check_retval((void *)A, "SUNDenseMatrix", 0)) return(1);
191 
192   /* Create banded SUNLinearSolver for the forward problem */
193   LS = SUNLinSol_Dense(y, A);
194   if(check_retval((void *)LS, "SUNLinSol_Dense", 0)) return(1);
195 
196   /* Attach the matrix and linear solver */
197   retval = CVodeSetLinearSolver(cvode_mem, LS, A);
198   if(check_retval(&retval, "CVodeSetLinearSolver", 1)) return(1);
199 
200   retval = CVodeQuadInit(cvode_mem, fQ, yQ);
201   if(check_retval(&retval, "CVodeQuadInit", 1)) return(1);
202 
203   retval = CVodeQuadSStolerances(cvode_mem, reltol, abstolQ);
204   if(check_retval(&retval, "CVodeQuadSStolerances", 1)) return(1);
205 
206   retval = CVodeSetQuadErrCon(cvode_mem, SUNTRUE);
207   if(check_retval(&retval, "CVodeSetQuadErrCon", 1)) return(1);
208 
209   retval = CVodeSensInit(cvode_mem, Np, CV_SIMULTANEOUS, fS, yS);
210   if(check_retval(&retval, "CVodeSensInit", 1)) return(1);
211 
212   retval = CVodeSensEEtolerances(cvode_mem);
213   if(check_retval(&retval, "CVodeSensEEtolerances", 1)) return(1);
214 
215   retval = CVodeSetSensErrCon(cvode_mem, SUNTRUE);
216   if(check_retval(&retval, "CVodeSetSensErrCon", 1)) return(1);
217 
218   retval = CVodeQuadSensInit(cvode_mem, fQS, yQS);
219   if(check_retval(&retval, "CVodeQuadSensInit", 1)) return(1);
220 
221   retval = CVodeQuadSensEEtolerances(cvode_mem);
222   if(check_retval(&retval, "CVodeQuadSensEEtolerances", 1)) return(1);
223 
224   retval = CVodeSetQuadSensErrCon(cvode_mem, SUNTRUE);
225   if(check_retval(&retval, "CVodeSetQuadSensErrCon", 1)) return(1);
226 
227   /* Initialize ASA */
228 
229   steps = 100;
230   retval = CVodeAdjInit(cvode_mem, steps, CV_POLYNOMIAL);
231   if(check_retval(&retval, "CVodeAdjInit", 1)) return(1);
232 
233   /* Forward integration */
234 
235   printf("-------------------\n");
236   printf("Forward integration\n");
237   printf("-------------------\n\n");
238 
239   retval = CVodeF(cvode_mem, tf, y, &time, CV_NORMAL, &ncheck);
240   if(check_retval(&retval, "CVodeF", 1)) return(1);
241 
242   retval = CVodeGetQuad(cvode_mem, &time, yQ);
243   if(check_retval(&retval, "CVodeGetQuad", 1)) return(1);
244 
245   G = Ith(yQ,1);
246 
247   retval = CVodeGetSens(cvode_mem, &time, yS);
248   if(check_retval(&retval, "CVodeGetSens", 1)) return(1);
249 
250   retval = CVodeGetQuadSens(cvode_mem, &time, yQS);
251   if(check_retval(&retval, "CVodeGetQuadSens", 1)) return(1);
252 
253   printf("ncheck = %d\n", ncheck);
254   printf("\n");
255 #if defined(SUNDIALS_EXTENDED_PRECISION)
256   printf("     y:    %12.4Le %12.4Le %12.4Le", Ith(y,1), Ith(y,2), Ith(y,3));
257   printf("     G:    %12.4Le\n", Ith(yQ,1));
258   printf("\n");
259   printf("     yS1:  %12.4Le %12.4Le %12.4Le\n", Ith(yS[0],1), Ith(yS[0],2), Ith(yS[0],3));
260   printf("     yS2:  %12.4Le %12.4Le %12.4Le\n", Ith(yS[1],1), Ith(yS[1],2), Ith(yS[1],3));
261   printf("\n");
262   printf("   dG/dp:  %12.4Le %12.4Le\n", Ith(yQS[0],1), Ith(yQS[1],1));
263 #else
264   printf("     y:    %12.4e %12.4e %12.4e", Ith(y,1), Ith(y,2), Ith(y,3));
265   printf("     G:    %12.4e\n", Ith(yQ,1));
266   printf("\n");
267   printf("     yS1:  %12.4e %12.4e %12.4e\n", Ith(yS[0],1), Ith(yS[0],2), Ith(yS[0],3));
268   printf("     yS2:  %12.4e %12.4e %12.4e\n", Ith(yS[1],1), Ith(yS[1],2), Ith(yS[1],3));
269   printf("\n");
270   printf("   dG/dp:  %12.4e %12.4e\n", Ith(yQS[0],1), Ith(yQS[1],1));
271 #endif
272   printf("\n");
273 
274   printf("Final Statistics for forward pb.\n");
275   printf("--------------------------------\n");
276   retval = PrintFwdStats(cvode_mem);
277   if (check_retval(&retval, "PrintFwdStats", 1)) return(1);
278 
279   /* Initializations for backward problems */
280 
281   yB1 = N_VNew_Serial(2*Neq);
282   if (check_retval((void *)yB1, "N_VNew_Serial", 0)) return(1);
283   N_VConst(ZERO, yB1);
284 
285   yQB1 = N_VNew_Serial(Np2);
286   if (check_retval((void *)yQB1, "N_VNew_Serial", 0)) return(1);
287   N_VConst(ZERO, yQB1);
288 
289   yB2 = N_VNew_Serial(2*Neq);
290   if (check_retval((void *)yB2, "N_VNew_Serial", 0)) return(1);
291   N_VConst(ZERO, yB2);
292 
293   yQB2 = N_VNew_Serial(Np2);
294   if (check_retval((void *)yQB2, "N_VNew_Serial", 0)) return(1);
295   N_VConst(ZERO, yQB2);
296 
297   /* Create and initialize backward problems (one for each column of the Hessian) */
298 
299   /* -------------------------
300      First backward problem
301      -------------------------*/
302 
303   retval = CVodeCreateB(cvode_mem, CV_BDF, &indexB1);
304   if(check_retval(&retval, "CVodeCreateB", 1)) return(1);
305 
306   retval = CVodeInitBS(cvode_mem, indexB1, fB1, tf, yB1);
307   if(check_retval(&retval, "CVodeInitBS", 1)) return(1);
308 
309   retval = CVodeSStolerancesB(cvode_mem, indexB1, reltol, abstolB);
310   if(check_retval(&retval, "CVodeSStolerancesB", 1)) return(1);
311 
312   retval = CVodeSetUserDataB(cvode_mem, indexB1, data);
313   if(check_retval(&retval, "CVodeSetUserDataB", 1)) return(1);
314 
315   retval = CVodeQuadInitBS(cvode_mem, indexB1, fQB1, yQB1);
316   if(check_retval(&retval, "CVodeQuadInitBS", 1)) return(1);
317 
318   retval = CVodeQuadSStolerancesB(cvode_mem, indexB1, reltol, abstolQB);
319   if(check_retval(&retval, "CVodeQuadSStolerancesB", 1)) return(1);
320 
321   retval = CVodeSetQuadErrConB(cvode_mem, indexB1, SUNTRUE);
322   if(check_retval(&retval, "CVodeSetQuadErrConB", 1)) return(1);
323 
324   /* Create a dense SUNMatrix */
325   AB1 = SUNDenseMatrix(2*Neq, 2*Neq);
326   if(check_retval((void *)A, "SUNDenseMatrix", 0)) return(1);
327 
328   /* Create dense SUNLinearSolver for the forward problem */
329   LSB1 = SUNLinSol_Dense(yB1, AB1);
330   if(check_retval((void *)LSB1, "SUNLinSol_Dense", 0)) return(1);
331 
332   /* Attach the matrix and linear solver */
333   retval = CVodeSetLinearSolverB(cvode_mem, indexB1, LSB1, AB1);
334   if(check_retval(&retval, "CVodeSetLinearSolverB", 1)) return(1);
335 
336   /* -------------------------
337      Second backward problem
338      -------------------------*/
339 
340   retval = CVodeCreateB(cvode_mem, CV_BDF, &indexB2);
341   if(check_retval(&retval, "CVodeCreateB", 1)) return(1);
342 
343   retval = CVodeInitBS(cvode_mem, indexB2, fB2, tf, yB2);
344   if(check_retval(&retval, "CVodeInitBS", 1)) return(1);
345 
346   retval = CVodeSStolerancesB(cvode_mem, indexB2, reltol, abstolB);
347   if(check_retval(&retval, "CVodeSStolerancesB", 1)) return(1);
348 
349   retval = CVodeSetUserDataB(cvode_mem, indexB2, data);
350   if(check_retval(&retval, "CVodeSetUserDataB", 1)) return(1);
351 
352   retval = CVodeQuadInitBS(cvode_mem, indexB2, fQB2, yQB2);
353   if(check_retval(&retval, "CVodeQuadInitBS", 1)) return(1);
354 
355   retval = CVodeQuadSStolerancesB(cvode_mem, indexB2, reltol, abstolQB);
356   if(check_retval(&retval, "CVodeQuadSStolerancesB", 1)) return(1);
357 
358   retval = CVodeSetQuadErrConB(cvode_mem, indexB2, SUNTRUE);
359   if(check_retval(&retval, "CVodeSetQuadErrConB", 1)) return(1);
360 
361   /* Create a dense SUNMatrix */
362   AB2 = SUNDenseMatrix(2*Neq, 2*Neq);
363   if(check_retval((void *)AB2, "SUNDenseMatrix", 0)) return(1);
364 
365   /* Create dense SUNLinearSolver for the forward problem */
366   LSB2 = SUNLinSol_Dense(yB2, AB2);
367   if(check_retval((void *)LSB2, "SUNLinSol_Dense", 0)) return(1);
368 
369   /* Attach the matrix and linear solver */
370   retval = CVodeSetLinearSolverB(cvode_mem, indexB2, LSB2, AB2);
371   if(check_retval(&retval, "CVodeSetLinearSolverB", 1)) return(1);
372 
373   /* Backward integration */
374 
375   printf("---------------------------------------------\n");
376   printf("Backward integration ... (2 adjoint problems)\n");
377   printf("---------------------------------------------\n\n");
378 
379   retval = CVodeB(cvode_mem, t0, CV_NORMAL);
380   if(check_retval(&retval, "CVodeB", 1)) return(1);
381 
382   retval = CVodeGetB(cvode_mem, indexB1, &time, yB1);
383   if(check_retval(&retval, "CVodeGetB", 1)) return(1);
384 
385   retval = CVodeGetQuadB(cvode_mem, indexB1, &time, yQB1);
386   if(check_retval(&retval, "CVodeGetQuadB", 1)) return(1);
387 
388   retval = CVodeGetB(cvode_mem, indexB2, &time, yB2);
389   if(check_retval(&retval, "CVodeGetB", 1)) return(1);
390 
391   retval = CVodeGetQuadB(cvode_mem, indexB2, &time, yQB2);
392   if(check_retval(&retval, "CVodeGetQuadB", 1)) return(1);
393 
394 #if defined(SUNDIALS_EXTENDED_PRECISION)
395   printf("   dG/dp:  %12.4Le %12.4Le   (from backward pb. 1)\n", -Ith(yQB1,1), -Ith(yQB1,2));
396   printf("           %12.4Le %12.4Le   (from backward pb. 2)\n", -Ith(yQB2,1), -Ith(yQB2,2));
397 #else
398   printf("   dG/dp:  %12.4e %12.4e   (from backward pb. 1)\n", -Ith(yQB1,1), -Ith(yQB1,2));
399   printf("           %12.4e %12.4e   (from backward pb. 2)\n", -Ith(yQB2,1), -Ith(yQB2,2));
400 #endif
401   printf("\n");
402   printf("   H = d2G/dp2:\n");
403   printf("        (1)            (2)\n");
404 #if defined(SUNDIALS_EXTENDED_PRECISION)
405   printf("  %12.4Le   %12.4Le\n", -Ith(yQB1,3) , -Ith(yQB2,3));
406   printf("  %12.4Le   %12.4Le\n", -Ith(yQB1,4) , -Ith(yQB2,4));
407 #else
408   printf("  %12.4e   %12.4e\n", -Ith(yQB1,3) , -Ith(yQB2,3));
409   printf("  %12.4e   %12.4e\n", -Ith(yQB1,4) , -Ith(yQB2,4));
410 #endif
411   printf("\n");
412 
413   printf("Final Statistics for backward pb. 1\n");
414   printf("-----------------------------------\n");
415   retval = PrintBckStats(cvode_mem, indexB1);
416   if (check_retval(&retval, "PrintBckStats", 1)) return(1);
417 
418   printf("Final Statistics for backward pb. 2\n");
419   printf("-----------------------------------\n");
420   retval = PrintBckStats(cvode_mem, indexB2);
421   if (check_retval(&retval, "PrintBckStats", 1)) return(1);
422 
423   /* Free memory */
424 
425   CVodeFree(&cvode_mem);
426   SUNLinSolFree(LS);
427   SUNMatDestroy(A);
428   SUNLinSolFree(LSB1);
429   SUNMatDestroy(AB1);
430   SUNLinSolFree(LSB2);
431   SUNMatDestroy(AB2);
432 
433   /* Finite difference tests */
434 
435   dp = RCONST(1.0e-2);
436 
437   printf("-----------------------\n");
438   printf("Finite Difference tests\n");
439   printf("-----------------------\n\n");
440 
441 #if defined(SUNDIALS_EXTENDED_PRECISION)
442   printf("del_p = %Lg\n\n",dp);
443 #else
444   printf("del_p = %g\n\n",dp);
445 #endif
446 
447   cvode_mem = CVodeCreate(CV_BDF);
448 
449   N_VConst(ONE, y);
450   N_VConst(ZERO, yQ);
451 
452   retval = CVodeInit(cvode_mem, f, t0, y);
453   if(check_retval(&retval, "CVodeInit", 1)) return(1);
454 
455   retval = CVodeSStolerances(cvode_mem, reltol, abstol);
456   if(check_retval(&retval, "CVodeSStolerances", 1)) return(1);
457 
458   retval = CVodeSetUserData(cvode_mem, data);
459   if(check_retval(&retval, "CVodeSetUserData", 1)) return(1);
460 
461   /* Create a dense SUNMatrix */
462   A = SUNDenseMatrix(Neq, Neq);
463   if(check_retval((void *)A, "SUNDenseMatrix", 0)) return(1);
464 
465   /* Create dense SUNLinearSolver for the forward problem */
466   LS = SUNLinSol_Dense(y, A);
467   if(check_retval((void *)LS, "SUNLinSol_Dense", 0)) return(1);
468 
469   /* Attach the matrix and linear solver */
470   retval = CVodeSetLinearSolver(cvode_mem, LS, A);
471   if(check_retval(&retval, "CVodeSetLinearSolver", 1)) return(1);
472 
473   retval = CVodeQuadInit(cvode_mem, fQ, yQ);
474   if(check_retval(&retval, "CVodeQuadInit", 1)) return(1);
475 
476   retval = CVodeQuadSStolerances(cvode_mem, reltol, abstolQ);
477   if(check_retval(&retval, "CVodeQuadSStolerances", 1)) return(1);
478 
479   retval = CVodeSetQuadErrCon(cvode_mem, SUNTRUE);
480   if(check_retval(&retval, "CVodeSetQuadErrCon", 1)) return(1);
481 
482   data->p1 += dp;
483 
484   retval = CVode(cvode_mem, tf, y, &time, CV_NORMAL);
485   if(check_retval(&retval, "CVode", 1)) return(1);
486 
487   retval = CVodeGetQuad(cvode_mem, &time, yQ);
488   if(check_retval(&retval, "CVodeGetQuad", 1)) return(1);
489 
490   Gp = Ith(yQ,1);
491 
492 #if defined(SUNDIALS_EXTENDED_PRECISION)
493   printf("p1+  y:   %12.4Le %12.4Le %12.4Le", Ith(y,1), Ith(y,2), Ith(y,3));
494   printf("     G:   %12.4Le\n",Ith(yQ,1));
495 #else
496   printf("p1+  y:   %12.4e %12.4e %12.4e", Ith(y,1), Ith(y,2), Ith(y,3));
497   printf("     G:   %12.4e\n",Ith(yQ,1));
498 #endif
499   data->p1 -= 2.0*dp;
500 
501   N_VConst(ONE, y);
502   N_VConst(ZERO, yQ);
503 
504   CVodeReInit(cvode_mem, t0, y);
505   CVodeQuadReInit(cvode_mem, yQ);
506 
507   retval = CVode(cvode_mem, tf, y, &time, CV_NORMAL);
508   if(check_retval(&retval, "CVode", 1)) return(1);
509 
510   retval = CVodeGetQuad(cvode_mem, &time, yQ);
511   if(check_retval(&retval, "CVodeGetQuad", 1)) return(1);
512 
513   Gm = Ith(yQ,1);
514 #if defined(SUNDIALS_EXTENDED_PRECISION)
515   printf("p1-  y:   %12.4Le %12.4Le %12.4Le", Ith(y,1), Ith(y,2), Ith(y,3));
516   printf("     G:   %12.4Le\n",Ith(yQ,1));
517 #else
518   printf("p1-  y:   %12.4e %12.4e %12.4e", Ith(y,1), Ith(y,2), Ith(y,3));
519   printf("     G:   %12.4e\n",Ith(yQ,1));
520 #endif
521   data->p1 += dp;
522 
523   grdG_fwd[0] = (Gp-G)/dp;
524   grdG_bck[0] = (G-Gm)/dp;
525   grdG_cntr[0] = (Gp-Gm)/(2.0*dp);
526   H11 = (Gp - 2.0*G + Gm) / (dp*dp);
527 
528   data->p2 += dp;
529 
530   N_VConst(ONE, y);
531   N_VConst(ZERO, yQ);
532 
533   CVodeReInit(cvode_mem, t0, y);
534   CVodeQuadReInit(cvode_mem, yQ);
535 
536   retval = CVode(cvode_mem, tf, y, &time, CV_NORMAL);
537   if(check_retval(&retval, "CVode", 1)) return(1);
538 
539   retval = CVodeGetQuad(cvode_mem, &time, yQ);
540   if(check_retval(&retval, "CVodeGetQuad", 1)) return(1);
541 
542   Gp = Ith(yQ,1);
543 #if defined(SUNDIALS_EXTENDED_PRECISION)
544   printf("p2+  y:   %12.4Le %12.4Le %12.4Le", Ith(y,1), Ith(y,2), Ith(y,3));
545   printf("     G:   %12.4Le\n",Ith(yQ,1));
546 #else
547   printf("p2+  y:   %12.4e %12.4e %12.4e", Ith(y,1), Ith(y,2), Ith(y,3));
548   printf("     G:   %12.4e\n",Ith(yQ,1));
549 #endif
550   data->p2 -= 2.0*dp;
551 
552   N_VConst(ONE, y);
553   N_VConst(ZERO, yQ);
554 
555   CVodeReInit(cvode_mem, t0, y);
556   CVodeQuadReInit(cvode_mem, yQ);
557 
558   retval = CVode(cvode_mem, tf, y, &time, CV_NORMAL);
559   if(check_retval(&retval, "CVode", 1)) return(1);
560 
561   retval = CVodeGetQuad(cvode_mem, &time, yQ);
562   if(check_retval(&retval, "CVodeGetQuad", 1)) return(1);
563 
564   Gm = Ith(yQ,1);
565 #if defined(SUNDIALS_EXTENDED_PRECISION)
566   printf("p2-  y:   %12.4Le %12.4Le %12.4Le", Ith(y,1), Ith(y,2), Ith(y,3));
567   printf("     G:   %12.4Le\n",Ith(yQ,1));
568 #else
569   printf("p2-  y:   %12.4e %12.4e %12.4e", Ith(y,1), Ith(y,2), Ith(y,3));
570   printf("     G:   %12.4e\n",Ith(yQ,1));
571 #endif
572   data->p2 += dp;
573 
574   grdG_fwd[1] = (Gp-G)/dp;
575   grdG_bck[1] = (G-Gm)/dp;
576   grdG_cntr[1] = (Gp-Gm)/(2.0*dp);
577   H22 = (Gp - 2.0*G + Gm) / (dp*dp);
578 
579   printf("\n");
580 
581 #if defined(SUNDIALS_EXTENDED_PRECISION)
582   printf("   dG/dp:  %12.4Le %12.4Le   (fwd FD)\n", grdG_fwd[0], grdG_fwd[1]);
583   printf("           %12.4Le %12.4Le   (bck FD)\n", grdG_bck[0], grdG_bck[1]);
584   printf("           %12.4Le %12.4Le   (cntr FD)\n", grdG_cntr[0], grdG_cntr[1]);
585   printf("\n");
586   printf("  H(1,1):  %12.4Le\n", H11);
587   printf("  H(2,2):  %12.4Le\n", H22);
588 #else
589   printf("   dG/dp:  %12.4e %12.4e   (fwd FD)\n", grdG_fwd[0], grdG_fwd[1]);
590   printf("           %12.4e %12.4e   (bck FD)\n", grdG_bck[0], grdG_bck[1]);
591   printf("           %12.4e %12.4e   (cntr FD)\n", grdG_cntr[0], grdG_cntr[1]);
592   printf("\n");
593   printf("  H(1,1):  %12.4e\n", H11);
594   printf("  H(2,2):  %12.4e\n", H22);
595 #endif
596 
597   /* Free memory */
598 
599   CVodeFree(&cvode_mem);
600   SUNLinSolFree(LS);
601   SUNMatDestroy(A);
602 
603   N_VDestroy(y);
604   N_VDestroy(yQ);
605 
606   N_VDestroyVectorArray(yS, Np);
607   N_VDestroyVectorArray(yQS, Np);
608 
609   N_VDestroy(yB1);
610   N_VDestroy(yQB1);
611   N_VDestroy(yB2);
612   N_VDestroy(yQB2);
613 
614   free(data);
615 
616   return(0);
617 
618 }
619 
620 /*
621  *--------------------------------------------------------------------
622  * FUNCTIONS CALLED BY CVODES
623  *--------------------------------------------------------------------
624  */
625 
626 
f(realtype t,N_Vector y,N_Vector ydot,void * user_data)627 static int f(realtype t, N_Vector y, N_Vector ydot, void *user_data)
628 {
629   realtype y1, y2, y3;
630   UserData data;
631   realtype p1, p2;
632 
633   data = (UserData) user_data;
634   p1 = data->p1;
635   p2 = data->p2;
636 
637   y1 = Ith(y,1);
638   y2 = Ith(y,2);
639   y3 = Ith(y,3);
640 
641   Ith(ydot,1) = -p1*y1*y1 - y3;
642   Ith(ydot,2) = -y2;
643   Ith(ydot,3) = -p2*p2*y2*y3;
644 
645   return(0);
646 }
647 
fQ(realtype t,N_Vector y,N_Vector qdot,void * user_data)648 static int fQ(realtype t, N_Vector y, N_Vector qdot, void *user_data)
649 {
650   realtype y1, y2, y3;
651 
652   y1 = Ith(y,1);
653   y2 = Ith(y,2);
654   y3 = Ith(y,3);
655 
656   Ith(qdot,1) = 0.5 * ( y1*y1 + y2*y2 + y3*y3 );
657 
658   return(0);
659 }
660 
fS(int Ns,realtype t,N_Vector y,N_Vector ydot,N_Vector * yS,N_Vector * ySdot,void * user_data,N_Vector tmp1,N_Vector tmp2)661 static int fS(int Ns, realtype t,
662               N_Vector y, N_Vector ydot,
663               N_Vector *yS, N_Vector *ySdot,
664               void *user_data,
665               N_Vector tmp1, N_Vector tmp2)
666 {
667   UserData data;
668   realtype y1, y2, y3;
669   realtype s1, s2, s3;
670   realtype fys1, fys2, fys3;
671   realtype p1, p2;
672 
673   data = (UserData) user_data;
674   p1 = data->p1;
675   p2 = data->p2;
676 
677   y1 = Ith(y,1);
678   y2 = Ith(y,2);
679   y3 = Ith(y,3);
680 
681   /* 1st sensitivity RHS */
682 
683   s1 = Ith(yS[0],1);
684   s2 = Ith(yS[0],2);
685   s3 = Ith(yS[0],3);
686 
687   fys1 = - 2.0*p1*y1 * s1 - s3;
688   fys2 = - s2;
689   fys3 = - p2*p2*y3 * s2 - p2*p2*y2 * s3;
690 
691   Ith(ySdot[0],1) = fys1 - y1*y1;
692   Ith(ySdot[0],2) = fys2;
693   Ith(ySdot[0],3) = fys3;
694 
695   /* 2nd sensitivity RHS */
696 
697   s1 = Ith(yS[1],1);
698   s2 = Ith(yS[1],2);
699   s3 = Ith(yS[1],3);
700 
701   fys1 = - 2.0*p1*y1 * s1 - s3;
702   fys2 = - s2;
703   fys3 = - p2*p2*y3 * s2 - p2*p2*y2 * s3;
704 
705   Ith(ySdot[1],1) = fys1;
706   Ith(ySdot[1],2) = fys2;
707   Ith(ySdot[1],3) = fys3 - 2.0*p2*y2*y3;
708 
709   return(0);
710 }
711 
fQS(int Ns,realtype t,N_Vector y,N_Vector * yS,N_Vector yQdot,N_Vector * yQSdot,void * user_data,N_Vector tmp,N_Vector tmpQ)712 static int fQS(int Ns, realtype t,
713                N_Vector y, N_Vector *yS,
714                N_Vector yQdot, N_Vector *yQSdot,
715                void *user_data,
716                N_Vector tmp, N_Vector tmpQ)
717 {
718   realtype y1, y2, y3;
719   realtype s1, s2, s3;
720 
721   y1 = Ith(y,1);
722   y2 = Ith(y,2);
723   y3 = Ith(y,3);
724 
725 
726   /* 1st sensitivity RHS */
727 
728   s1 = Ith(yS[0],1);
729   s2 = Ith(yS[0],2);
730   s3 = Ith(yS[0],3);
731 
732   Ith(yQSdot[0],1) = y1*s1 + y2*s2 + y3*s3;
733 
734 
735   /* 1st sensitivity RHS */
736 
737   s1 = Ith(yS[1],1);
738   s2 = Ith(yS[1],2);
739   s3 = Ith(yS[1],3);
740 
741   Ith(yQSdot[1],1) = y1*s1 + y2*s2 + y3*s3;
742 
743   return(0);
744 }
745 
fB1(realtype t,N_Vector y,N_Vector * yS,N_Vector yB,N_Vector yBdot,void * user_dataB)746 static int fB1(realtype t, N_Vector y, N_Vector *yS,
747                N_Vector yB, N_Vector yBdot, void *user_dataB)
748 {
749   UserData data;
750   realtype p1, p2;
751   realtype y1, y2, y3;  /* solution */
752   realtype s1, s2, s3;  /* sensitivity 1 */
753   realtype l1, l2, l3;  /* lambda */
754   realtype m1, m2, m3;  /* mu */
755 
756   data = (UserData) user_dataB;
757   p1 = data->p1;
758   p2 = data->p2;
759 
760   y1 = Ith(y,1);
761   y2 = Ith(y,2);
762   y3 = Ith(y,3);
763 
764   s1 = Ith(yS[0],1);
765   s2 = Ith(yS[0],2);
766   s3 = Ith(yS[0],3);
767 
768   l1 = Ith(yB,1);
769   l2 = Ith(yB,2);
770   l3 = Ith(yB,3);
771 
772   m1 = Ith(yB,4);
773   m2 = Ith(yB,5);
774   m3 = Ith(yB,6);
775 
776 
777   Ith(yBdot,1) = 2.0*p1*y1 * l1     - y1;
778   Ith(yBdot,2) = l2 + p2*p2*y3 * l3 - y2;
779   Ith(yBdot,3) = l1 + p2*p2*y2 * l3 - y3;
780 
781   Ith(yBdot,4) = 2.0*p1*y1 * m1     + l1 * 2.0*(y1 + p1*s1) - s1;
782   Ith(yBdot,5) = m2 + p2*p2*y3 * m3 + l3 * p2*p2*s3         - s2;
783   Ith(yBdot,6) = m1 + p2*p2*y2 * m3 + l3 * p2*p2*s2         - s3;
784 
785   return(0);
786 }
787 
fQB1(realtype t,N_Vector y,N_Vector * yS,N_Vector yB,N_Vector qBdot,void * user_dataB)788 static int fQB1(realtype t, N_Vector y, N_Vector *yS,
789                 N_Vector yB, N_Vector qBdot, void *user_dataB)
790 {
791   UserData data;
792   realtype p2;
793   realtype y1, y2, y3;  /* solution */
794   realtype s1, s2, s3;  /* sensitivity 1 */
795   realtype l1, l3;      /* lambda */
796   realtype m1, m3;      /* mu */
797 
798   data = (UserData) user_dataB;
799 
800   p2 = data->p2;
801 
802   y1 = Ith(y,1);
803   y2 = Ith(y,2);
804   y3 = Ith(y,3);
805 
806   s1 = Ith(yS[0],1);
807   s2 = Ith(yS[0],2);
808   s3 = Ith(yS[0],3);
809 
810   l1 = Ith(yB,1);
811   l3 = Ith(yB,3);
812 
813   m1 = Ith(yB,4);
814   m3 = Ith(yB,6);
815 
816   Ith(qBdot,1) = -y1*y1 * l1;
817   Ith(qBdot,2) = -2.0*p2*y2*y3 * l3;
818 
819   Ith(qBdot,3) = -y1*y1 * m1        - l1 * 2.0*y1*s1;
820   Ith(qBdot,4) = -2.0*p2*y2*y3 * m3 - l3 * 2.0*(p2*y3*s2 + p2*y2*s3);
821 
822   return(0);
823 }
824 
825 
826 
827 
fB2(realtype t,N_Vector y,N_Vector * yS,N_Vector yB,N_Vector yBdot,void * user_dataB)828 static int fB2(realtype t, N_Vector y, N_Vector *yS,
829                N_Vector yB, N_Vector yBdot, void *user_dataB)
830 {
831   UserData data;
832   realtype p1, p2;
833   realtype y1, y2, y3;  /* solution */
834   realtype s1, s2, s3;  /* sensitivity 2 */
835   realtype l1, l2, l3;  /* lambda */
836   realtype m1, m2, m3;  /* mu */
837 
838   data = (UserData) user_dataB;
839   p1 = data->p1;
840   p2 = data->p2;
841 
842   y1 = Ith(y,1);
843   y2 = Ith(y,2);
844   y3 = Ith(y,3);
845 
846   s1 = Ith(yS[1],1);
847   s2 = Ith(yS[1],2);
848   s3 = Ith(yS[1],3);
849 
850   l1 = Ith(yB,1);
851   l2 = Ith(yB,2);
852   l3 = Ith(yB,3);
853 
854   m1 = Ith(yB,4);
855   m2 = Ith(yB,5);
856   m3 = Ith(yB,6);
857 
858   Ith(yBdot,1) = 2.0*p1*y1 * l1     - y1;
859   Ith(yBdot,2) = l2 + p2*p2*y3 * l3 - y2;
860   Ith(yBdot,3) = l1 + p2*p2*y2 * l3 - y3;
861 
862   Ith(yBdot,4) = 2.0*p1*y1 * m1     + l1 * 2.0*p1*s1              - s1;
863   Ith(yBdot,5) = m2 + p2*p2*y3 * m3 + l3 * (2.0*p2*y3 + p2*p2*s3) - s2;
864   Ith(yBdot,6) = m1 + p2*p2*y2 * m3 + l3 * (2.0*p2*y2 + p2*p2*s2) - s3;
865 
866 
867   return(0);
868 }
869 
870 
fQB2(realtype t,N_Vector y,N_Vector * yS,N_Vector yB,N_Vector qBdot,void * user_dataB)871 static int fQB2(realtype t, N_Vector y, N_Vector *yS,
872                 N_Vector yB, N_Vector qBdot, void *user_dataB)
873 {
874   UserData data;
875   realtype p2;
876   realtype y1, y2, y3;  /* solution */
877   realtype s1, s2, s3;  /* sensitivity 2 */
878   realtype l1, l3;  /* lambda */
879   realtype m1, m3;  /* mu */
880 
881   data = (UserData) user_dataB;
882 
883   p2 = data->p2;
884 
885   y1 = Ith(y,1);
886   y2 = Ith(y,2);
887   y3 = Ith(y,3);
888 
889   s1 = Ith(yS[1],1);
890   s2 = Ith(yS[1],2);
891   s3 = Ith(yS[1],3);
892 
893   l1 = Ith(yB,1);
894   l3 = Ith(yB,3);
895 
896   m1 = Ith(yB,4);
897   m3 = Ith(yB,6);
898 
899   Ith(qBdot,1) = -y1*y1 * l1;
900   Ith(qBdot,2) = -2.0*p2*y2*y3 * l3;
901 
902   Ith(qBdot,3) = -y1*y1 * m1        - l1 * 2.0*y1*s1;
903   Ith(qBdot,4) = -2.0*p2*y2*y3 * m3 - l3 * 2.0*(p2*y3*s2 + p2*y2*s3 + y2*y3);
904 
905   return(0);
906 }
907 
908 
909 /*
910  *--------------------------------------------------------------------
911  * PRIVATE FUNCTIONS
912  *--------------------------------------------------------------------
913  */
914 
PrintFwdStats(void * cvode_mem)915 int PrintFwdStats(void *cvode_mem)
916 {
917   long int nst, nfe, nsetups, nni, ncfn, netf;
918   long int nfQe, netfQ;
919   long int nfSe, nfeS, nsetupsS, netfS;
920   long int nfQSe, netfQS;
921 
922   int qlast, qcur;
923   realtype h0u, hlast, hcur, tcur;
924 
925   int retval;
926 
927 
928   retval = CVodeGetIntegratorStats(cvode_mem, &nst, &nfe, &nsetups, &netf,
929                                  &qlast, &qcur,
930                                  &h0u, &hlast, &hcur,
931                                  &tcur);
932 
933   retval = CVodeGetNonlinSolvStats(cvode_mem, &nni, &ncfn);
934 
935   retval = CVodeGetQuadStats(cvode_mem, &nfQe, &netfQ);
936 
937   retval = CVodeGetSensStats(cvode_mem, &nfSe, &nfeS, &netfS, &nsetupsS);
938 
939   retval = CVodeGetQuadSensStats(cvode_mem, &nfQSe, &netfQS);
940 
941 
942   printf(" Number steps: %5ld\n\n", nst);
943   printf(" Function evaluations:\n");
944   printf("  f:        %5ld\n  fQ:       %5ld\n  fS:       %5ld\n  fQS:      %5ld\n",
945          nfe, nfQe, nfSe, nfQSe);
946   printf(" Error test failures:\n");
947   printf("  netf:     %5ld\n  netfQ:    %5ld\n  netfS:    %5ld\n  netfQS:   %5ld\n",
948          netf, netfQ, netfS, netfQS);
949   printf(" Linear solver setups:\n");
950   printf("  nsetups:  %5ld\n  nsetupsS: %5ld\n", nsetups, nsetupsS);
951   printf(" Nonlinear iterations:\n");
952   printf("  nni:      %5ld\n", nni);
953   printf(" Convergence failures:\n");
954   printf("  ncfn:     %5ld\n", ncfn);
955 
956   printf("\n");
957 
958   return(retval);
959 }
960 
961 
PrintBckStats(void * cvode_mem,int idx)962 int PrintBckStats(void *cvode_mem, int idx)
963 {
964   void *cvode_mem_bck;
965 
966   long int nst, nfe, nsetups, nni, ncfn, netf;
967   long int nfQe, netfQ;
968 
969   int qlast, qcur;
970   realtype h0u, hlast, hcur, tcur;
971 
972   int retval;
973 
974   cvode_mem_bck = CVodeGetAdjCVodeBmem(cvode_mem, idx);
975 
976   retval = CVodeGetIntegratorStats(cvode_mem_bck, &nst, &nfe, &nsetups, &netf,
977                                  &qlast, &qcur,
978                                  &h0u, &hlast, &hcur,
979                                  &tcur);
980 
981   retval = CVodeGetNonlinSolvStats(cvode_mem_bck, &nni, &ncfn);
982 
983   retval = CVodeGetQuadStats(cvode_mem_bck, &nfQe, &netfQ);
984 
985   printf(" Number steps: %5ld\n\n", nst);
986   printf(" Function evaluations:\n");
987   printf("  f:        %5ld\n  fQ:       %5ld\n", nfe, nfQe);
988   printf(" Error test failures:\n");
989   printf("  netf:     %5ld\n  netfQ:    %5ld\n", netf, netfQ);
990   printf(" Linear solver setups:\n");
991   printf("  nsetups:  %5ld\n", nsetups);
992   printf(" Nonlinear iterations:\n");
993   printf("  nni:      %5ld\n", nni);
994   printf(" Convergence failures:\n");
995   printf("  ncfn:     %5ld\n", ncfn);
996 
997   printf("\n");
998 
999   return(retval);
1000 }
1001 
1002 /*
1003  * Check function return value...
1004  *   opt == 0 means SUNDIALS function allocates memory so check if
1005  *            returned NULL pointer
1006  *   opt == 1 means SUNDIALS function returns an integer value so check if
1007  *            retval < 0
1008  *   opt == 2 means function allocates memory so check if returned
1009  *            NULL pointer
1010  */
1011 
check_retval(void * returnvalue,const char * funcname,int opt)1012 static int check_retval(void *returnvalue, const char *funcname, int opt)
1013 {
1014   int *retval;
1015 
1016   /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
1017   if (opt == 0 && returnvalue == NULL) {
1018     fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
1019 	    funcname);
1020     return(1); }
1021 
1022   /* Check if retval < 0 */
1023   else if (opt == 1) {
1024     retval = (int *) returnvalue;
1025     if (*retval < 0) {
1026       fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with retval = %d\n\n",
1027 	      funcname, *retval);
1028       return(1); }}
1029 
1030   /* Check if function returned NULL pointer - no memory allocated */
1031   else if (opt == 2 && returnvalue == NULL) {
1032     fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
1033 	    funcname);
1034     return(1); }
1035 
1036   return(0);
1037 }
1038 
1039