1 /* -----------------------------------------------------------------
2 * Programmer(s): Ting Yan @ SMU
3 * Based on idasRoberts_ASAi_dns.c and modified to use SuperLUMT
4 * -----------------------------------------------------------------
5 * SUNDIALS Copyright Start
6 * Copyright (c) 2002-2021, Lawrence Livermore National Security
7 * and Southern Methodist University.
8 * All rights reserved.
9 *
10 * See the top-level LICENSE and NOTICE files for details.
11 *
12 * SPDX-License-Identifier: BSD-3-Clause
13 * SUNDIALS Copyright End
14 * -----------------------------------------------------------------
15 * Adjoint sensitivity example problem.
16 *
17 * This simple example problem for IDAS, due to Robertson,
18 * is from chemical kinetics, and consists of the following three
19 * equations:
20 *
21 * dy1/dt + p1*y1 - p2*y2*y3 = 0
22 * dy2/dt - p1*y1 + p2*y2*y3 + p3*y2**2 = 0
23 * y1 + y2 + y3 - 1 = 0
24 *
25 * on the interval from t = 0.0 to t = 4.e10, with initial
26 * conditions: y1 = 1, y2 = y3 = 0.The reaction rates are: p1=0.04,
27 * p2=1e4, and p3=3e7
28 *
29 * It uses a scalar relative tolerance and a vector absolute
30 * tolerance.
31 *
32 * IDAS can also compute sensitivities with respect to
33 * the problem parameters p1, p2, and p3 of the following quantity:
34 * G = int_t0^t1 g(t,p,y) dt
35 * where
36 * g(t,p,y) = y3
37 *
38 * The gradient dG/dp is obtained as:
39 * dG/dp = int_t0^t1 (g_p - lambda^T F_p ) dt -
40 * lambda^T*F_y'*y_p | _t0^t1
41 * = int_t0^t1 (lambda^T*F_p) dt
42 * where lambda and are solutions of the adjoint system:
43 * d(lambda^T * F_y' )/dt -lambda^T F_y = -g_y
44 *
45 * During the backward integration, IDAS also evaluates G as
46 * G = - phi(t0)
47 * where
48 * d(phi)/dt = g(t,y,p)
49 * phi(t1) = 0
50 * -----------------------------------------------------------------*/
51
52 #include <stdio.h>
53 #include <stdlib.h>
54
55 #include <idas/idas.h> /* prototypes for IDA fcts., consts. */
56 #include <nvector/nvector_serial.h> /* access to serial N_Vector */
57 #include <sunmatrix/sunmatrix_sparse.h> /* access to sparse SUNMatrix */
58 #include <sunlinsol/sunlinsol_superlumt.h> /* access to SuperLUMT linear solver */
59 #include <sundials/sundials_types.h> /* defs. of realtype, sunindextype */
60 #include <sundials/sundials_math.h> /* defs. of SUNRabs, SUNRexp, etc. */
61
62 /* Accessor macros */
63
64 #define Ith(v,i) NV_Ith_S(v,i-1) /* i-th vector component i= 1..NEQ */
65
66 /* Problem Constants */
67
68 #define NEQ 3 /* number of equations */
69
70 #define RTOL RCONST(1e-06) /* scalar relative tolerance */
71
72 #define ATOL1 RCONST(1e-08) /* vector absolute tolerance components */
73 #define ATOL2 RCONST(1e-12)
74 #define ATOL3 RCONST(1e-08)
75
76 #define ATOLA RCONST(1e-08) /* absolute tolerance for adjoint vars. */
77 #define ATOLQ RCONST(1e-06) /* absolute tolerance for quadratures */
78
79 #define T0 RCONST(0.0) /* initial time */
80 #define TOUT RCONST(4e10) /* final time */
81
82 #define TB1 RCONST(50.0) /* starting point for adjoint problem */
83 #define TB2 TOUT /* starting point for adjoint problem */
84
85 #define T1B RCONST(49.0) /* for IDACalcICB */
86
87 #define STEPS 100 /* number of steps between check points */
88
89 #define NP 3 /* number of problem parameters */
90
91 #define ONE RCONST(1.0)
92 #define ZERO RCONST(0.0)
93
94
95 /* Type : UserData */
96
97 typedef struct {
98 realtype p[3];
99 } *UserData;
100
101 /* Prototypes of user-supplied functions */
102
103 static int res(realtype t, N_Vector yy, N_Vector yp,
104 N_Vector resval, void *user_data);
105 static int Jac(realtype t, realtype cj,
106 N_Vector yy, N_Vector yp, N_Vector resvec,
107 SUNMatrix JJ, void *user_data,
108 N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
109
110 static int rhsQ(realtype t, N_Vector yy, N_Vector yp, N_Vector qdot, void *user_data);
111 static int ewt(N_Vector y, N_Vector w, void *user_data);
112
113 static int resB(realtype tt,
114 N_Vector yy, N_Vector yp,
115 N_Vector yyB, N_Vector ypB, N_Vector rrB,
116 void *user_dataB);
117
118 static int JacB(realtype tt, realtype cjB,
119 N_Vector yy, N_Vector yp,
120 N_Vector yyB, N_Vector ypB, N_Vector rrB,
121 SUNMatrix JB, void *user_data,
122 N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B);
123
124
125 static int rhsQB(realtype tt,
126 N_Vector yy, N_Vector yp,
127 N_Vector yyB, N_Vector ypB,
128 N_Vector rrQB, void *user_dataB);
129
130 /* Prototypes of private functions */
131 static void PrintOutput(realtype tfinal, N_Vector yB, N_Vector ypB, N_Vector qB);
132 static int check_retval(void *returnvalue, char *funcname, int opt);
133
134 /*
135 *--------------------------------------------------------------------
136 * MAIN PROGRAM
137 *--------------------------------------------------------------------
138 */
139
main(int argc,char * argv[])140 int main(int argc, char *argv[])
141 {
142 UserData data;
143
144 void *ida_mem;
145 SUNMatrix A, AB;
146 SUNLinearSolver LS, LSB;
147
148 realtype reltolQ, abstolQ;
149 N_Vector yy, yp, q;
150 N_Vector yyTB1, ypTB1;
151 N_Vector id;
152
153 int steps;
154
155 int indexB;
156
157 realtype reltolB, abstolB, abstolQB;
158 N_Vector yB, ypB, qB;
159 realtype time;
160 int retval, nthreads, nnz, ncheck;
161
162 IDAadjCheckPointRec *ckpnt;
163
164 long int nst, nstB;
165
166 data = NULL;
167 ckpnt = NULL;
168 ida_mem = NULL;
169 yy = yp = yB = qB = NULL;
170 A = AB = NULL;
171 LS = LSB = NULL;
172
173 /* Print problem description */
174 printf("\nAdjoint Sensitivity Example for Chemical Kinetics\n");
175 printf("-------------------------------------------------\n\n");
176 printf("DAE: dy1/dt + p1*y1 - p2*y2*y3 = 0\n");
177 printf(" dy2/dt - p1*y1 + p2*y2*y3 + p3*(y2)^2 = 0\n");
178 printf(" y1 + y2 + y3 = 0\n\n");
179 printf("Find dG/dp for\n");
180 printf(" G = int_t0^tB0 g(t,p,y) dt\n");
181 printf(" g(t,p,y) = y3\n\n\n");
182
183 /* User data structure */
184 data = (UserData) malloc(sizeof *data);
185 if (check_retval((void *)data, "malloc", 2)) return(1);
186 data->p[0] = RCONST(0.04);
187 data->p[1] = RCONST(1.0e4);
188 data->p[2] = RCONST(3.0e7);
189
190 /* Initialize y */
191 yy = N_VNew_Serial(NEQ);
192 if (check_retval((void *)yy, "N_VNew_Serial", 0)) return(1);
193 Ith(yy,1) = ONE;
194 Ith(yy,2) = ZERO;
195 Ith(yy,3) = ZERO;
196
197 /* Initialize yprime */
198 yp = N_VNew_Serial(NEQ);
199 if (check_retval((void *)yp, "N_VNew_Serial", 0)) return(1);
200 Ith(yp,1) = RCONST(-0.04);
201 Ith(yp,2) = RCONST( 0.04);
202 Ith(yp,3) = ZERO;
203
204 /* Initialize q */
205 q = N_VNew_Serial(1);
206 if (check_retval((void *)q, "N_VNew_Serial", 0)) return(1);
207 Ith(q,1) = ZERO;
208
209 /* Set the scalar realtive and absolute tolerances reltolQ and abstolQ */
210 reltolQ = RTOL;
211 abstolQ = ATOLQ;
212
213 /* Create and allocate IDAS memory for forward run */
214 printf("Create and allocate IDAS memory for forward runs\n");
215
216 ida_mem = IDACreate();
217 if (check_retval((void *)ida_mem, "IDACreate", 0)) return(1);
218
219 retval = IDAInit(ida_mem, res, T0, yy, yp);
220 if (check_retval(&retval, "IDAInit", 1)) return(1);
221
222 retval = IDAWFtolerances(ida_mem, ewt);
223 if (check_retval(&retval, "IDAWFtolerances", 1)) return(1);
224
225 retval = IDASetUserData(ida_mem, data);
226 if (check_retval(&retval, "IDASetUserData", 1)) return(1);
227
228 /* Create sparse SUNMatrix for use in linear solves */
229 nnz = NEQ * NEQ;
230 A = SUNSparseMatrix(NEQ, NEQ, nnz, CSC_MAT);
231 if(check_retval((void *)A, "SUNSparseMatrix", 0)) return(1);
232
233 /* Create SuperLUMT SUNLinearSolver object (one thread) */
234 nthreads = 1;
235 LS = SUNLinSol_SuperLUMT(yy, A, nthreads);
236 if(check_retval((void *)LS, "SUNLinSol_SuperLUMT", 0)) return(1);
237
238 /* Attach the matrix and linear solver */
239 retval = IDASetLinearSolver(ida_mem, LS, A);
240 if(check_retval(&retval, "IDASetLinearSolver", 1)) return(1);
241
242 /* Set the user-supplied Jacobian routine */
243 retval = IDASetJacFn(ida_mem, Jac);
244 if(check_retval(&retval, "IDASetJacFn", 1)) return(1);
245
246 /* Setup quadrature integration */
247 retval = IDAQuadInit(ida_mem, rhsQ, q);
248 if (check_retval(&retval, "IDAQuadInit", 1)) return(1);
249
250 retval = IDAQuadSStolerances(ida_mem, reltolQ, abstolQ);
251 if (check_retval(&retval, "IDAQuadSStolerances", 1)) return(1);
252
253 retval = IDASetQuadErrCon(ida_mem, SUNTRUE);
254 if (check_retval(&retval, "IDASetQuadErrCon", 1)) return(1);
255
256 /* Call IDASetMaxNumSteps to set the maximum number of steps the
257 * solver will take in an attempt to reach the next output time
258 * during forward integration. */
259 retval = IDASetMaxNumSteps(ida_mem, 2500);
260 if (check_retval(&retval, "IDASetMaxNumSteps", 1)) return(1);
261
262 /* Allocate global memory */
263
264 steps = STEPS;
265 retval = IDAAdjInit(ida_mem, steps, IDA_HERMITE);
266 /*retval = IDAAdjInit(ida_mem, steps, IDA_POLYNOMIAL);*/
267 if (check_retval(&retval, "IDAAdjInit", 1)) return(1);
268
269 /* Perform forward run */
270 printf("Forward integration ... ");
271
272 /* Integrate till TB1 and get the solution (y, y') at that time. */
273 retval = IDASolveF(ida_mem, TB1, &time, yy, yp, IDA_NORMAL, &ncheck);
274 if (check_retval(&retval, "IDASolveF", 1)) return(1);
275
276 yyTB1 = N_VClone(yy);
277 ypTB1 = N_VClone(yp);
278 /* Save the states at t=TB1. */
279 N_VScale(ONE, yy, yyTB1);
280 N_VScale(ONE, yp, ypTB1);
281
282 /* Continue integrating till TOUT is reached. */
283 retval = IDASolveF(ida_mem, TOUT, &time, yy, yp, IDA_NORMAL, &ncheck);
284 if (check_retval(&retval, "IDASolveF", 1)) return(1);
285
286 retval = IDAGetNumSteps(ida_mem, &nst);
287 if (check_retval(&retval, "IDAGetNumSteps", 1)) return(1);
288
289 printf("done ( nst = %ld )\n",nst);
290
291 retval = IDAGetQuad(ida_mem, &time, q);
292 if (check_retval(&retval, "IDAGetQuad", 1)) return(1);
293
294 printf("--------------------------------------------------------\n");
295 #if defined(SUNDIALS_EXTENDED_PRECISION)
296 printf("G: %12.4Le \n",Ith(q,1));
297 #elif defined(SUNDIALS_DOUBLE_PRECISION)
298 printf("G: %12.4e \n",Ith(q,1));
299 #else
300 printf("G: %12.4e \n",Ith(q,1));
301 #endif
302 printf("--------------------------------------------------------\n\n");
303
304 /* Test check point linked list
305 (uncomment next block to print check point information) */
306
307 /*
308 {
309 int i;
310
311 printf("\nList of Check Points (ncheck = %d)\n\n", ncheck);
312 ckpnt = (IDAadjCheckPointRec *) malloc ( (ncheck+1)*sizeof(IDAadjCheckPointRec));
313 IDAGetAdjCheckPointsInfo(ida_mem, ckpnt);
314 for (i=0;i<=ncheck;i++) {
315 printf("Address: %p\n",ckpnt[i].my_addr);
316 printf("Next: %p\n",ckpnt[i].next_addr);
317 printf("Time interval: %le %le\n",ckpnt[i].t0, ckpnt[i].t1);
318 printf("Step number: %ld\n",ckpnt[i].nstep);
319 printf("Order: %d\n",ckpnt[i].order);
320 printf("Step size: %le\n",ckpnt[i].step);
321 printf("\n");
322 }
323
324 }
325 */
326
327
328 /* Create BACKWARD problem. */
329
330 /* Allocate yB (i.e. lambda_0). */
331 yB = N_VNew_Serial(NEQ);
332 if (check_retval((void *)yB, "N_VNew_Serial", 0)) return(1);
333
334 /* Consistently initialize yB. */
335 Ith(yB,1) = ZERO;
336 Ith(yB,2) = ZERO;
337 Ith(yB,3) = ONE;
338
339
340 /* Allocate ypB (i.e. lambda'_0). */
341 ypB = N_VNew_Serial(NEQ);
342 if (check_retval((void *)ypB, "N_VNew_Serial", 0)) return(1);
343
344 /* Consistently initialize ypB. */
345 Ith(ypB,1) = ONE;
346 Ith(ypB,2) = ONE;
347 Ith(ypB,3) = ZERO;
348
349
350 /* Set the scalar relative tolerance reltolB */
351 reltolB = RTOL;
352
353 /* Set the scalar absolute tolerance abstolB */
354 abstolB = ATOLA;
355
356 /* Set the scalar absolute tolerance abstolQB */
357 abstolQB = ATOLQ;
358
359 /* Create and allocate IDAS memory for backward run */
360 printf("Create and allocate IDAS memory for backward run\n");
361
362 retval = IDACreateB(ida_mem, &indexB);
363 if (check_retval(&retval, "IDACreateB", 1)) return(1);
364
365 retval = IDAInitB(ida_mem, indexB, resB, TB2, yB, ypB);
366 if (check_retval(&retval, "IDAInitB", 1)) return(1);
367
368 retval = IDASStolerancesB(ida_mem, indexB, reltolB, abstolB);
369 if (check_retval(&retval, "IDASStolerancesB", 1)) return(1);
370
371 retval = IDASetUserDataB(ida_mem, indexB, data);
372 if (check_retval(&retval, "IDASetUserDataB", 1)) return(1);
373
374 retval = IDASetMaxNumStepsB(ida_mem, indexB, 1000);
375 if (check_retval(&retval, "IDASetMaxNumStepsB", 1)) return(1);
376
377 /* Create sparse SUNMatrix for use in linear solves */
378 AB = SUNSparseMatrix(NEQ, NEQ, nnz, CSC_MAT);
379 if(check_retval((void *)AB, "SUNSparseMatrix", 0)) return(1);
380
381 /* Create SuperLUMT SUNLinearSolver object (one thread) */
382 LSB = SUNLinSol_SuperLUMT(yB, AB, nthreads);
383 if(check_retval((void *)LSB, "SUNLinSol_SuperLUMT", 0)) return(1);
384
385 /* Attach the matrix and linear solver */
386 retval = IDASetLinearSolverB(ida_mem, indexB, LSB, AB);
387 if(check_retval(&retval, "IDASetLinearSolverB", 1)) return(1);
388
389 /* Set the user-supplied Jacobian routine */
390 retval = IDASetJacFnB(ida_mem, indexB, JacB);
391 if(check_retval(&retval, "IDASetJacFnB", 1)) return(1);
392
393 /* Quadrature for backward problem. */
394
395 /* Initialize qB */
396 qB = N_VNew_Serial(NP);
397 if (check_retval((void *)qB, "N_VNew", 0)) return(1);
398 Ith(qB,1) = ZERO;
399 Ith(qB,2) = ZERO;
400 Ith(qB,3) = ZERO;
401
402 retval = IDAQuadInitB(ida_mem, indexB, rhsQB, qB);
403 if (check_retval(&retval, "IDAQuadInitB", 1)) return(1);
404
405 retval = IDAQuadSStolerancesB(ida_mem, indexB, reltolB, abstolQB);
406 if (check_retval(&retval, "IDAQuadSStolerancesB", 1)) return(1);
407
408 /* Include quadratures in error control. */
409 retval = IDASetQuadErrConB(ida_mem, indexB, SUNTRUE);
410 if (check_retval(&retval, "IDASetQuadErrConB", 1)) return(1);
411
412
413 /* Backward Integration */
414 printf("Backward integration ... ");
415
416 retval = IDASolveB(ida_mem, T0, IDA_NORMAL);
417 if (check_retval(&retval, "IDASolveB", 1)) return(1);
418
419 IDAGetNumSteps(IDAGetAdjIDABmem(ida_mem, indexB), &nstB);
420 printf("done ( nst = %ld )\n", nstB);
421
422 retval = IDAGetB(ida_mem, indexB, &time, yB, ypB);
423 if (check_retval(&retval, "IDAGetB", 1)) return(1);
424
425 retval = IDAGetQuadB(ida_mem, indexB, &time, qB);
426 if (check_retval(&retval, "IDAGetB", 1)) return(1);
427
428 PrintOutput(TB2, yB, ypB, qB);
429
430
431 /* Reinitialize backward phase and start from a different time (TB1). */
432 printf("Re-initialize IDAS memory for backward run\n");
433
434 /* Both algebraic part from y and the entire y' are computed by IDACalcIC. */
435 Ith(yB,1) = ZERO;
436 Ith(yB,2) = ZERO;
437 Ith(yB,3) = RCONST(0.50); /* not consistent */
438
439 /* Rough guess for ypB. */
440 Ith(ypB,1) = RCONST(0.80);
441 Ith(ypB,2) = RCONST(0.75);
442 Ith(ypB,3) = ZERO;
443
444 /* Initialize qB */
445 Ith(qB,1) = ZERO;
446 Ith(qB,2) = ZERO;
447 Ith(qB,3) = ZERO;
448
449 retval = IDAReInitB(ida_mem, indexB, TB1, yB, ypB);
450 if (check_retval(&retval, "IDAReInitB", 1)) return(1);
451
452 /* Also reinitialize quadratures. */
453 retval = IDAQuadReInitB(ida_mem, indexB, qB);
454 if (check_retval(&retval, "IDAQuadReInitB", 1)) return(1);
455
456 /* Use IDACalcICB to compute consistent initial conditions
457 for this backward problem. */
458
459 id = N_VNew_Serial(NEQ);
460 Ith(id,1) = 1.0;
461 Ith(id,2) = 1.0;
462 Ith(id,3) = 0.0;
463
464 /* Specify which variables are differential (1) and which algebraic (0).*/
465 retval = IDASetIdB(ida_mem, indexB, id);
466 if (check_retval(&retval, "IDASetId", 1)) return(1);
467
468 retval = IDACalcICB(ida_mem, indexB, T1B, yyTB1, ypTB1);
469 if (check_retval(&retval, "IDACalcICB", 1)) return(1);
470
471 /* Get the consistent IC found by IDAS. */
472 retval = IDAGetConsistentICB(ida_mem, indexB, yB, ypB);
473 if (check_retval(&retval, "IDAGetConsistentICB", 1)) return(1);
474
475 printf("Backward integration ... ");
476
477 retval = IDASolveB(ida_mem, T0, IDA_NORMAL);
478 if (check_retval(&retval, "IDASolveB", 1)) return(1);
479
480 IDAGetNumSteps(IDAGetAdjIDABmem(ida_mem, indexB), &nstB);
481 printf("done ( nst = %ld )\n", nstB);
482
483 retval = IDAGetB(ida_mem, indexB, &time, yB, ypB);
484 if (check_retval(&retval, "IDAGetB", 1)) return(1);
485
486 retval = IDAGetQuadB(ida_mem, indexB, &time, qB);
487 if (check_retval(&retval, "IDAGetQuadB", 1)) return(1);
488
489 PrintOutput(TB1, yB, ypB, qB);
490
491 /* Free any memory used.*/
492
493 printf("Free memory\n\n");
494
495 IDAFree(&ida_mem);
496 SUNLinSolFree(LS);
497 SUNMatDestroy(A);
498 SUNLinSolFree(LSB);
499 SUNMatDestroy(AB);
500 N_VDestroy(yy);
501 N_VDestroy(yp);
502 N_VDestroy(q);
503 N_VDestroy(yB);
504 N_VDestroy(ypB);
505 N_VDestroy(qB);
506 N_VDestroy(id);
507 N_VDestroy(yyTB1);
508 N_VDestroy(ypTB1);
509
510 if (ckpnt != NULL) free(ckpnt);
511 free(data);
512
513 return(0);
514
515 }
516
517 /*
518 *--------------------------------------------------------------------
519 * FUNCTIONS CALLED BY IDAS
520 *--------------------------------------------------------------------
521 */
522
523 /*
524 * f routine. Compute f(t,y).
525 */
526
res(realtype t,N_Vector yy,N_Vector yp,N_Vector resval,void * user_data)527 static int res(realtype t, N_Vector yy, N_Vector yp, N_Vector resval, void *user_data)
528 {
529 realtype y1, y2, y3,yp1, yp2, *rval;
530 UserData data;
531 realtype p1, p2, p3;
532
533 y1 = Ith(yy,1); y2 = Ith(yy,2); y3 = Ith(yy,3);
534 yp1 = Ith(yp,1); yp2 = Ith(yp,2);
535 rval = N_VGetArrayPointer(resval);
536
537 data = (UserData) user_data;
538 p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
539
540 rval[0] = p1*y1-p2*y2*y3;
541 rval[1] = -rval[0] + p3*y2*y2 + yp2;
542 rval[0]+= yp1;
543 rval[2] = y1+y2+y3-1;
544
545 return(0);
546 }
547
548 /*
549 * Jacobian routine. Compute J(t,y).
550 */
551
Jac(realtype t,realtype cj,N_Vector yy,N_Vector yp,N_Vector resvec,SUNMatrix JJ,void * user_data,N_Vector tmp1,N_Vector tmp2,N_Vector tmp3)552 static int Jac(realtype t, realtype cj,
553 N_Vector yy, N_Vector yp, N_Vector resvec,
554 SUNMatrix JJ, void *user_data,
555 N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
556 {
557 realtype *yval;
558 sunindextype *colptrs = SUNSparseMatrix_IndexPointers(JJ);
559 sunindextype *rowvals = SUNSparseMatrix_IndexValues(JJ);
560 realtype *data = SUNSparseMatrix_Data(JJ);
561
562 UserData userdata;
563 realtype p1, p2, p3;
564
565 yval = N_VGetArrayPointer(yy);
566
567 userdata = (UserData) user_data;
568 p1 = userdata->p[0]; p2 = userdata->p[1]; p3 = userdata->p[2];
569
570 SUNMatZero(JJ);
571
572 colptrs[0] = 0;
573 colptrs[1] = 3;
574 colptrs[2] = 6;
575 colptrs[3] = 9;
576
577 /* column 0 */
578 data[0] = p1+cj;
579 rowvals[0] = 0;
580 data[1] = -p1;
581 rowvals[1] = 1;
582 data[2] = ONE;
583 rowvals[2] = 2;
584
585 /* column 1 */
586 data[3] = -p2*yval[2];
587 rowvals[3] = 0;
588 data[4] = p2*yval[2]+2*p3*yval[1]+cj;
589 rowvals[4] = 1;
590 data[5] = ONE;
591 rowvals[5] = 2;
592
593 /* column 2 */
594 data[6] = -p2*yval[1];
595 rowvals[6] = 0;
596 data[7] = p2*yval[1];
597 rowvals[7] = 1;
598 data[8] = ONE;
599 rowvals[8] = 2;
600
601 return(0);
602 }
603
604 /*
605 * rhsQ routine. Compute fQ(t,y).
606 */
607
rhsQ(realtype t,N_Vector yy,N_Vector yp,N_Vector qdot,void * user_data)608 static int rhsQ(realtype t, N_Vector yy, N_Vector yp, N_Vector qdot, void *user_data)
609 {
610 Ith(qdot,1) = Ith(yy,3);
611 return(0);
612 }
613
614 /*
615 * EwtSet function. Computes the error weights at the current solution.
616 */
617
ewt(N_Vector y,N_Vector w,void * user_data)618 static int ewt(N_Vector y, N_Vector w, void *user_data)
619 {
620 int i;
621 realtype yy, ww, rtol, atol[3];
622
623 rtol = RTOL;
624 atol[0] = ATOL1;
625 atol[1] = ATOL2;
626 atol[2] = ATOL3;
627
628 for (i=1; i<=3; i++) {
629 yy = Ith(y,i);
630 ww = rtol * SUNRabs(yy) + atol[i-1];
631 if (ww <= 0.0) return (-1);
632 Ith(w,i) = 1.0/ww;
633 }
634
635 return(0);
636 }
637
638
639 /*
640 * resB routine.
641 */
642
resB(realtype tt,N_Vector yy,N_Vector yp,N_Vector yyB,N_Vector ypB,N_Vector rrB,void * user_dataB)643 static int resB(realtype tt,
644 N_Vector yy, N_Vector yp,
645 N_Vector yyB, N_Vector ypB, N_Vector rrB,
646 void *user_dataB)
647 {
648 UserData data;
649 realtype y2, y3;
650 realtype p1, p2, p3;
651 realtype l1, l2, l3;
652 realtype lp1, lp2;
653 realtype l21;
654
655 data = (UserData) user_dataB;
656
657 /* The p vector */
658 p1 = data->p[0]; p2 = data->p[1]; p3 = data->p[2];
659
660 /* The y vector */
661 y2 = Ith(yy,2); y3 = Ith(yy,3);
662
663 /* The lambda vector */
664 l1 = Ith(yyB,1); l2 = Ith(yyB,2); l3 = Ith(yyB,3);
665
666 /* The lambda dot vector */
667 lp1 = Ith(ypB,1); lp2 = Ith(ypB,2);
668
669 /* Temporary variables */
670 l21 = l2-l1;
671
672 /* Load residual. */
673 Ith(rrB,1) = lp1 + p1*l21 - l3;
674 Ith(rrB,2) = lp2 - p2*y3*l21 - RCONST(2.0)*p3*y2*l2-l3;
675 Ith(rrB,3) = - p2*y2*l21 -l3 + RCONST(1.0);
676
677 return(0);
678 }
679
680 /*Jacobian for backward problem. */
JacB(realtype tt,realtype cjB,N_Vector yy,N_Vector yp,N_Vector yyB,N_Vector ypB,N_Vector rrB,SUNMatrix JB,void * user_data,N_Vector tmp1B,N_Vector tmp2B,N_Vector tmp3B)681 static int JacB(realtype tt, realtype cjB,
682 N_Vector yy, N_Vector yp,
683 N_Vector yyB, N_Vector ypB, N_Vector rrB,
684 SUNMatrix JB, void *user_data,
685 N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B)
686 {
687 realtype *yvalB;
688 sunindextype *colptrsB = SUNSparseMatrix_IndexPointers(JB);
689 sunindextype *rowvalsB = SUNSparseMatrix_IndexValues(JB);
690 realtype *dataB = SUNSparseMatrix_Data(JB);
691
692 UserData userdata;
693 realtype p1, p2, p3;
694
695 yvalB = N_VGetArrayPointer(yy);
696
697 userdata = (UserData) user_data;
698 p1 = userdata->p[0]; p2 = userdata->p[1]; p3 = userdata->p[2];
699
700 SUNMatZero(JB);
701
702 colptrsB[0] = 0;
703 colptrsB[1] = 3;
704 colptrsB[2] = 6;
705 colptrsB[3] = 9;
706
707 /* column 0 */
708 dataB[0] = -p1+cjB;
709 rowvalsB[0] = 0;
710 dataB[1] = p2*yvalB[2];
711 rowvalsB[1] = 1;
712 dataB[2] = p2*yvalB[1];
713 rowvalsB[2] = 2;
714
715 /* column 1 */
716 dataB[3] = p1;
717 rowvalsB[3] = 0;
718 dataB[4] = -(p2*yvalB[2]+RCONST(2.0)*p3*yvalB[1])+cjB;
719 rowvalsB[4] = 1;
720 dataB[5] = -p2*yvalB[1];
721 rowvalsB[5] = 2;
722
723 /* column 2 */
724 dataB[6] = -ONE;
725 rowvalsB[6] = 0;
726 dataB[7] = -ONE;
727 rowvalsB[7] = 1;
728 dataB[8] = -ONE;
729 rowvalsB[8] = 2;
730
731 return(0);
732 }
733
rhsQB(realtype tt,N_Vector yy,N_Vector yp,N_Vector yyB,N_Vector ypB,N_Vector rrQB,void * user_dataB)734 static int rhsQB(realtype tt,
735 N_Vector yy, N_Vector yp,
736 N_Vector yyB, N_Vector ypB,
737 N_Vector rrQB, void *user_dataB)
738 {
739 realtype y1, y2, y3;
740 realtype l1, l2;
741 realtype l21;
742
743 /* The y vector */
744 y1 = Ith(yy,1); y2 = Ith(yy,2); y3 = Ith(yy,3);
745
746 /* The lambda vector */
747 l1 = Ith(yyB,1); l2 = Ith(yyB,2);
748
749 /* Temporary variables */
750 l21 = l2-l1;
751
752 Ith(rrQB,1) = y1*l21;
753 Ith(rrQB,2) = -y3*y2*l21;
754 Ith(rrQB,3) = -y2*y2*l2;
755
756 return(0);
757 }
758
759
760 /*
761 *--------------------------------------------------------------------
762 * PRIVATE FUNCTIONS
763 *--------------------------------------------------------------------
764 */
765
766 /*
767 * Print results after backward integration
768 */
769
PrintOutput(realtype tfinal,N_Vector yB,N_Vector ypB,N_Vector qB)770 static void PrintOutput(realtype tfinal, N_Vector yB, N_Vector ypB, N_Vector qB)
771 {
772 printf("--------------------------------------------------------\n");
773 #if defined(SUNDIALS_EXTENDED_PRECISION)
774 printf("tB0: %12.4Le\n",tfinal);
775 printf("dG/dp: %12.4Le %12.4Le %12.4Le\n",
776 -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
777 printf("lambda(t0): %12.4Le %12.4Le %12.4Le\n",
778 Ith(yB,1), Ith(yB,2), Ith(yB,3));
779 #elif defined(SUNDIALS_DOUBLE_PRECISION)
780 printf("tB0: %12.4e\n",tfinal);
781 printf("dG/dp: %12.4e %12.4e %12.4e\n",
782 -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
783 printf("lambda(t0): %12.4e %12.4e %12.4e\n",
784 Ith(yB,1), Ith(yB,2), Ith(yB,3));
785 #else
786 printf("tB0: %12.4e\n",tfinal);
787 printf("dG/dp: %12.4e %12.4e %12.4e\n",
788 -Ith(qB,1), -Ith(qB,2), -Ith(qB,3));
789 printf("lambda(t0): %12.4e %12.4e %12.4e\n",
790 Ith(yB,1), Ith(yB,2), Ith(yB,3));
791 #endif
792 printf("--------------------------------------------------------\n\n");
793 }
794
795 /*
796 * Check function return value.
797 * opt == 0 means SUNDIALS function allocates memory so check if
798 * returned NULL pointer
799 * opt == 1 means SUNDIALS function returns an integer value so check if
800 * retval < 0
801 * opt == 2 means function allocates memory so check if returned
802 * NULL pointer
803 */
804
check_retval(void * returnvalue,char * funcname,int opt)805 static int check_retval(void *returnvalue, char *funcname, int opt)
806 {
807 int *retval;
808
809 /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
810 if (opt == 0 && returnvalue == NULL) {
811 fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
812 funcname);
813 return(1); }
814
815 /* Check if retval < 0 */
816 else if (opt == 1) {
817 retval = (int *) returnvalue;
818 if (*retval < 0) {
819 fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with retval = %d\n\n",
820 funcname, *retval);
821 return(1); }}
822
823 /* Check if function returned NULL pointer - no memory allocated */
824 else if (opt == 2 && returnvalue == NULL) {
825 fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
826 funcname);
827 return(1); }
828
829 return(0);
830 }
831