1 /* -----------------------------------------------------------------------------
2  * Programmer(s): David J. Gardner @ LLNL
3  * -----------------------------------------------------------------------------
4  * SUNDIALS Copyright Start
5  * Copyright (c) 2002-2021, Lawrence Livermore National Security
6  * and Southern Methodist University.
7  * All rights reserved.
8  *
9  * See the top-level LICENSE and NOTICE files for details.
10  *
11  * SPDX-License-Identifier: BSD-3-Clause
12  * SUNDIALS Copyright End
13  * -----------------------------------------------------------------------------
14  * This the implementation file for the CVODES nonlinear solver interface.
15  * ---------------------------------------------------------------------------*/
16 
17 #include "cvodes_impl.h"
18 #include "sundials/sundials_math.h"
19 #include "sundials/sundials_nvector_senswrapper.h"
20 
21 /* constant macros */
22 #define ONE RCONST(1.0)
23 
24 /* private functions */
25 static int cvNlsResidual(N_Vector ycor, N_Vector res, void* cvode_mem);
26 static int cvNlsFPFunction(N_Vector ycor, N_Vector res, void* cvode_mem);
27 
28 static int cvNlsLSetup(booleantype jbad, booleantype* jcur, void* cvode_mem);
29 static int cvNlsLSolve(N_Vector delta, void* cvode_mem);
30 static int cvNlsConvTest(SUNNonlinearSolver NLS, N_Vector ycor, N_Vector del,
31                          realtype tol, N_Vector ewt, void* cvode_mem);
32 
33 /* -----------------------------------------------------------------------------
34  * Exported functions
35  * ---------------------------------------------------------------------------*/
36 
CVodeSetNonlinearSolver(void * cvode_mem,SUNNonlinearSolver NLS)37 int CVodeSetNonlinearSolver(void *cvode_mem, SUNNonlinearSolver NLS)
38 {
39   CVodeMem cv_mem;
40   int retval;
41 
42   /* Return immediately if CVode memory is NULL */
43   if (cvode_mem == NULL) {
44     cvProcessError(NULL, CV_MEM_NULL, "CVODES", "CVodeSetNonlinearSolver", MSGCV_NO_MEM);
45     return(CV_MEM_NULL);
46   }
47   cv_mem = (CVodeMem) cvode_mem;
48 
49   /* Return immediately if NLS memory is NULL */
50   if (NLS == NULL) {
51     cvProcessError(NULL, CV_ILL_INPUT, "CVODES", "CVodeSetNonlinearSolver",
52                    "NLS must be non-NULL");
53     return (CV_ILL_INPUT);
54   }
55 
56   /* check for required nonlinear solver functions */
57   if ( NLS->ops->gettype    == NULL ||
58        NLS->ops->solve      == NULL ||
59        NLS->ops->setsysfn   == NULL ) {
60     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES", "CVodeSetNonlinearSolver",
61                    "NLS does not support required operations");
62     return(CV_ILL_INPUT);
63   }
64 
65   /* free any existing nonlinear solver */
66   if ((cv_mem->NLS != NULL) && (cv_mem->ownNLS))
67     retval = SUNNonlinSolFree(cv_mem->NLS);
68 
69   /* set SUNNonlinearSolver pointer */
70   cv_mem->NLS = NLS;
71 
72   /* Set NLS ownership flag. If this function was called to attach the default
73      NLS, CVODE will set the flag to SUNTRUE after this function returns. */
74   cv_mem->ownNLS = SUNFALSE;
75 
76   /* set the nonlinear system function */
77   if (SUNNonlinSolGetType(NLS) == SUNNONLINEARSOLVER_ROOTFIND) {
78     retval = SUNNonlinSolSetSysFn(cv_mem->NLS, cvNlsResidual);
79   } else if (SUNNonlinSolGetType(NLS) ==  SUNNONLINEARSOLVER_FIXEDPOINT) {
80     retval = SUNNonlinSolSetSysFn(cv_mem->NLS, cvNlsFPFunction);
81   } else {
82     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES", "CVodeSetNonlinearSolver",
83                    "Invalid nonlinear solver type");
84     return(CV_ILL_INPUT);
85   }
86 
87   if (retval != CV_SUCCESS) {
88     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES", "CVodeSetNonlinearSolver",
89                    "Setting nonlinear system function failed");
90     return(CV_ILL_INPUT);
91   }
92 
93   /* set convergence test function */
94   retval = SUNNonlinSolSetConvTestFn(cv_mem->NLS, cvNlsConvTest, cvode_mem);
95   if (retval != CV_SUCCESS) {
96     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES", "CVodeSetNonlinearSolver",
97                    "Setting convergence test function failed");
98     return(CV_ILL_INPUT);
99   }
100 
101   /* set max allowed nonlinear iterations */
102   retval = SUNNonlinSolSetMaxIters(cv_mem->NLS, NLS_MAXCOR);
103   if (retval != CV_SUCCESS) {
104     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES", "CVodeSetNonlinearSolver",
105                    "Setting maximum number of nonlinear iterations failed");
106     return(CV_ILL_INPUT);
107   }
108 
109   /* Reset the acnrmcur flag to SUNFALSE */
110   cv_mem->cv_acnrmcur = SUNFALSE;
111 
112   return(CV_SUCCESS);
113 }
114 
115 
116 /*---------------------------------------------------------------
117   CVodeGetNonlinearSystemData:
118 
119   This routine provides access to the relevant data needed to
120   compute the nonlinear system function.
121   ---------------------------------------------------------------*/
CVodeGetNonlinearSystemData(void * cvode_mem,realtype * tcur,N_Vector * ypred,N_Vector * yn,N_Vector * fn,realtype * gamma,realtype * rl1,N_Vector * zn1,void ** user_data)122 int CVodeGetNonlinearSystemData(void *cvode_mem, realtype *tcur,
123                                 N_Vector *ypred, N_Vector *yn,
124                                 N_Vector *fn, realtype *gamma,
125                                 realtype *rl1, N_Vector *zn1,
126                                 void **user_data)
127 {
128   CVodeMem cv_mem;
129 
130   if (cvode_mem == NULL) {
131     cvProcessError(NULL, CV_MEM_NULL, "CVODES", "CVodeGetNonlinearSystemData",
132                    MSGCV_NO_MEM);
133     return(CV_MEM_NULL);
134   }
135 
136   cv_mem = (CVodeMem) cvode_mem;
137 
138   *tcur      = cv_mem->cv_tn;
139   *ypred     = cv_mem->cv_zn[0];
140   *yn        = cv_mem->cv_y;
141   *fn        = cv_mem->cv_ftemp;
142   *gamma     = cv_mem->cv_gamma;
143   *rl1       = cv_mem->cv_rl1;
144   *zn1       = cv_mem->cv_zn[1];
145   *user_data = cv_mem->cv_user_data;
146 
147   return(CV_SUCCESS);
148 }
149 
150 
151 /* -----------------------------------------------------------------------------
152  * Private functions
153  * ---------------------------------------------------------------------------*/
154 
155 
cvNlsInit(CVodeMem cvode_mem)156 int cvNlsInit(CVodeMem cvode_mem)
157 {
158   int retval;
159 
160   /* set the linear solver setup wrapper function */
161   if (cvode_mem->cv_lsetup)
162     retval = SUNNonlinSolSetLSetupFn(cvode_mem->NLS, cvNlsLSetup);
163   else
164     retval = SUNNonlinSolSetLSetupFn(cvode_mem->NLS, NULL);
165 
166   if (retval != CV_SUCCESS) {
167     cvProcessError(cvode_mem, CV_ILL_INPUT, "CVODE", "cvNlsInit",
168                    "Setting the linear solver setup function failed");
169     return(CV_NLS_INIT_FAIL);
170   }
171 
172   /* set the linear solver solve wrapper function */
173   if (cvode_mem->cv_lsolve)
174     retval = SUNNonlinSolSetLSolveFn(cvode_mem->NLS, cvNlsLSolve);
175   else
176     retval = SUNNonlinSolSetLSolveFn(cvode_mem->NLS, NULL);
177 
178   if (retval != CV_SUCCESS) {
179     cvProcessError(cvode_mem, CV_ILL_INPUT, "CVODE", "cvNlsInit",
180                    "Setting linear solver solve function failed");
181     return(CV_NLS_INIT_FAIL);
182   }
183 
184   /* initialize nonlinear solver */
185   retval = SUNNonlinSolInitialize(cvode_mem->NLS);
186 
187   if (retval != CV_SUCCESS) {
188     cvProcessError(cvode_mem, CV_ILL_INPUT, "CVODE", "cvNlsInit",
189                    MSGCV_NLS_INIT_FAIL);
190     return(CV_NLS_INIT_FAIL);
191   }
192 
193   return(CV_SUCCESS);
194 }
195 
196 
cvNlsLSetup(booleantype jbad,booleantype * jcur,void * cvode_mem)197 static int cvNlsLSetup(booleantype jbad, booleantype* jcur, void* cvode_mem)
198 {
199   CVodeMem cv_mem;
200   int      retval;
201 
202   if (cvode_mem == NULL) {
203     cvProcessError(NULL, CV_MEM_NULL, "CVODE", "cvNlsLSetup", MSGCV_NO_MEM);
204     return(CV_MEM_NULL);
205   }
206   cv_mem = (CVodeMem) cvode_mem;
207 
208   /* if the nonlinear solver marked the Jacobian as bad update convfail */
209   if (jbad)
210     cv_mem->convfail = CV_FAIL_BAD_J;
211 
212   /* setup the linear solver */
213   retval = cv_mem->cv_lsetup(cv_mem, cv_mem->convfail, cv_mem->cv_y, cv_mem->cv_ftemp,
214                              &(cv_mem->cv_jcur), cv_mem->cv_vtemp1, cv_mem->cv_vtemp2,
215                              cv_mem->cv_vtemp3);
216   cv_mem->cv_nsetups++;
217 
218   /* update Jacobian status */
219   *jcur = cv_mem->cv_jcur;
220 
221   cv_mem->cv_forceSetup = SUNFALSE;
222   cv_mem->cv_gamrat     = ONE;
223   cv_mem->cv_gammap     = cv_mem->cv_gamma;
224   cv_mem->cv_crate      = ONE;
225   cv_mem->cv_crateS     = ONE;
226   cv_mem->cv_nstlp      = cv_mem->cv_nst;
227 
228   if (retval < 0) return(CV_LSETUP_FAIL);
229   if (retval > 0) return(SUN_NLS_CONV_RECVR);
230 
231   return(CV_SUCCESS);
232 }
233 
234 
cvNlsLSolve(N_Vector delta,void * cvode_mem)235 static int cvNlsLSolve(N_Vector delta, void* cvode_mem)
236 {
237   CVodeMem cv_mem;
238   int      retval;
239 
240   if (cvode_mem == NULL) {
241     cvProcessError(NULL, CV_MEM_NULL, "CVODE", "cvNlsLSolve", MSGCV_NO_MEM);
242     return(CV_MEM_NULL);
243   }
244   cv_mem = (CVodeMem) cvode_mem;
245 
246   retval = cv_mem->cv_lsolve(cv_mem, delta, cv_mem->cv_ewt, cv_mem->cv_y, cv_mem->cv_ftemp);
247 
248   if (retval < 0) return(CV_LSOLVE_FAIL);
249   if (retval > 0) return(SUN_NLS_CONV_RECVR);
250 
251   return(CV_SUCCESS);
252 }
253 
254 
cvNlsConvTest(SUNNonlinearSolver NLS,N_Vector ycor,N_Vector delta,realtype tol,N_Vector ewt,void * cvode_mem)255 static int cvNlsConvTest(SUNNonlinearSolver NLS, N_Vector ycor, N_Vector delta,
256                          realtype tol, N_Vector ewt, void* cvode_mem)
257 {
258   CVodeMem cv_mem;
259   int m, retval;
260   realtype del;
261   realtype dcon;
262 
263   if (cvode_mem == NULL) {
264     cvProcessError(NULL, CV_MEM_NULL, "CVODE", "cvNlsConvTest", MSGCV_NO_MEM);
265     return(CV_MEM_NULL);
266   }
267   cv_mem = (CVodeMem) cvode_mem;
268 
269   /* compute the norm of the correction */
270   del = N_VWrmsNorm(delta, ewt);
271 
272   /* get the current nonlinear solver iteration count */
273   retval = SUNNonlinSolGetCurIter(NLS, &m);
274   if (retval != CV_SUCCESS) return(CV_MEM_NULL);
275 
276   /* Test for convergence. If m > 0, an estimate of the convergence
277      rate constant is stored in crate, and used in the test.        */
278   if (m > 0) {
279     cv_mem->cv_crate = SUNMAX(CRDOWN * cv_mem->cv_crate, del/cv_mem->cv_delp);
280   }
281   dcon = del * SUNMIN(ONE, cv_mem->cv_crate) / tol;
282 
283   if (dcon <= ONE) {
284     cv_mem->cv_acnrm = (m==0) ? del : N_VWrmsNorm(ycor, cv_mem->cv_ewt);
285     cv_mem->cv_acnrmcur = SUNTRUE;
286     return(CV_SUCCESS); /* Nonlinear system was solved successfully */
287   }
288 
289   /* check if the iteration seems to be diverging */
290   if ((m >= 1) && (del > RDIV*cv_mem->cv_delp)) return(SUN_NLS_CONV_RECVR);
291 
292   /* Save norm of correction and loop again */
293   cv_mem->cv_delp = del;
294 
295   /* Not yet converged */
296   return(SUN_NLS_CONTINUE);
297 }
298 
299 
cvNlsResidual(N_Vector ycor,N_Vector res,void * cvode_mem)300 static int cvNlsResidual(N_Vector ycor, N_Vector res, void* cvode_mem)
301 {
302   CVodeMem cv_mem;
303   int retval;
304 
305   if (cvode_mem == NULL) {
306     cvProcessError(NULL, CV_MEM_NULL, "CVODE",
307                    "cvNlsResidual", MSGCV_NO_MEM);
308     return(CV_MEM_NULL);
309   }
310   cv_mem = (CVodeMem) cvode_mem;
311 
312   /* update the state based on the current correction */
313   N_VLinearSum(ONE, cv_mem->cv_zn[0], ONE, ycor, cv_mem->cv_y);
314 
315   /* evaluate the rhs function */
316   retval = cv_mem->cv_f(cv_mem->cv_tn, cv_mem->cv_y, cv_mem->cv_ftemp,
317                         cv_mem->cv_user_data);
318   cv_mem->cv_nfe++;
319   if (retval < 0) return(CV_RHSFUNC_FAIL);
320   if (retval > 0) return(RHSFUNC_RECVR);
321 
322   /* compute the resiudal */
323   N_VLinearSum(cv_mem->cv_rl1, cv_mem->cv_zn[1], ONE, ycor, res);
324   N_VLinearSum(-cv_mem->cv_gamma, cv_mem->cv_ftemp, ONE, res, res);
325 
326   return(CV_SUCCESS);
327 }
328 
329 
cvNlsFPFunction(N_Vector ycor,N_Vector res,void * cvode_mem)330 static int cvNlsFPFunction(N_Vector ycor, N_Vector res, void* cvode_mem)
331 {
332  CVodeMem cv_mem;
333   int retval;
334 
335   if (cvode_mem == NULL) {
336     cvProcessError(NULL, CV_MEM_NULL, "CVODE", "cvNlsFPFunction", MSGCV_NO_MEM);
337     return(CV_MEM_NULL);
338   }
339   cv_mem = (CVodeMem) cvode_mem;
340 
341   /* update the state based on the current correction */
342   N_VLinearSum(ONE, cv_mem->cv_zn[0], ONE, ycor, cv_mem->cv_y);
343 
344   /* evaluate the rhs function */
345   retval = cv_mem->cv_f(cv_mem->cv_tn, cv_mem->cv_y, res,
346                         cv_mem->cv_user_data);
347   cv_mem->cv_nfe++;
348   if (retval < 0) return(CV_RHSFUNC_FAIL);
349   if (retval > 0) return(RHSFUNC_RECVR);
350 
351   N_VLinearSum(cv_mem->cv_h, res, -ONE, cv_mem->cv_zn[1], res);
352   N_VScale(cv_mem->cv_rl1, res, res);
353 
354   return(CV_SUCCESS);
355 }
356