1 /* -----------------------------------------------------------------------------
2  * Programmer(s): David J. Gardner @ LLNL
3  * -----------------------------------------------------------------------------
4  * SUNDIALS Copyright Start
5  * Copyright (c) 2002-2021, Lawrence Livermore National Security
6  * and Southern Methodist University.
7  * All rights reserved.
8  *
9  * See the top-level LICENSE and NOTICE files for details.
10  *
11  * SPDX-License-Identifier: BSD-3-Clause
12  * SUNDIALS Copyright End
13  * -----------------------------------------------------------------------------
14  * This example solves the nonlinear system
15  *
16  * 3x - cos((y-1)z) - 1/2 = 0
17  * x^2 - 81(y-0.9)^2 + sin(z) + 1.06 = 0
18  * exp(-x(y-1)) + 20z + (10 pi - 3)/3 = 0
19  *
20  * using the accelerated fixed pointer solver in KINSOL. The nonlinear fixed
21  * point function is
22  *
23  * g1(x,y,z) = 1/3 cos((y-1)yz) + 1/6
24  * g2(x,y,z) = 1/9 sqrt(x^2 + sin(z) + 1.06) + 0.9
25  * g3(x,y,z) = -1/20 exp(-x(y-1)) - (10 pi - 3) / 60
26  *
27  * This system has the analytic solution x = 1/2, y = 1, z = -pi/6.
28  * ---------------------------------------------------------------------------*/
29 
30 #include <stdio.h>
31 #include <stdlib.h>
32 #include <math.h>
33 
34 #include "kinsol/kinsol.h"           /* access to KINSOL func., consts. */
35 #include "nvector/nvector_serial.h"  /* access to serial N_Vector       */
36 
37 /* precision specific formatting macros */
38 #if defined(SUNDIALS_EXTENDED_PRECISION)
39 #define GSYM "Lg"
40 #define ESYM "Le"
41 #define FSYM "Lf"
42 #else
43 #define GSYM "g"
44 #define ESYM "e"
45 #define FSYM "f"
46 #endif
47 
48 /* precision specific math function macros */
49 #if defined(SUNDIALS_DOUBLE_PRECISION)
50 #define ABS(x)  (fabs((x)))
51 #define SQRT(x) (sqrt((x)))
52 #define EXP(x)  (exp((x)))
53 #define SIN(x)  (sin((x)))
54 #define COS(x)  (cos((x)))
55 #elif defined(SUNDIALS_SINGLE_PRECISION)
56 #define ABS(x)  (fabsf((x)))
57 #define SQRT(x) (sqrtf((x)))
58 #define EXP(x)  (expf((x)))
59 #define SIN(x)  (sinf((x)))
60 #define COS(x)  (cosf((x)))
61 #elif defined(SUNDIALS_EXTENDED_PRECISION)
62 #define ABS(x)  (fabsl((x)))
63 #define SQRT(x) (sqrtl((x)))
64 #define EXP(x)  (expl((x)))
65 #define SIN(x)  (sinl((x)))
66 #define COS(x)  (cosl((x)))
67 #endif
68 
69 /* problem constants */
70 #define NEQ 3 /* number of equations */
71 
72 #define ZERO         RCONST(0.0)             /* real 0.0  */
73 #define PTONE        RCONST(0.1)             /* real 0.1  */
74 #define HALF         RCONST(0.5)             /* real 0.5  */
75 #define PTNINE       RCONST(0.9)             /* real 0.9  */
76 #define ONE          RCONST(1.0)             /* real 1.0  */
77 #define ONEPTZEROSIX RCONST(1.06)            /* real 1.06 */
78 #define ONEPTONE     RCONST(1.1)             /* real 1.1  */
79 #define THREE        RCONST(3.0)             /* real 3.0  */
80 #define FOUR         RCONST(4.0)             /* real 4.0  */
81 #define SIX          RCONST(6.0)             /* real 6.0  */
82 #define NINE         RCONST(9.0)             /* real 9.0  */
83 #define TEN          RCONST(10.0)            /* real 10.0 */
84 #define TWENTY       RCONST(20.0)            /* real 20.0 */
85 #define SIXTY        RCONST(60.0)            /* real 60.0 */
86 #define EIGHTYONE    RCONST(81.0)            /* real 81.0 */
87 #define PI           RCONST(3.1415926535898) /* real pi   */
88 
89 /* analytic solution */
90 #define XTRUE HALF
91 #define YTRUE ONE
92 #define ZTRUE -PI/SIX
93 
94 /* Nonlinear fixed point function */
95 static int FPFunction(N_Vector u, N_Vector f, void *user_data);
96 
97 /* Check function return values */
98 static int check_retval(void *returnvalue, const char *funcname, int opt);
99 
100 /* Check the system solution */
101 static int check_ans(N_Vector u, realtype tol);
102 
103 /* -----------------------------------------------------------------------------
104  * Main program
105  * ---------------------------------------------------------------------------*/
main(int argc,char * argv[])106 int main(int argc, char *argv[])
107 {
108   int       retval  = 0;
109   N_Vector  u       = NULL;
110   N_Vector  scale   = NULL;
111   realtype  tol     = 100 * SQRT(UNIT_ROUNDOFF);
112   long int  mxiter  = 10;
113   long int  maa     = 0;           /* no acceleration */
114   realtype  damping = RCONST(1.0); /* no damping      */
115   long int  nni, nfe;
116   realtype* data;
117   void*     kmem;
118 
119   /* Check if a acceleration/dampling values were provided */
120   if (argc > 1) maa     = (long int) atoi(argv[1]);
121   if (argc > 2) damping = (realtype) atof(argv[2]);
122 
123   /* -------------------------
124    * Print problem description
125    * ------------------------- */
126 
127   printf("Solve the nonlinear system:\n");
128   printf("    3x - cos((y-1)z) - 1/2 = 0\n");
129   printf("    x^2 - 81(y-0.9)^2 + sin(z) + 1.06 = 0\n");
130   printf("    exp(-x(y-1)) + 20z + (10 pi - 3)/3 = 0\n");
131   printf("Analytic solution:\n");
132   printf("    x = %"GSYM"\n", XTRUE);
133   printf("    y = %"GSYM"\n", YTRUE);
134   printf("    z = %"GSYM"\n", ZTRUE);
135   printf("Solution method: Anderson accelerated fixed point iteration.\n");
136   printf("    tolerance = %"GSYM"\n", tol);
137   printf("    max iters = %ld\n", mxiter);
138   printf("    accel vec = %ld\n", maa);
139   printf("    damping   = %"GSYM"\n", damping);
140 
141   /* --------------------------------------
142    * Create vectors for solution and scales
143    * -------------------------------------- */
144 
145   u = N_VNew_Serial(NEQ);
146   if (check_retval((void *)u, "N_VNew_Serial", 0)) return(1);
147 
148   scale = N_VClone(u);
149   if (check_retval((void *)scale, "N_VClone", 0)) return(1);
150 
151   /* -----------------------------------------
152    * Initialize and allocate memory for KINSOL
153    * ----------------------------------------- */
154 
155   kmem = KINCreate();
156   if (check_retval((void *)kmem, "KINCreate", 0)) return(1);
157 
158   /* Set number of prior residuals used in Anderson acceleration */
159   retval = KINSetMAA(kmem, maa);
160 
161   retval = KINInit(kmem, FPFunction, u);
162   if (check_retval(&retval, "KINInit", 1)) return(1);
163 
164   /* -------------------
165    * Set optional inputs
166    * ------------------- */
167 
168   /* Specify stopping tolerance based on residual */
169   retval = KINSetFuncNormTol(kmem, tol);
170   if (check_retval(&retval, "KINSetFuncNormTol", 1)) return(1);
171 
172   /* Set maximum number of iterations */
173   retval = KINSetNumMaxIters(kmem, mxiter);
174   if (check_retval(&retval, "KINSetNumMaxItersFuncNormTol", 1)) return(1);
175 
176   /* Set Anderson acceleration damping parameter */
177   retval = KINSetDampingAA(kmem, damping);
178   if (check_retval(&retval, "KINSetDampingAA", 1)) return(1);
179 
180   /* -------------
181    * Initial guess
182    * ------------- */
183 
184   /* Get vector data array */
185   data = N_VGetArrayPointer(u);
186   if (check_retval((void *)data, "N_VGetArrayPointer", 0)) return(1);
187 
188   data[0] =  PTONE;
189   data[1] =  PTONE;
190   data[2] = -PTONE;
191 
192   /* ----------------------------
193    * Call KINSol to solve problem
194    * ---------------------------- */
195 
196   /* No scaling used */
197   N_VConst(ONE, scale);
198 
199   /* Call main solver */
200   retval = KINSol(kmem,         /* KINSol memory block */
201                   u,            /* initial guess on input; solution vector */
202                   KIN_FP,       /* global strategy choice */
203                   scale,        /* scaling vector, for the variable cc */
204                   scale);       /* scaling vector for function values fval */
205   if (check_retval(&retval, "KINSol", 1)) return(1);
206 
207   /* ------------------------------------
208    * Get solver statistics
209    * ------------------------------------ */
210 
211   /* get solver stats */
212   retval = KINGetNumNonlinSolvIters(kmem, &nni);
213   check_retval(&retval, "KINGetNumNonlinSolvIters", 1);
214 
215   retval = KINGetNumFuncEvals(kmem, &nfe);
216   check_retval(&retval, "KINGetNumFuncEvals", 1);
217 
218   printf("\nFinal Statistics:\n");
219   printf("Number of nonlinear iterations: %6ld\n", nni);
220   printf("Number of function evaluations: %6ld\n", nfe);
221 
222   /* ------------------------------------
223    * Print solution and check error
224    * ------------------------------------ */
225 
226   /* check solution */
227   retval = check_ans(u, tol);
228 
229   /* -----------
230    * Free memory
231    * ----------- */
232 
233   N_VDestroy(u);
234   N_VDestroy(scale);
235   KINFree(&kmem);
236 
237   return(retval);
238 }
239 
240 /* -----------------------------------------------------------------------------
241  * Nonlinear system
242  *
243  * 3x - cos((y-1)z) - 1/2 = 0
244  * x^2 - 81(y-0.9)^2 + sin(z) + 1.06 = 0
245  * exp(-x(y-1)) + 20z + (10 pi - 3)/3 = 0
246  *
247  * Nonlinear fixed point function
248  *
249  * g1(x,y,z) = 1/3 cos((y-1)z) + 1/6
250  * g2(x,y,z) = 1/9 sqrt(x^2 + sin(z) + 1.06) + 0.9
251  * g3(x,y,z) = -1/20 exp(-x(y-1)) - (10 pi - 3) / 60
252  *
253  * ---------------------------------------------------------------------------*/
FPFunction(N_Vector u,N_Vector g,void * user_data)254 int FPFunction(N_Vector u, N_Vector g, void* user_data)
255 {
256   realtype* udata = NULL;
257   realtype* gdata = NULL;
258   realtype  x, y, z;
259 
260   /* Get vector data arrays */
261   udata = N_VGetArrayPointer(u);
262   if (check_retval((void*)udata, "N_VGetArrayPointer", 0)) return(-1);
263 
264   gdata = N_VGetArrayPointer(g);
265   if (check_retval((void*)gdata, "N_VGetArrayPointer", 0)) return(-1);
266 
267   x = udata[0];
268   y = udata[1];
269   z = udata[2];
270 
271   gdata[0] = (ONE/THREE) * COS((y-ONE)*z) + (ONE/SIX);
272   gdata[1] = (ONE/NINE) * SQRT(x*x + SIN(z) + ONEPTZEROSIX) + PTNINE;
273   gdata[2] = -(ONE/TWENTY) * EXP(-x*(y-ONE)) - (TEN * PI - THREE) / SIXTY;
274 
275   return(0);
276 }
277 
278 /* -----------------------------------------------------------------------------
279  * Check the solution of the nonlinear system and return PASS or FAIL
280  * ---------------------------------------------------------------------------*/
check_ans(N_Vector u,realtype tol)281 static int check_ans(N_Vector u, realtype tol)
282 {
283   realtype* data = NULL;
284   realtype  ex, ey, ez;
285 
286   /* Get vector data array */
287   data = N_VGetArrayPointer(u);
288   if (check_retval((void *)data, "N_VGetArrayPointer", 0)) return(1);
289 
290   /* print the solution */
291   printf("Computed solution:\n");
292   printf("    x = %"GSYM"\n", data[0]);
293   printf("    y = %"GSYM"\n", data[1]);
294   printf("    z = %"GSYM"\n", data[2]);
295 
296   /* solution error */
297   ex = ABS(data[0] - XTRUE);
298   ey = ABS(data[1] - YTRUE);
299   ez = ABS(data[2] - ZTRUE);
300 
301   /* print the solution error */
302   printf("Solution error:\n");
303   printf("    ex = %"GSYM"\n", ex);
304   printf("    ey = %"GSYM"\n", ey);
305   printf("    ez = %"GSYM"\n", ez);
306 
307   tol *= TEN;
308   if (ex > tol || ey > tol || ez > tol) {
309     printf("FAIL\n");
310     return(1);
311   }
312 
313   printf("PASS\n");
314   return(0);
315 }
316 
317 /* -----------------------------------------------------------------------------
318  * Check function return value
319  *   opt == 0 check if returned NULL pointer
320  *   opt == 1 check if returned a non-zero value
321  * ---------------------------------------------------------------------------*/
check_retval(void * returnvalue,const char * funcname,int opt)322 static int check_retval(void *returnvalue, const char *funcname, int opt)
323 {
324   int *errflag;
325 
326   /* Check if the function returned a NULL pointer -- no memory allocated */
327   if (opt == 0) {
328     if (returnvalue == NULL) {
329       fprintf(stderr, "\nERROR: %s() failed -- returned NULL\n\n", funcname);
330       return(1);
331     } else {
332       return(0);
333     }
334   }
335 
336   /* Check if the function returned an non-zero value -- internal failure */
337   if (opt == 1) {
338     errflag = (int *) returnvalue;
339     if (*errflag != 0) {
340       fprintf(stderr, "\nERROR: %s() failed -- returned %d\n\n", funcname, *errflag);
341       return(1);
342     } else {
343       return(0);
344     }
345   }
346 
347   /* if we make it here then opt was not 0 or 1 */
348   fprintf(stderr, "\nERROR: check_retval failed -- Invalid opt value\n\n");
349   return(1);
350 }
351