1 /* -----------------------------------------------------------------------------
2  * Programmer(s): David J. Gardner @ LLNL
3  * -----------------------------------------------------------------------------
4  * SUNDIALS Copyright Start
5  * Copyright (c) 2002-2021, Lawrence Livermore National Security
6  * and Southern Methodist University.
7  * All rights reserved.
8  *
9  * See the top-level LICENSE and NOTICE files for details.
10  *
11  * SPDX-License-Identifier: BSD-3-Clause
12  * SUNDIALS Copyright End
13  * -----------------------------------------------------------------------------
14  * This the implementation file for the CVODES nonlinear solver interface.
15  * ---------------------------------------------------------------------------*/
16 
17 #include "cvodes_impl.h"
18 #include "sundials/sundials_math.h"
19 
20 /* constant macros */
21 #define ONE RCONST(1.0)
22 
23 /* private functions */
24 static int cvNlsResidualSensStg1(N_Vector ycor, N_Vector res,
25                                  void* cvode_mem);
26 static int cvNlsFPFunctionSensStg1(N_Vector ycor, N_Vector res,
27                                    void* cvode_mem);
28 
29 static int cvNlsLSetupSensStg1(booleantype jbad, booleantype* jcur,
30                                void* cvode_mem);
31 static int cvNlsLSolveSensStg1(N_Vector delta, void* cvode_mem);
32 static int cvNlsConvTestSensStg1(SUNNonlinearSolver NLS,
33                                  N_Vector ycor, N_Vector del,
34                                  realtype tol, N_Vector ewt, void* cvode_mem);
35 
36 /* -----------------------------------------------------------------------------
37  * Exported functions
38  * ---------------------------------------------------------------------------*/
39 
CVodeSetNonlinearSolverSensStg1(void * cvode_mem,SUNNonlinearSolver NLS)40 int CVodeSetNonlinearSolverSensStg1(void *cvode_mem, SUNNonlinearSolver NLS)
41 {
42   CVodeMem cv_mem;
43   int retval;
44 
45   /* Return immediately if CVode memory is NULL */
46   if (cvode_mem == NULL) {
47     cvProcessError(NULL, CV_MEM_NULL, "CVODES",
48                    "CVodeSetNonlinearSolverSensStg1", MSGCV_NO_MEM);
49     return(CV_MEM_NULL);
50   }
51   cv_mem = (CVodeMem) cvode_mem;
52 
53   /* Return immediately if NLS memory is NULL */
54   if (NLS == NULL) {
55     cvProcessError(NULL, CV_ILL_INPUT, "CVODES",
56                    "CVodeSetNonlinearSolverSensStg1",
57                    "NLS must be non-NULL");
58     return (CV_ILL_INPUT);
59   }
60 
61   /* check for required nonlinear solver functions */
62   if ( NLS->ops->gettype    == NULL ||
63        NLS->ops->solve      == NULL ||
64        NLS->ops->setsysfn   == NULL ) {
65     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
66                    "CVodeSetNonlinearSolverSensStg1",
67                    "NLS does not support required operations");
68     return(CV_ILL_INPUT);
69   }
70 
71   /* check that sensitivities were initialized */
72   if (!(cv_mem->cv_sensi)) {
73     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
74                    "CVodeSetNonlinearSolverSensStg1",
75                    MSGCV_NO_SENSI);
76     return(CV_ILL_INPUT);
77   }
78 
79   /* check that staggered corrector was selected */
80   if (cv_mem->cv_ism != CV_STAGGERED1) {
81     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
82                    "CVodeSetNonlinearSolverSensStg1",
83                    "Sensitivity solution method is not CV_STAGGERED1");
84     return(CV_ILL_INPUT);
85   }
86 
87   /* free any existing nonlinear solver */
88   if ((cv_mem->NLSstg1 != NULL) && (cv_mem->ownNLSstg1))
89       retval = SUNNonlinSolFree(cv_mem->NLSstg1);
90 
91   /* set SUNNonlinearSolver pointer */
92   cv_mem->NLSstg1 = NLS;
93 
94   /* Set NLS ownership flag. If this function was called to attach the default
95      NLS, CVODE will set the flag to SUNTRUE after this function returns. */
96   cv_mem->ownNLSstg1 = SUNFALSE;
97 
98   /* set the nonlinear system function */
99   if (SUNNonlinSolGetType(NLS) == SUNNONLINEARSOLVER_ROOTFIND) {
100     retval = SUNNonlinSolSetSysFn(cv_mem->NLSstg1, cvNlsResidualSensStg1);
101   } else if (SUNNonlinSolGetType(NLS) ==  SUNNONLINEARSOLVER_FIXEDPOINT) {
102     retval = SUNNonlinSolSetSysFn(cv_mem->NLSstg1, cvNlsFPFunctionSensStg1);
103   } else {
104     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
105                    "CVodeSetNonlinearSolverSensStg1",
106                    "Invalid nonlinear solver type");
107     return(CV_ILL_INPUT);
108   }
109 
110   if (retval != CV_SUCCESS) {
111     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
112                    "CVodeSetNonlinearSolverSensStg1",
113                    "Setting nonlinear system function failed");
114     return(CV_ILL_INPUT);
115   }
116 
117   /* set convergence test function */
118   retval = SUNNonlinSolSetConvTestFn(cv_mem->NLSstg1, cvNlsConvTestSensStg1,
119                                      cvode_mem);
120   if (retval != CV_SUCCESS) {
121     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
122                    "CVodeSetNonlinearSolverSensStg1",
123                    "Setting convergence test function failed");
124     return(CV_ILL_INPUT);
125   }
126 
127   /* set max allowed nonlinear iterations */
128   retval = SUNNonlinSolSetMaxIters(cv_mem->NLSstg1, NLS_MAXCOR);
129   if (retval != CV_SUCCESS) {
130     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
131                    "CVodeSetNonlinearSolverSensStg1",
132                    "Setting maximum number of nonlinear iterations failed");
133     return(CV_ILL_INPUT);
134   }
135 
136   /* Reset the acnrmScur flag to SUNFALSE (always false for stg1) */
137   cv_mem->cv_acnrmScur = SUNFALSE;
138 
139   return(CV_SUCCESS);
140 }
141 
142 
143 /* -----------------------------------------------------------------------------
144  * Private functions
145  * ---------------------------------------------------------------------------*/
146 
147 
cvNlsInitSensStg1(CVodeMem cvode_mem)148 int cvNlsInitSensStg1(CVodeMem cvode_mem)
149 {
150   int retval;
151 
152   /* set the linear solver setup wrapper function */
153   if (cvode_mem->cv_lsetup)
154     retval = SUNNonlinSolSetLSetupFn(cvode_mem->NLSstg1, cvNlsLSetupSensStg1);
155   else
156     retval = SUNNonlinSolSetLSetupFn(cvode_mem->NLSstg1, NULL);
157 
158   if (retval != CV_SUCCESS) {
159     cvProcessError(cvode_mem, CV_ILL_INPUT, "CVODES", "cvNlsInitSensStg1",
160                    "Setting the linear solver setup function failed");
161     return(CV_NLS_INIT_FAIL);
162   }
163 
164   /* set the linear solver solve wrapper function */
165   if (cvode_mem->cv_lsolve)
166     retval = SUNNonlinSolSetLSolveFn(cvode_mem->NLSstg1, cvNlsLSolveSensStg1);
167   else
168     retval = SUNNonlinSolSetLSolveFn(cvode_mem->NLSstg1, NULL);
169 
170   if (retval != CV_SUCCESS) {
171     cvProcessError(cvode_mem, CV_ILL_INPUT, "CVODES", "cvNlsInitSensStg1",
172                    "Setting linear solver solve function failed");
173     return(CV_NLS_INIT_FAIL);
174   }
175 
176   /* initialize nonlinear solver */
177   retval = SUNNonlinSolInitialize(cvode_mem->NLSstg1);
178 
179   if (retval != CV_SUCCESS) {
180     cvProcessError(cvode_mem, CV_ILL_INPUT, "CVODES", "cvNlsInitSensStg1",
181                    MSGCV_NLS_INIT_FAIL);
182     return(CV_NLS_INIT_FAIL);
183   }
184 
185   /* reset previous iteration count for updating nniS1 */
186   cvode_mem->nnip = 0;
187 
188   return(CV_SUCCESS);
189 }
190 
191 
cvNlsLSetupSensStg1(booleantype jbad,booleantype * jcur,void * cvode_mem)192 static int cvNlsLSetupSensStg1(booleantype jbad, booleantype* jcur,
193                                void* cvode_mem)
194 {
195   CVodeMem cv_mem;
196   int retval;
197 
198   if (cvode_mem == NULL) {
199     cvProcessError(NULL, CV_MEM_NULL, "CVODES",
200                    "cvNlsLSetupSensStg1", MSGCV_NO_MEM);
201     return(CV_MEM_NULL);
202   }
203   cv_mem = (CVodeMem) cvode_mem;
204 
205   /* if the nonlinear solver marked the Jacobian as bad update convfail */
206   if (jbad)
207     cv_mem->convfail = CV_FAIL_BAD_J;
208 
209   /* setup the linear solver */
210   retval = cv_mem->cv_lsetup(cv_mem, cv_mem->convfail, cv_mem->cv_y,
211                              cv_mem->cv_ftemp, &(cv_mem->cv_jcur),
212                              cv_mem->cv_vtemp1, cv_mem->cv_vtemp2,
213                              cv_mem->cv_vtemp3);
214   cv_mem->cv_nsetups++;
215   cv_mem->cv_nsetupsS++;
216 
217   /* update Jacobian status */
218   *jcur = cv_mem->cv_jcur;
219 
220   cv_mem->cv_gamrat     = ONE;
221   cv_mem->cv_gammap     = cv_mem->cv_gamma;
222   cv_mem->cv_crate      = ONE;
223   cv_mem->cv_crateS     = ONE;
224   cv_mem->cv_nstlp      = cv_mem->cv_nst;
225 
226   if (retval < 0) return(CV_LSETUP_FAIL);
227   if (retval > 0) return(SUN_NLS_CONV_RECVR);
228 
229   return(CV_SUCCESS);
230 }
231 
232 
cvNlsLSolveSensStg1(N_Vector delta,void * cvode_mem)233 static int cvNlsLSolveSensStg1(N_Vector delta, void* cvode_mem)
234 {
235   CVodeMem cv_mem;
236   int retval, is;
237 
238   if (cvode_mem == NULL) {
239     cvProcessError(NULL, CV_MEM_NULL, "CVODES",
240                    "cvNlsLSolveSensStg1", MSGCV_NO_MEM);
241     return(CV_MEM_NULL);
242   }
243   cv_mem = (CVodeMem) cvode_mem;
244 
245   /* get index of current sensitivity solve */
246   is = cv_mem->sens_solve_idx;
247 
248   /* solve the sensitivity linear systems */
249   retval = cv_mem->cv_lsolve(cv_mem, delta, cv_mem->cv_ewtS[is],
250                              cv_mem->cv_y, cv_mem->cv_ftemp);
251 
252   if (retval < 0) return(CV_LSOLVE_FAIL);
253   if (retval > 0) return(SUN_NLS_CONV_RECVR);
254 
255   return(CV_SUCCESS);
256 }
257 
258 
cvNlsConvTestSensStg1(SUNNonlinearSolver NLS,N_Vector ycor,N_Vector delta,realtype tol,N_Vector ewt,void * cvode_mem)259 static int cvNlsConvTestSensStg1(SUNNonlinearSolver NLS,
260                                  N_Vector ycor, N_Vector delta,
261                                  realtype tol, N_Vector ewt, void* cvode_mem)
262 {
263   CVodeMem cv_mem;
264   int m, retval;
265   realtype del;
266   realtype dcon;
267 
268   if (cvode_mem == NULL) {
269     cvProcessError(NULL, CV_MEM_NULL, "CVODES",
270                    "cvNlsConvTestSensStg1", MSGCV_NO_MEM);
271     return(CV_MEM_NULL);
272   }
273   cv_mem = (CVodeMem) cvode_mem;
274 
275   /* compute the norm of the state and sensitivity corrections */
276   del = N_VWrmsNorm(delta, ewt);
277 
278   /* get the current nonlinear solver iteration count */
279   retval = SUNNonlinSolGetCurIter(NLS, &m);
280   if (retval != CV_SUCCESS) return(CV_MEM_NULL);
281 
282   /* Test for convergence. If m > 0, an estimate of the convergence
283      rate constant is stored in crate, and used in the test.
284   */
285   if (m > 0) {
286     cv_mem->cv_crateS = SUNMAX(CRDOWN * cv_mem->cv_crateS, del/cv_mem->cv_delp);
287   }
288   dcon = del * SUNMIN(ONE, cv_mem->cv_crateS) / tol;
289 
290   /* check if nonlinear system was solved successfully */
291   if (dcon <= ONE) return(CV_SUCCESS);
292 
293   /* check if the iteration seems to be diverging */
294   if ((m >= 1) && (del > RDIV*cv_mem->cv_delp)) return(SUN_NLS_CONV_RECVR);
295 
296   /* Save norm of correction and loop again */
297   cv_mem->cv_delp = del;
298 
299   /* Not yet converged */
300   return(SUN_NLS_CONTINUE);
301 }
302 
303 
cvNlsResidualSensStg1(N_Vector ycor,N_Vector res,void * cvode_mem)304 static int cvNlsResidualSensStg1(N_Vector ycor, N_Vector res, void* cvode_mem)
305 {
306   CVodeMem cv_mem;
307   int retval, is;
308 
309   if (cvode_mem == NULL) {
310     cvProcessError(NULL, CV_MEM_NULL, "CVODES",
311                    "cvNlsResidualSensStg1", MSGCV_NO_MEM);
312     return(CV_MEM_NULL);
313   }
314   cv_mem = (CVodeMem) cvode_mem;
315 
316   /* get index of current sensitivity solve */
317   is = cv_mem->sens_solve_idx;
318 
319   /* update sensitivity based on the current correction */
320   N_VLinearSum(ONE, cv_mem->cv_znS[0][is], ONE, ycor, cv_mem->cv_yS[is]);
321 
322   /* evaluate the sensitivity rhs function */
323   retval = cvSensRhs1Wrapper(cv_mem, cv_mem->cv_tn,
324                              cv_mem->cv_y, cv_mem->cv_ftemp,
325                              is, cv_mem->cv_yS[is], cv_mem->cv_ftempS[is],
326                              cv_mem->cv_vtemp1, cv_mem->cv_vtemp2);
327 
328   if (retval < 0) return(CV_SRHSFUNC_FAIL);
329   if (retval > 0) return(SRHSFUNC_RECVR);
330 
331   /* compute the sensitivity resiudal */
332   N_VLinearSum(cv_mem->cv_rl1, cv_mem->cv_znS[1][is], ONE, ycor, res);
333   N_VLinearSum(-cv_mem->cv_gamma, cv_mem->cv_ftempS[is], ONE, res, res);
334 
335   return(CV_SUCCESS);
336 }
337 
338 
cvNlsFPFunctionSensStg1(N_Vector ycor,N_Vector res,void * cvode_mem)339 static int cvNlsFPFunctionSensStg1(N_Vector ycor, N_Vector res, void* cvode_mem)
340 {
341   CVodeMem cv_mem;
342   int retval, is;
343 
344   if (cvode_mem == NULL) {
345     cvProcessError(NULL, CV_MEM_NULL, "CVODES",
346                    "cvNlsFPFunctionSensStg1", MSGCV_NO_MEM);
347     return(CV_MEM_NULL);
348   }
349   cv_mem = (CVodeMem) cvode_mem;
350 
351   /* get index of current sensitivity solve */
352   is = cv_mem->sens_solve_idx;
353 
354   /* update the sensitivities based on the current correction */
355   N_VLinearSum(ONE, cv_mem->cv_znS[0][is], ONE, ycor, cv_mem->cv_yS[is]);
356 
357   /* evaluate the sensitivity rhs function */
358   retval = cvSensRhs1Wrapper(cv_mem, cv_mem->cv_tn,
359                              cv_mem->cv_y, cv_mem->cv_ftemp,
360                              is, cv_mem->cv_yS[is], res,
361                              cv_mem->cv_vtemp1, cv_mem->cv_vtemp2);
362 
363   if (retval < 0) return(CV_SRHSFUNC_FAIL);
364   if (retval > 0) return(SRHSFUNC_RECVR);
365 
366   /* evaluate sensitivity fixed point function */
367   N_VLinearSum(cv_mem->cv_h, res, -ONE, cv_mem->cv_znS[1][is], res);
368   N_VScale(cv_mem->cv_rl1, res, res);
369 
370   return(CV_SUCCESS);
371 }
372