1 /* -----------------------------------------------------------------
2 * Programmer(s): Radu Serban @ LLNL
3 * -----------------------------------------------------------------
4 * SUNDIALS Copyright Start
5 * Copyright (c) 2002-2021, Lawrence Livermore National Security
6 * and Southern Methodist University.
7 * All rights reserved.
8 *
9 * See the top-level LICENSE and NOTICE files for details.
10 *
11 * SPDX-License-Identifier: BSD-3-Clause
12 * SUNDIALS Copyright End
13 * -----------------------------------------------------------------
14 *
15 * Hessian through adjoint sensitivity example problem.
16 *
17 * [ - p1 * y1^2 - y3 ] [ 1 ]
18 * y' = [ - y2 ] y(0) = [ 1 ]
19 * [ -p2^2 * y2 * y3 ] [ 1 ]
20 *
21 * p1 = 1.0
22 * p2 = 2.0
23 *
24 * 2
25 * /
26 * G(p) = | 0.5 * ( y1^2 + y2^2 + y3^2 ) dt
27 * /
28 * 0
29 *
30 * Compute the gradient (ASA) and Hessian (FSA over ASA) of G(p).
31 *
32 * See D.B. Ozyurt and P.I. Barton, SISC 26(5) 1725-1743, 2005.
33 *
34 * -----------------------------------------------------------------*/
35
36 #include <stdio.h>
37 #include <stdlib.h>
38
39 #include <cvodes/cvodes.h> /* prototypes for CVODES fcts., consts. */
40 #include <nvector/nvector_serial.h> /* access to serial N_Vector */
41 #include <sunmatrix/sunmatrix_dense.h> /* access to band SUNMatrix */
42 #include <sunlinsol/sunlinsol_dense.h> /* access to band SUNLinearSolver */
43 #include <sundials/sundials_math.h> /* definition of SUNRabs, SUNRexp, etc. */
44
45 #define Ith(v,i) NV_Ith_S(v,i-1)
46
47 #define ZERO RCONST(0.0)
48 #define ONE RCONST(1.0)
49
50 typedef struct {
51 realtype p1, p2;
52 } *UserData;
53
54 static int f(realtype t, N_Vector y, N_Vector ydot, void *user_data);
55 static int fQ(realtype t, N_Vector y, N_Vector qdot, void *user_data);
56 static int fS(int Ns, realtype t,
57 N_Vector y, N_Vector ydot,
58 N_Vector *yS, N_Vector *ySdot,
59 void *user_data,
60 N_Vector tmp1, N_Vector tmp2);
61 static int fQS(int Ns, realtype t,
62 N_Vector y, N_Vector *yS,
63 N_Vector yQdot, N_Vector *yQSdot,
64 void *user_data,
65 N_Vector tmp, N_Vector tmpQ);
66
67 static int fB1(realtype t, N_Vector y, N_Vector *yS,
68 N_Vector yB, N_Vector yBdot, void *user_dataB);
69 static int fQB1(realtype t, N_Vector y, N_Vector *yS,
70 N_Vector yB, N_Vector qBdot, void *user_dataB);
71
72
73 static int fB2(realtype t, N_Vector y, N_Vector *yS,
74 N_Vector yB, N_Vector yBdot, void *user_dataB);
75 static int fQB2(realtype t, N_Vector y, N_Vector *yS,
76 N_Vector yB, N_Vector qBdot, void *user_dataB);
77
78 int PrintFwdStats(void *cvode_mem);
79 int PrintBckStats(void *cvode_mem, int idx);
80
81 /* Private function to check function return values */
82
83 static int check_retval(void *returnvalue, const char *funcname, int opt);
84
85 /*
86 *--------------------------------------------------------------------
87 * MAIN PROGRAM
88 *--------------------------------------------------------------------
89 */
90
main(int argc,char * argv[])91 int main(int argc, char *argv[])
92 {
93 UserData data;
94
95 SUNMatrix A, AB1, AB2;
96 SUNLinearSolver LS, LSB1, LSB2;
97 void *cvode_mem;
98
99 sunindextype Neq, Np2;
100 int Np;
101
102 realtype t0, tf;
103
104 realtype reltol;
105 realtype abstol, abstolQ, abstolB, abstolQB;
106
107 N_Vector y, yQ;
108 N_Vector *yS, *yQS;
109 N_Vector yB1, yB2;
110 N_Vector yQB1, yQB2;
111
112 int steps, ncheck;
113 int indexB1, indexB2;
114
115 int retval;
116 realtype time;
117
118 realtype dp;
119 realtype G, Gp, Gm;
120 realtype grdG_fwd[2], grdG_bck[2], grdG_cntr[2];
121 realtype H11, H22;
122
123 data = NULL;
124 y = yQ = NULL;
125 yB1 = yB2 = NULL;
126 yQB1 = yQB2 = NULL;
127 A = AB1 = AB2 = NULL;
128 LS = LSB1 = LSB2 = NULL;
129 cvode_mem = NULL;
130
131 /* User data structure */
132
133 data = (UserData) malloc(sizeof *data);
134 data->p1 = RCONST(1.0);
135 data->p2 = RCONST(2.0);
136
137 /* Problem size, integration interval, and tolerances */
138
139 Neq = 3;
140 Np = 2;
141 Np2 = 2*Np;
142
143 t0 = 0.0;
144 tf = 2.0;
145
146 reltol = 1.0e-8;
147
148 abstol = 1.0e-8;
149 abstolQ = 1.0e-8;
150
151 abstolB = 1.0e-8;
152 abstolQB = 1.0e-8;
153
154 /* Initializations for forward problem */
155
156 y = N_VNew_Serial(Neq);
157 if (check_retval((void *)y, "N_VNew_Serial", 0)) return(1);
158 N_VConst(ONE, y);
159
160 yQ = N_VNew_Serial(1);
161 if (check_retval((void *)yQ, "N_VNew_Serial", 0)) return(1);
162 N_VConst(ZERO, yQ);
163
164 yS = N_VCloneVectorArray(Np, y);
165 if (check_retval((void *)yS, "N_VCloneVectorArray", 0)) return(1);
166 N_VConst(ZERO, yS[0]);
167 N_VConst(ZERO, yS[1]);
168
169 yQS = N_VCloneVectorArray(Np, yQ);
170 if (check_retval((void *)yQS, "N_VCloneVectorArray", 0)) return(1);
171 N_VConst(ZERO, yQS[0]);
172 N_VConst(ZERO, yQS[1]);
173
174 /* Create and initialize forward problem */
175
176 cvode_mem = CVodeCreate(CV_BDF);
177 if(check_retval((void *)cvode_mem, "CVodeCreate", 0)) return(1);
178
179 retval = CVodeInit(cvode_mem, f, t0, y);
180 if(check_retval(&retval, "CVodeInit", 1)) return(1);
181
182 retval = CVodeSStolerances(cvode_mem, reltol, abstol);
183 if(check_retval(&retval, "CVodeSStolerances", 1)) return(1);
184
185 retval = CVodeSetUserData(cvode_mem, data);
186 if(check_retval(&retval, "CVodeSetUserData", 1)) return(1);
187
188 /* Create a dense SUNMatrix */
189 A = SUNDenseMatrix(Neq, Neq);
190 if(check_retval((void *)A, "SUNDenseMatrix", 0)) return(1);
191
192 /* Create banded SUNLinearSolver for the forward problem */
193 LS = SUNLinSol_Dense(y, A);
194 if(check_retval((void *)LS, "SUNLinSol_Dense", 0)) return(1);
195
196 /* Attach the matrix and linear solver */
197 retval = CVodeSetLinearSolver(cvode_mem, LS, A);
198 if(check_retval(&retval, "CVodeSetLinearSolver", 1)) return(1);
199
200 retval = CVodeQuadInit(cvode_mem, fQ, yQ);
201 if(check_retval(&retval, "CVodeQuadInit", 1)) return(1);
202
203 retval = CVodeQuadSStolerances(cvode_mem, reltol, abstolQ);
204 if(check_retval(&retval, "CVodeQuadSStolerances", 1)) return(1);
205
206 retval = CVodeSetQuadErrCon(cvode_mem, SUNTRUE);
207 if(check_retval(&retval, "CVodeSetQuadErrCon", 1)) return(1);
208
209 retval = CVodeSensInit(cvode_mem, Np, CV_SIMULTANEOUS, fS, yS);
210 if(check_retval(&retval, "CVodeSensInit", 1)) return(1);
211
212 retval = CVodeSensEEtolerances(cvode_mem);
213 if(check_retval(&retval, "CVodeSensEEtolerances", 1)) return(1);
214
215 retval = CVodeSetSensErrCon(cvode_mem, SUNTRUE);
216 if(check_retval(&retval, "CVodeSetSensErrCon", 1)) return(1);
217
218 retval = CVodeQuadSensInit(cvode_mem, fQS, yQS);
219 if(check_retval(&retval, "CVodeQuadSensInit", 1)) return(1);
220
221 retval = CVodeQuadSensEEtolerances(cvode_mem);
222 if(check_retval(&retval, "CVodeQuadSensEEtolerances", 1)) return(1);
223
224 retval = CVodeSetQuadSensErrCon(cvode_mem, SUNTRUE);
225 if(check_retval(&retval, "CVodeSetQuadSensErrCon", 1)) return(1);
226
227 /* Initialize ASA */
228
229 steps = 100;
230 retval = CVodeAdjInit(cvode_mem, steps, CV_POLYNOMIAL);
231 if(check_retval(&retval, "CVodeAdjInit", 1)) return(1);
232
233 /* Forward integration */
234
235 printf("-------------------\n");
236 printf("Forward integration\n");
237 printf("-------------------\n\n");
238
239 retval = CVodeF(cvode_mem, tf, y, &time, CV_NORMAL, &ncheck);
240 if(check_retval(&retval, "CVodeF", 1)) return(1);
241
242 retval = CVodeGetQuad(cvode_mem, &time, yQ);
243 if(check_retval(&retval, "CVodeGetQuad", 1)) return(1);
244
245 G = Ith(yQ,1);
246
247 retval = CVodeGetSens(cvode_mem, &time, yS);
248 if(check_retval(&retval, "CVodeGetSens", 1)) return(1);
249
250 retval = CVodeGetQuadSens(cvode_mem, &time, yQS);
251 if(check_retval(&retval, "CVodeGetQuadSens", 1)) return(1);
252
253 printf("ncheck = %d\n", ncheck);
254 printf("\n");
255 #if defined(SUNDIALS_EXTENDED_PRECISION)
256 printf(" y: %12.4Le %12.4Le %12.4Le", Ith(y,1), Ith(y,2), Ith(y,3));
257 printf(" G: %12.4Le\n", Ith(yQ,1));
258 printf("\n");
259 printf(" yS1: %12.4Le %12.4Le %12.4Le\n", Ith(yS[0],1), Ith(yS[0],2), Ith(yS[0],3));
260 printf(" yS2: %12.4Le %12.4Le %12.4Le\n", Ith(yS[1],1), Ith(yS[1],2), Ith(yS[1],3));
261 printf("\n");
262 printf(" dG/dp: %12.4Le %12.4Le\n", Ith(yQS[0],1), Ith(yQS[1],1));
263 #else
264 printf(" y: %12.4e %12.4e %12.4e", Ith(y,1), Ith(y,2), Ith(y,3));
265 printf(" G: %12.4e\n", Ith(yQ,1));
266 printf("\n");
267 printf(" yS1: %12.4e %12.4e %12.4e\n", Ith(yS[0],1), Ith(yS[0],2), Ith(yS[0],3));
268 printf(" yS2: %12.4e %12.4e %12.4e\n", Ith(yS[1],1), Ith(yS[1],2), Ith(yS[1],3));
269 printf("\n");
270 printf(" dG/dp: %12.4e %12.4e\n", Ith(yQS[0],1), Ith(yQS[1],1));
271 #endif
272 printf("\n");
273
274 printf("Final Statistics for forward pb.\n");
275 printf("--------------------------------\n");
276 retval = PrintFwdStats(cvode_mem);
277 if (check_retval(&retval, "PrintFwdStats", 1)) return(1);
278
279 /* Initializations for backward problems */
280
281 yB1 = N_VNew_Serial(2*Neq);
282 if (check_retval((void *)yB1, "N_VNew_Serial", 0)) return(1);
283 N_VConst(ZERO, yB1);
284
285 yQB1 = N_VNew_Serial(Np2);
286 if (check_retval((void *)yQB1, "N_VNew_Serial", 0)) return(1);
287 N_VConst(ZERO, yQB1);
288
289 yB2 = N_VNew_Serial(2*Neq);
290 if (check_retval((void *)yB2, "N_VNew_Serial", 0)) return(1);
291 N_VConst(ZERO, yB2);
292
293 yQB2 = N_VNew_Serial(Np2);
294 if (check_retval((void *)yQB2, "N_VNew_Serial", 0)) return(1);
295 N_VConst(ZERO, yQB2);
296
297 /* Create and initialize backward problems (one for each column of the Hessian) */
298
299 /* -------------------------
300 First backward problem
301 -------------------------*/
302
303 retval = CVodeCreateB(cvode_mem, CV_BDF, &indexB1);
304 if(check_retval(&retval, "CVodeCreateB", 1)) return(1);
305
306 retval = CVodeInitBS(cvode_mem, indexB1, fB1, tf, yB1);
307 if(check_retval(&retval, "CVodeInitBS", 1)) return(1);
308
309 retval = CVodeSStolerancesB(cvode_mem, indexB1, reltol, abstolB);
310 if(check_retval(&retval, "CVodeSStolerancesB", 1)) return(1);
311
312 retval = CVodeSetUserDataB(cvode_mem, indexB1, data);
313 if(check_retval(&retval, "CVodeSetUserDataB", 1)) return(1);
314
315 retval = CVodeQuadInitBS(cvode_mem, indexB1, fQB1, yQB1);
316 if(check_retval(&retval, "CVodeQuadInitBS", 1)) return(1);
317
318 retval = CVodeQuadSStolerancesB(cvode_mem, indexB1, reltol, abstolQB);
319 if(check_retval(&retval, "CVodeQuadSStolerancesB", 1)) return(1);
320
321 retval = CVodeSetQuadErrConB(cvode_mem, indexB1, SUNTRUE);
322 if(check_retval(&retval, "CVodeSetQuadErrConB", 1)) return(1);
323
324 /* Create a dense SUNMatrix */
325 AB1 = SUNDenseMatrix(2*Neq, 2*Neq);
326 if(check_retval((void *)A, "SUNDenseMatrix", 0)) return(1);
327
328 /* Create dense SUNLinearSolver for the forward problem */
329 LSB1 = SUNLinSol_Dense(yB1, AB1);
330 if(check_retval((void *)LSB1, "SUNLinSol_Dense", 0)) return(1);
331
332 /* Attach the matrix and linear solver */
333 retval = CVodeSetLinearSolverB(cvode_mem, indexB1, LSB1, AB1);
334 if(check_retval(&retval, "CVodeSetLinearSolverB", 1)) return(1);
335
336 /* -------------------------
337 Second backward problem
338 -------------------------*/
339
340 retval = CVodeCreateB(cvode_mem, CV_BDF, &indexB2);
341 if(check_retval(&retval, "CVodeCreateB", 1)) return(1);
342
343 retval = CVodeInitBS(cvode_mem, indexB2, fB2, tf, yB2);
344 if(check_retval(&retval, "CVodeInitBS", 1)) return(1);
345
346 retval = CVodeSStolerancesB(cvode_mem, indexB2, reltol, abstolB);
347 if(check_retval(&retval, "CVodeSStolerancesB", 1)) return(1);
348
349 retval = CVodeSetUserDataB(cvode_mem, indexB2, data);
350 if(check_retval(&retval, "CVodeSetUserDataB", 1)) return(1);
351
352 retval = CVodeQuadInitBS(cvode_mem, indexB2, fQB2, yQB2);
353 if(check_retval(&retval, "CVodeQuadInitBS", 1)) return(1);
354
355 retval = CVodeQuadSStolerancesB(cvode_mem, indexB2, reltol, abstolQB);
356 if(check_retval(&retval, "CVodeQuadSStolerancesB", 1)) return(1);
357
358 retval = CVodeSetQuadErrConB(cvode_mem, indexB2, SUNTRUE);
359 if(check_retval(&retval, "CVodeSetQuadErrConB", 1)) return(1);
360
361 /* Create a dense SUNMatrix */
362 AB2 = SUNDenseMatrix(2*Neq, 2*Neq);
363 if(check_retval((void *)AB2, "SUNDenseMatrix", 0)) return(1);
364
365 /* Create dense SUNLinearSolver for the forward problem */
366 LSB2 = SUNLinSol_Dense(yB2, AB2);
367 if(check_retval((void *)LSB2, "SUNLinSol_Dense", 0)) return(1);
368
369 /* Attach the matrix and linear solver */
370 retval = CVodeSetLinearSolverB(cvode_mem, indexB2, LSB2, AB2);
371 if(check_retval(&retval, "CVodeSetLinearSolverB", 1)) return(1);
372
373 /* Backward integration */
374
375 printf("---------------------------------------------\n");
376 printf("Backward integration ... (2 adjoint problems)\n");
377 printf("---------------------------------------------\n\n");
378
379 retval = CVodeB(cvode_mem, t0, CV_NORMAL);
380 if(check_retval(&retval, "CVodeB", 1)) return(1);
381
382 retval = CVodeGetB(cvode_mem, indexB1, &time, yB1);
383 if(check_retval(&retval, "CVodeGetB", 1)) return(1);
384
385 retval = CVodeGetQuadB(cvode_mem, indexB1, &time, yQB1);
386 if(check_retval(&retval, "CVodeGetQuadB", 1)) return(1);
387
388 retval = CVodeGetB(cvode_mem, indexB2, &time, yB2);
389 if(check_retval(&retval, "CVodeGetB", 1)) return(1);
390
391 retval = CVodeGetQuadB(cvode_mem, indexB2, &time, yQB2);
392 if(check_retval(&retval, "CVodeGetQuadB", 1)) return(1);
393
394 #if defined(SUNDIALS_EXTENDED_PRECISION)
395 printf(" dG/dp: %12.4Le %12.4Le (from backward pb. 1)\n", -Ith(yQB1,1), -Ith(yQB1,2));
396 printf(" %12.4Le %12.4Le (from backward pb. 2)\n", -Ith(yQB2,1), -Ith(yQB2,2));
397 #else
398 printf(" dG/dp: %12.4e %12.4e (from backward pb. 1)\n", -Ith(yQB1,1), -Ith(yQB1,2));
399 printf(" %12.4e %12.4e (from backward pb. 2)\n", -Ith(yQB2,1), -Ith(yQB2,2));
400 #endif
401 printf("\n");
402 printf(" H = d2G/dp2:\n");
403 printf(" (1) (2)\n");
404 #if defined(SUNDIALS_EXTENDED_PRECISION)
405 printf(" %12.4Le %12.4Le\n", -Ith(yQB1,3) , -Ith(yQB2,3));
406 printf(" %12.4Le %12.4Le\n", -Ith(yQB1,4) , -Ith(yQB2,4));
407 #else
408 printf(" %12.4e %12.4e\n", -Ith(yQB1,3) , -Ith(yQB2,3));
409 printf(" %12.4e %12.4e\n", -Ith(yQB1,4) , -Ith(yQB2,4));
410 #endif
411 printf("\n");
412
413 printf("Final Statistics for backward pb. 1\n");
414 printf("-----------------------------------\n");
415 retval = PrintBckStats(cvode_mem, indexB1);
416 if (check_retval(&retval, "PrintBckStats", 1)) return(1);
417
418 printf("Final Statistics for backward pb. 2\n");
419 printf("-----------------------------------\n");
420 retval = PrintBckStats(cvode_mem, indexB2);
421 if (check_retval(&retval, "PrintBckStats", 1)) return(1);
422
423 /* Free memory */
424
425 CVodeFree(&cvode_mem);
426 SUNLinSolFree(LS);
427 SUNMatDestroy(A);
428 SUNLinSolFree(LSB1);
429 SUNMatDestroy(AB1);
430 SUNLinSolFree(LSB2);
431 SUNMatDestroy(AB2);
432
433 /* Finite difference tests */
434
435 dp = RCONST(1.0e-2);
436
437 printf("-----------------------\n");
438 printf("Finite Difference tests\n");
439 printf("-----------------------\n\n");
440
441 #if defined(SUNDIALS_EXTENDED_PRECISION)
442 printf("del_p = %Lg\n\n",dp);
443 #else
444 printf("del_p = %g\n\n",dp);
445 #endif
446
447 cvode_mem = CVodeCreate(CV_BDF);
448
449 N_VConst(ONE, y);
450 N_VConst(ZERO, yQ);
451
452 retval = CVodeInit(cvode_mem, f, t0, y);
453 if(check_retval(&retval, "CVodeInit", 1)) return(1);
454
455 retval = CVodeSStolerances(cvode_mem, reltol, abstol);
456 if(check_retval(&retval, "CVodeSStolerances", 1)) return(1);
457
458 retval = CVodeSetUserData(cvode_mem, data);
459 if(check_retval(&retval, "CVodeSetUserData", 1)) return(1);
460
461 /* Create a dense SUNMatrix */
462 A = SUNDenseMatrix(Neq, Neq);
463 if(check_retval((void *)A, "SUNDenseMatrix", 0)) return(1);
464
465 /* Create dense SUNLinearSolver for the forward problem */
466 LS = SUNLinSol_Dense(y, A);
467 if(check_retval((void *)LS, "SUNLinSol_Dense", 0)) return(1);
468
469 /* Attach the matrix and linear solver */
470 retval = CVodeSetLinearSolver(cvode_mem, LS, A);
471 if(check_retval(&retval, "CVodeSetLinearSolver", 1)) return(1);
472
473 retval = CVodeQuadInit(cvode_mem, fQ, yQ);
474 if(check_retval(&retval, "CVodeQuadInit", 1)) return(1);
475
476 retval = CVodeQuadSStolerances(cvode_mem, reltol, abstolQ);
477 if(check_retval(&retval, "CVodeQuadSStolerances", 1)) return(1);
478
479 retval = CVodeSetQuadErrCon(cvode_mem, SUNTRUE);
480 if(check_retval(&retval, "CVodeSetQuadErrCon", 1)) return(1);
481
482 data->p1 += dp;
483
484 retval = CVode(cvode_mem, tf, y, &time, CV_NORMAL);
485 if(check_retval(&retval, "CVode", 1)) return(1);
486
487 retval = CVodeGetQuad(cvode_mem, &time, yQ);
488 if(check_retval(&retval, "CVodeGetQuad", 1)) return(1);
489
490 Gp = Ith(yQ,1);
491
492 #if defined(SUNDIALS_EXTENDED_PRECISION)
493 printf("p1+ y: %12.4Le %12.4Le %12.4Le", Ith(y,1), Ith(y,2), Ith(y,3));
494 printf(" G: %12.4Le\n",Ith(yQ,1));
495 #else
496 printf("p1+ y: %12.4e %12.4e %12.4e", Ith(y,1), Ith(y,2), Ith(y,3));
497 printf(" G: %12.4e\n",Ith(yQ,1));
498 #endif
499 data->p1 -= 2.0*dp;
500
501 N_VConst(ONE, y);
502 N_VConst(ZERO, yQ);
503
504 CVodeReInit(cvode_mem, t0, y);
505 CVodeQuadReInit(cvode_mem, yQ);
506
507 retval = CVode(cvode_mem, tf, y, &time, CV_NORMAL);
508 if(check_retval(&retval, "CVode", 1)) return(1);
509
510 retval = CVodeGetQuad(cvode_mem, &time, yQ);
511 if(check_retval(&retval, "CVodeGetQuad", 1)) return(1);
512
513 Gm = Ith(yQ,1);
514 #if defined(SUNDIALS_EXTENDED_PRECISION)
515 printf("p1- y: %12.4Le %12.4Le %12.4Le", Ith(y,1), Ith(y,2), Ith(y,3));
516 printf(" G: %12.4Le\n",Ith(yQ,1));
517 #else
518 printf("p1- y: %12.4e %12.4e %12.4e", Ith(y,1), Ith(y,2), Ith(y,3));
519 printf(" G: %12.4e\n",Ith(yQ,1));
520 #endif
521 data->p1 += dp;
522
523 grdG_fwd[0] = (Gp-G)/dp;
524 grdG_bck[0] = (G-Gm)/dp;
525 grdG_cntr[0] = (Gp-Gm)/(2.0*dp);
526 H11 = (Gp - 2.0*G + Gm) / (dp*dp);
527
528 data->p2 += dp;
529
530 N_VConst(ONE, y);
531 N_VConst(ZERO, yQ);
532
533 CVodeReInit(cvode_mem, t0, y);
534 CVodeQuadReInit(cvode_mem, yQ);
535
536 retval = CVode(cvode_mem, tf, y, &time, CV_NORMAL);
537 if(check_retval(&retval, "CVode", 1)) return(1);
538
539 retval = CVodeGetQuad(cvode_mem, &time, yQ);
540 if(check_retval(&retval, "CVodeGetQuad", 1)) return(1);
541
542 Gp = Ith(yQ,1);
543 #if defined(SUNDIALS_EXTENDED_PRECISION)
544 printf("p2+ y: %12.4Le %12.4Le %12.4Le", Ith(y,1), Ith(y,2), Ith(y,3));
545 printf(" G: %12.4Le\n",Ith(yQ,1));
546 #else
547 printf("p2+ y: %12.4e %12.4e %12.4e", Ith(y,1), Ith(y,2), Ith(y,3));
548 printf(" G: %12.4e\n",Ith(yQ,1));
549 #endif
550 data->p2 -= 2.0*dp;
551
552 N_VConst(ONE, y);
553 N_VConst(ZERO, yQ);
554
555 CVodeReInit(cvode_mem, t0, y);
556 CVodeQuadReInit(cvode_mem, yQ);
557
558 retval = CVode(cvode_mem, tf, y, &time, CV_NORMAL);
559 if(check_retval(&retval, "CVode", 1)) return(1);
560
561 retval = CVodeGetQuad(cvode_mem, &time, yQ);
562 if(check_retval(&retval, "CVodeGetQuad", 1)) return(1);
563
564 Gm = Ith(yQ,1);
565 #if defined(SUNDIALS_EXTENDED_PRECISION)
566 printf("p2- y: %12.4Le %12.4Le %12.4Le", Ith(y,1), Ith(y,2), Ith(y,3));
567 printf(" G: %12.4Le\n",Ith(yQ,1));
568 #else
569 printf("p2- y: %12.4e %12.4e %12.4e", Ith(y,1), Ith(y,2), Ith(y,3));
570 printf(" G: %12.4e\n",Ith(yQ,1));
571 #endif
572 data->p2 += dp;
573
574 grdG_fwd[1] = (Gp-G)/dp;
575 grdG_bck[1] = (G-Gm)/dp;
576 grdG_cntr[1] = (Gp-Gm)/(2.0*dp);
577 H22 = (Gp - 2.0*G + Gm) / (dp*dp);
578
579 printf("\n");
580
581 #if defined(SUNDIALS_EXTENDED_PRECISION)
582 printf(" dG/dp: %12.4Le %12.4Le (fwd FD)\n", grdG_fwd[0], grdG_fwd[1]);
583 printf(" %12.4Le %12.4Le (bck FD)\n", grdG_bck[0], grdG_bck[1]);
584 printf(" %12.4Le %12.4Le (cntr FD)\n", grdG_cntr[0], grdG_cntr[1]);
585 printf("\n");
586 printf(" H(1,1): %12.4Le\n", H11);
587 printf(" H(2,2): %12.4Le\n", H22);
588 #else
589 printf(" dG/dp: %12.4e %12.4e (fwd FD)\n", grdG_fwd[0], grdG_fwd[1]);
590 printf(" %12.4e %12.4e (bck FD)\n", grdG_bck[0], grdG_bck[1]);
591 printf(" %12.4e %12.4e (cntr FD)\n", grdG_cntr[0], grdG_cntr[1]);
592 printf("\n");
593 printf(" H(1,1): %12.4e\n", H11);
594 printf(" H(2,2): %12.4e\n", H22);
595 #endif
596
597 /* Free memory */
598
599 CVodeFree(&cvode_mem);
600 SUNLinSolFree(LS);
601 SUNMatDestroy(A);
602
603 N_VDestroy(y);
604 N_VDestroy(yQ);
605
606 N_VDestroyVectorArray(yS, Np);
607 N_VDestroyVectorArray(yQS, Np);
608
609 N_VDestroy(yB1);
610 N_VDestroy(yQB1);
611 N_VDestroy(yB2);
612 N_VDestroy(yQB2);
613
614 free(data);
615
616 return(0);
617
618 }
619
620 /*
621 *--------------------------------------------------------------------
622 * FUNCTIONS CALLED BY CVODES
623 *--------------------------------------------------------------------
624 */
625
626
f(realtype t,N_Vector y,N_Vector ydot,void * user_data)627 static int f(realtype t, N_Vector y, N_Vector ydot, void *user_data)
628 {
629 realtype y1, y2, y3;
630 UserData data;
631 realtype p1, p2;
632
633 data = (UserData) user_data;
634 p1 = data->p1;
635 p2 = data->p2;
636
637 y1 = Ith(y,1);
638 y2 = Ith(y,2);
639 y3 = Ith(y,3);
640
641 Ith(ydot,1) = -p1*y1*y1 - y3;
642 Ith(ydot,2) = -y2;
643 Ith(ydot,3) = -p2*p2*y2*y3;
644
645 return(0);
646 }
647
fQ(realtype t,N_Vector y,N_Vector qdot,void * user_data)648 static int fQ(realtype t, N_Vector y, N_Vector qdot, void *user_data)
649 {
650 realtype y1, y2, y3;
651
652 y1 = Ith(y,1);
653 y2 = Ith(y,2);
654 y3 = Ith(y,3);
655
656 Ith(qdot,1) = 0.5 * ( y1*y1 + y2*y2 + y3*y3 );
657
658 return(0);
659 }
660
fS(int Ns,realtype t,N_Vector y,N_Vector ydot,N_Vector * yS,N_Vector * ySdot,void * user_data,N_Vector tmp1,N_Vector tmp2)661 static int fS(int Ns, realtype t,
662 N_Vector y, N_Vector ydot,
663 N_Vector *yS, N_Vector *ySdot,
664 void *user_data,
665 N_Vector tmp1, N_Vector tmp2)
666 {
667 UserData data;
668 realtype y1, y2, y3;
669 realtype s1, s2, s3;
670 realtype fys1, fys2, fys3;
671 realtype p1, p2;
672
673 data = (UserData) user_data;
674 p1 = data->p1;
675 p2 = data->p2;
676
677 y1 = Ith(y,1);
678 y2 = Ith(y,2);
679 y3 = Ith(y,3);
680
681 /* 1st sensitivity RHS */
682
683 s1 = Ith(yS[0],1);
684 s2 = Ith(yS[0],2);
685 s3 = Ith(yS[0],3);
686
687 fys1 = - 2.0*p1*y1 * s1 - s3;
688 fys2 = - s2;
689 fys3 = - p2*p2*y3 * s2 - p2*p2*y2 * s3;
690
691 Ith(ySdot[0],1) = fys1 - y1*y1;
692 Ith(ySdot[0],2) = fys2;
693 Ith(ySdot[0],3) = fys3;
694
695 /* 2nd sensitivity RHS */
696
697 s1 = Ith(yS[1],1);
698 s2 = Ith(yS[1],2);
699 s3 = Ith(yS[1],3);
700
701 fys1 = - 2.0*p1*y1 * s1 - s3;
702 fys2 = - s2;
703 fys3 = - p2*p2*y3 * s2 - p2*p2*y2 * s3;
704
705 Ith(ySdot[1],1) = fys1;
706 Ith(ySdot[1],2) = fys2;
707 Ith(ySdot[1],3) = fys3 - 2.0*p2*y2*y3;
708
709 return(0);
710 }
711
fQS(int Ns,realtype t,N_Vector y,N_Vector * yS,N_Vector yQdot,N_Vector * yQSdot,void * user_data,N_Vector tmp,N_Vector tmpQ)712 static int fQS(int Ns, realtype t,
713 N_Vector y, N_Vector *yS,
714 N_Vector yQdot, N_Vector *yQSdot,
715 void *user_data,
716 N_Vector tmp, N_Vector tmpQ)
717 {
718 realtype y1, y2, y3;
719 realtype s1, s2, s3;
720
721 y1 = Ith(y,1);
722 y2 = Ith(y,2);
723 y3 = Ith(y,3);
724
725
726 /* 1st sensitivity RHS */
727
728 s1 = Ith(yS[0],1);
729 s2 = Ith(yS[0],2);
730 s3 = Ith(yS[0],3);
731
732 Ith(yQSdot[0],1) = y1*s1 + y2*s2 + y3*s3;
733
734
735 /* 1st sensitivity RHS */
736
737 s1 = Ith(yS[1],1);
738 s2 = Ith(yS[1],2);
739 s3 = Ith(yS[1],3);
740
741 Ith(yQSdot[1],1) = y1*s1 + y2*s2 + y3*s3;
742
743 return(0);
744 }
745
fB1(realtype t,N_Vector y,N_Vector * yS,N_Vector yB,N_Vector yBdot,void * user_dataB)746 static int fB1(realtype t, N_Vector y, N_Vector *yS,
747 N_Vector yB, N_Vector yBdot, void *user_dataB)
748 {
749 UserData data;
750 realtype p1, p2;
751 realtype y1, y2, y3; /* solution */
752 realtype s1, s2, s3; /* sensitivity 1 */
753 realtype l1, l2, l3; /* lambda */
754 realtype m1, m2, m3; /* mu */
755
756 data = (UserData) user_dataB;
757 p1 = data->p1;
758 p2 = data->p2;
759
760 y1 = Ith(y,1);
761 y2 = Ith(y,2);
762 y3 = Ith(y,3);
763
764 s1 = Ith(yS[0],1);
765 s2 = Ith(yS[0],2);
766 s3 = Ith(yS[0],3);
767
768 l1 = Ith(yB,1);
769 l2 = Ith(yB,2);
770 l3 = Ith(yB,3);
771
772 m1 = Ith(yB,4);
773 m2 = Ith(yB,5);
774 m3 = Ith(yB,6);
775
776
777 Ith(yBdot,1) = 2.0*p1*y1 * l1 - y1;
778 Ith(yBdot,2) = l2 + p2*p2*y3 * l3 - y2;
779 Ith(yBdot,3) = l1 + p2*p2*y2 * l3 - y3;
780
781 Ith(yBdot,4) = 2.0*p1*y1 * m1 + l1 * 2.0*(y1 + p1*s1) - s1;
782 Ith(yBdot,5) = m2 + p2*p2*y3 * m3 + l3 * p2*p2*s3 - s2;
783 Ith(yBdot,6) = m1 + p2*p2*y2 * m3 + l3 * p2*p2*s2 - s3;
784
785 return(0);
786 }
787
fQB1(realtype t,N_Vector y,N_Vector * yS,N_Vector yB,N_Vector qBdot,void * user_dataB)788 static int fQB1(realtype t, N_Vector y, N_Vector *yS,
789 N_Vector yB, N_Vector qBdot, void *user_dataB)
790 {
791 UserData data;
792 realtype p2;
793 realtype y1, y2, y3; /* solution */
794 realtype s1, s2, s3; /* sensitivity 1 */
795 realtype l1, l3; /* lambda */
796 realtype m1, m3; /* mu */
797
798 data = (UserData) user_dataB;
799
800 p2 = data->p2;
801
802 y1 = Ith(y,1);
803 y2 = Ith(y,2);
804 y3 = Ith(y,3);
805
806 s1 = Ith(yS[0],1);
807 s2 = Ith(yS[0],2);
808 s3 = Ith(yS[0],3);
809
810 l1 = Ith(yB,1);
811 l3 = Ith(yB,3);
812
813 m1 = Ith(yB,4);
814 m3 = Ith(yB,6);
815
816 Ith(qBdot,1) = -y1*y1 * l1;
817 Ith(qBdot,2) = -2.0*p2*y2*y3 * l3;
818
819 Ith(qBdot,3) = -y1*y1 * m1 - l1 * 2.0*y1*s1;
820 Ith(qBdot,4) = -2.0*p2*y2*y3 * m3 - l3 * 2.0*(p2*y3*s2 + p2*y2*s3);
821
822 return(0);
823 }
824
825
826
827
fB2(realtype t,N_Vector y,N_Vector * yS,N_Vector yB,N_Vector yBdot,void * user_dataB)828 static int fB2(realtype t, N_Vector y, N_Vector *yS,
829 N_Vector yB, N_Vector yBdot, void *user_dataB)
830 {
831 UserData data;
832 realtype p1, p2;
833 realtype y1, y2, y3; /* solution */
834 realtype s1, s2, s3; /* sensitivity 2 */
835 realtype l1, l2, l3; /* lambda */
836 realtype m1, m2, m3; /* mu */
837
838 data = (UserData) user_dataB;
839 p1 = data->p1;
840 p2 = data->p2;
841
842 y1 = Ith(y,1);
843 y2 = Ith(y,2);
844 y3 = Ith(y,3);
845
846 s1 = Ith(yS[1],1);
847 s2 = Ith(yS[1],2);
848 s3 = Ith(yS[1],3);
849
850 l1 = Ith(yB,1);
851 l2 = Ith(yB,2);
852 l3 = Ith(yB,3);
853
854 m1 = Ith(yB,4);
855 m2 = Ith(yB,5);
856 m3 = Ith(yB,6);
857
858 Ith(yBdot,1) = 2.0*p1*y1 * l1 - y1;
859 Ith(yBdot,2) = l2 + p2*p2*y3 * l3 - y2;
860 Ith(yBdot,3) = l1 + p2*p2*y2 * l3 - y3;
861
862 Ith(yBdot,4) = 2.0*p1*y1 * m1 + l1 * 2.0*p1*s1 - s1;
863 Ith(yBdot,5) = m2 + p2*p2*y3 * m3 + l3 * (2.0*p2*y3 + p2*p2*s3) - s2;
864 Ith(yBdot,6) = m1 + p2*p2*y2 * m3 + l3 * (2.0*p2*y2 + p2*p2*s2) - s3;
865
866
867 return(0);
868 }
869
870
fQB2(realtype t,N_Vector y,N_Vector * yS,N_Vector yB,N_Vector qBdot,void * user_dataB)871 static int fQB2(realtype t, N_Vector y, N_Vector *yS,
872 N_Vector yB, N_Vector qBdot, void *user_dataB)
873 {
874 UserData data;
875 realtype p2;
876 realtype y1, y2, y3; /* solution */
877 realtype s1, s2, s3; /* sensitivity 2 */
878 realtype l1, l3; /* lambda */
879 realtype m1, m3; /* mu */
880
881 data = (UserData) user_dataB;
882
883 p2 = data->p2;
884
885 y1 = Ith(y,1);
886 y2 = Ith(y,2);
887 y3 = Ith(y,3);
888
889 s1 = Ith(yS[1],1);
890 s2 = Ith(yS[1],2);
891 s3 = Ith(yS[1],3);
892
893 l1 = Ith(yB,1);
894 l3 = Ith(yB,3);
895
896 m1 = Ith(yB,4);
897 m3 = Ith(yB,6);
898
899 Ith(qBdot,1) = -y1*y1 * l1;
900 Ith(qBdot,2) = -2.0*p2*y2*y3 * l3;
901
902 Ith(qBdot,3) = -y1*y1 * m1 - l1 * 2.0*y1*s1;
903 Ith(qBdot,4) = -2.0*p2*y2*y3 * m3 - l3 * 2.0*(p2*y3*s2 + p2*y2*s3 + y2*y3);
904
905 return(0);
906 }
907
908
909 /*
910 *--------------------------------------------------------------------
911 * PRIVATE FUNCTIONS
912 *--------------------------------------------------------------------
913 */
914
PrintFwdStats(void * cvode_mem)915 int PrintFwdStats(void *cvode_mem)
916 {
917 long int nst, nfe, nsetups, nni, ncfn, netf;
918 long int nfQe, netfQ;
919 long int nfSe, nfeS, nsetupsS, netfS;
920 long int nfQSe, netfQS;
921
922 int qlast, qcur;
923 realtype h0u, hlast, hcur, tcur;
924
925 int retval;
926
927
928 retval = CVodeGetIntegratorStats(cvode_mem, &nst, &nfe, &nsetups, &netf,
929 &qlast, &qcur,
930 &h0u, &hlast, &hcur,
931 &tcur);
932
933 retval = CVodeGetNonlinSolvStats(cvode_mem, &nni, &ncfn);
934
935 retval = CVodeGetQuadStats(cvode_mem, &nfQe, &netfQ);
936
937 retval = CVodeGetSensStats(cvode_mem, &nfSe, &nfeS, &netfS, &nsetupsS);
938
939 retval = CVodeGetQuadSensStats(cvode_mem, &nfQSe, &netfQS);
940
941
942 printf(" Number steps: %5ld\n\n", nst);
943 printf(" Function evaluations:\n");
944 printf(" f: %5ld\n fQ: %5ld\n fS: %5ld\n fQS: %5ld\n",
945 nfe, nfQe, nfSe, nfQSe);
946 printf(" Error test failures:\n");
947 printf(" netf: %5ld\n netfQ: %5ld\n netfS: %5ld\n netfQS: %5ld\n",
948 netf, netfQ, netfS, netfQS);
949 printf(" Linear solver setups:\n");
950 printf(" nsetups: %5ld\n nsetupsS: %5ld\n", nsetups, nsetupsS);
951 printf(" Nonlinear iterations:\n");
952 printf(" nni: %5ld\n", nni);
953 printf(" Convergence failures:\n");
954 printf(" ncfn: %5ld\n", ncfn);
955
956 printf("\n");
957
958 return(retval);
959 }
960
961
PrintBckStats(void * cvode_mem,int idx)962 int PrintBckStats(void *cvode_mem, int idx)
963 {
964 void *cvode_mem_bck;
965
966 long int nst, nfe, nsetups, nni, ncfn, netf;
967 long int nfQe, netfQ;
968
969 int qlast, qcur;
970 realtype h0u, hlast, hcur, tcur;
971
972 int retval;
973
974 cvode_mem_bck = CVodeGetAdjCVodeBmem(cvode_mem, idx);
975
976 retval = CVodeGetIntegratorStats(cvode_mem_bck, &nst, &nfe, &nsetups, &netf,
977 &qlast, &qcur,
978 &h0u, &hlast, &hcur,
979 &tcur);
980
981 retval = CVodeGetNonlinSolvStats(cvode_mem_bck, &nni, &ncfn);
982
983 retval = CVodeGetQuadStats(cvode_mem_bck, &nfQe, &netfQ);
984
985 printf(" Number steps: %5ld\n\n", nst);
986 printf(" Function evaluations:\n");
987 printf(" f: %5ld\n fQ: %5ld\n", nfe, nfQe);
988 printf(" Error test failures:\n");
989 printf(" netf: %5ld\n netfQ: %5ld\n", netf, netfQ);
990 printf(" Linear solver setups:\n");
991 printf(" nsetups: %5ld\n", nsetups);
992 printf(" Nonlinear iterations:\n");
993 printf(" nni: %5ld\n", nni);
994 printf(" Convergence failures:\n");
995 printf(" ncfn: %5ld\n", ncfn);
996
997 printf("\n");
998
999 return(retval);
1000 }
1001
1002 /*
1003 * Check function return value...
1004 * opt == 0 means SUNDIALS function allocates memory so check if
1005 * returned NULL pointer
1006 * opt == 1 means SUNDIALS function returns an integer value so check if
1007 * retval < 0
1008 * opt == 2 means function allocates memory so check if returned
1009 * NULL pointer
1010 */
1011
check_retval(void * returnvalue,const char * funcname,int opt)1012 static int check_retval(void *returnvalue, const char *funcname, int opt)
1013 {
1014 int *retval;
1015
1016 /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
1017 if (opt == 0 && returnvalue == NULL) {
1018 fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
1019 funcname);
1020 return(1); }
1021
1022 /* Check if retval < 0 */
1023 else if (opt == 1) {
1024 retval = (int *) returnvalue;
1025 if (*retval < 0) {
1026 fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with retval = %d\n\n",
1027 funcname, *retval);
1028 return(1); }}
1029
1030 /* Check if function returned NULL pointer - no memory allocated */
1031 else if (opt == 2 && returnvalue == NULL) {
1032 fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
1033 funcname);
1034 return(1); }
1035
1036 return(0);
1037 }
1038
1039