1 /* -----------------------------------------------------------------------------
2  * Programmer(s): David J. Gardner @ LLNL
3  * -----------------------------------------------------------------------------
4  * SUNDIALS Copyright Start
5  * Copyright (c) 2002-2021, Lawrence Livermore National Security
6  * and Southern Methodist University.
7  * All rights reserved.
8  *
9  * See the top-level LICENSE and NOTICE files for details.
10  *
11  * SPDX-License-Identifier: BSD-3-Clause
12  * SUNDIALS Copyright End
13  * -----------------------------------------------------------------------------
14  * This the implementation file for the CVODES nonlinear solver interface.
15  * ---------------------------------------------------------------------------*/
16 
17 #include "cvodes_impl.h"
18 #include "sundials/sundials_math.h"
19 #include "sundials/sundials_nvector_senswrapper.h"
20 
21 /* constant macros */
22 #define ONE RCONST(1.0)
23 
24 /* private functions */
25 static int cvNlsResidualSensStg(N_Vector ycorStg, N_Vector resStg,
26                                 void* cvode_mem);
27 static int cvNlsFPFunctionSensStg(N_Vector ycorStg, N_Vector resStg,
28                                   void* cvode_mem);
29 
30 static int cvNlsLSetupSensStg(booleantype jbad, booleantype* jcur,
31                               void* cvode_mem);
32 static int cvNlsLSolveSensStg(N_Vector deltaStg, void* cvode_mem);
33 static int cvNlsConvTestSensStg(SUNNonlinearSolver NLS,
34                                 N_Vector ycorStg, N_Vector delStg,
35                                 realtype tol, N_Vector ewtStg, void* cvode_mem);
36 
37 /* -----------------------------------------------------------------------------
38  * Exported functions
39  * ---------------------------------------------------------------------------*/
40 
CVodeSetNonlinearSolverSensStg(void * cvode_mem,SUNNonlinearSolver NLS)41 int CVodeSetNonlinearSolverSensStg(void *cvode_mem, SUNNonlinearSolver NLS)
42 {
43   CVodeMem cv_mem;
44   int retval, is;
45 
46   /* Return immediately if CVode memory is NULL */
47   if (cvode_mem == NULL) {
48     cvProcessError(NULL, CV_MEM_NULL, "CVODES",
49                    "CVodeSetNonlinearSolverSensStg", MSGCV_NO_MEM);
50     return(CV_MEM_NULL);
51   }
52   cv_mem = (CVodeMem) cvode_mem;
53 
54   /* Return immediately if NLS memory is NULL */
55   if (NLS == NULL) {
56     cvProcessError(NULL, CV_ILL_INPUT, "CVODES",
57                    "CVodeSetNonlinearSolverSensStg",
58                    "NLS must be non-NULL");
59     return (CV_ILL_INPUT);
60   }
61 
62   /* check for required nonlinear solver functions */
63   if ( NLS->ops->gettype    == NULL ||
64        NLS->ops->solve      == NULL ||
65        NLS->ops->setsysfn   == NULL ) {
66     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
67                    "CVodeSetNonlinearSolverSensStg",
68                    "NLS does not support required operations");
69     return(CV_ILL_INPUT);
70   }
71 
72   /* check that sensitivities were initialized */
73   if (!(cv_mem->cv_sensi)) {
74     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
75                    "CVodeSetNonlinearSolverSensStg",
76                    MSGCV_NO_SENSI);
77     return(CV_ILL_INPUT);
78   }
79 
80   /* check that staggered corrector was selected */
81   if (cv_mem->cv_ism != CV_STAGGERED) {
82     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
83                    "CVodeSetNonlinearSolverSensStg",
84                    "Sensitivity solution method is not CV_STAGGERED");
85     return(CV_ILL_INPUT);
86   }
87 
88   /* free any existing nonlinear solver */
89   if ((cv_mem->NLSstg != NULL) && (cv_mem->ownNLSstg))
90     retval = SUNNonlinSolFree(cv_mem->NLSstg);
91 
92   /* set SUNNonlinearSolver pointer */
93   cv_mem->NLSstg = NLS;
94 
95   /* Set NLS ownership flag. If this function was called to attach the default
96      NLS, CVODE will set the flag to SUNTRUE after this function returns. */
97   cv_mem->ownNLSstg = SUNFALSE;
98 
99   /* set the nonlinear system function */
100   if (SUNNonlinSolGetType(NLS) == SUNNONLINEARSOLVER_ROOTFIND) {
101     retval = SUNNonlinSolSetSysFn(cv_mem->NLSstg, cvNlsResidualSensStg);
102   } else if (SUNNonlinSolGetType(NLS) ==  SUNNONLINEARSOLVER_FIXEDPOINT) {
103     retval = SUNNonlinSolSetSysFn(cv_mem->NLSstg, cvNlsFPFunctionSensStg);
104   } else {
105     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
106                    "CVodeSetNonlinearSolverSensStg",
107                    "Invalid nonlinear solver type");
108     return(CV_ILL_INPUT);
109   }
110 
111   if (retval != CV_SUCCESS) {
112     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
113                    "CVodeSetNonlinearSolverSensStg",
114                    "Setting nonlinear system function failed");
115     return(CV_ILL_INPUT);
116   }
117 
118   /* set convergence test function */
119   retval = SUNNonlinSolSetConvTestFn(cv_mem->NLSstg, cvNlsConvTestSensStg,
120                                      cvode_mem);
121   if (retval != CV_SUCCESS) {
122     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
123                    "CVodeSetNonlinearSolverSensStg",
124                    "Setting convergence test function failed");
125     return(CV_ILL_INPUT);
126   }
127 
128   /* set max allowed nonlinear iterations */
129   retval = SUNNonlinSolSetMaxIters(cv_mem->NLSstg, NLS_MAXCOR);
130   if (retval != CV_SUCCESS) {
131     cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
132                    "CVodeSetNonlinearSolverSensStg",
133                    "Setting maximum number of nonlinear iterations failed");
134     return(CV_ILL_INPUT);
135   }
136 
137   /* create vector wrappers if necessary */
138   if (cv_mem->stgMallocDone == SUNFALSE) {
139 
140     cv_mem->zn0Stg = N_VNewEmpty_SensWrapper(cv_mem->cv_Ns);
141     if (cv_mem->zn0Stg == NULL) {
142       cvProcessError(cv_mem, CV_MEM_FAIL, "CVODES",
143                      "CVodeSetNonlinearSolverSensStg", MSGCV_MEM_FAIL);
144       return(CV_MEM_FAIL);
145     }
146 
147     cv_mem->ycorStg = N_VNewEmpty_SensWrapper(cv_mem->cv_Ns);
148     if (cv_mem->ycorStg == NULL) {
149       N_VDestroy(cv_mem->zn0Stg);
150       cvProcessError(cv_mem, CV_MEM_FAIL, "CVODES",
151                      "CVodeSetNonlinearSolverSensStg", MSGCV_MEM_FAIL);
152       return(CV_MEM_FAIL);
153     }
154 
155     cv_mem->ewtStg = N_VNewEmpty_SensWrapper(cv_mem->cv_Ns);
156     if (cv_mem->ewtStg == NULL) {
157       N_VDestroy(cv_mem->zn0Stg);
158       N_VDestroy(cv_mem->ycorStg);
159       cvProcessError(cv_mem, CV_MEM_FAIL, "CVODES",
160                      "CVodeSetNonlinearSolverSensStg", MSGCV_MEM_FAIL);
161       return(CV_MEM_FAIL);
162     }
163 
164     cv_mem->stgMallocDone = SUNTRUE;
165   }
166 
167   /* attach vectors to vector wrappers */
168   for (is=0; is < cv_mem->cv_Ns; is++) {
169     NV_VEC_SW(cv_mem->zn0Stg,  is) = cv_mem->cv_znS[0][is];
170     NV_VEC_SW(cv_mem->ycorStg, is) = cv_mem->cv_acorS[is];
171     NV_VEC_SW(cv_mem->ewtStg,  is) = cv_mem->cv_ewtS[is];
172   }
173 
174   /* Reset the acnrmScur flag to SUNFALSE */
175   cv_mem->cv_acnrmScur = SUNFALSE;
176 
177   return(CV_SUCCESS);
178 }
179 
180 
181 /* -----------------------------------------------------------------------------
182  * Private functions
183  * ---------------------------------------------------------------------------*/
184 
185 
cvNlsInitSensStg(CVodeMem cvode_mem)186 int cvNlsInitSensStg(CVodeMem cvode_mem)
187 {
188   int retval;
189 
190   /* set the linear solver setup wrapper function */
191   if (cvode_mem->cv_lsetup)
192     retval = SUNNonlinSolSetLSetupFn(cvode_mem->NLSstg, cvNlsLSetupSensStg);
193   else
194     retval = SUNNonlinSolSetLSetupFn(cvode_mem->NLSstg, NULL);
195 
196   if (retval != CV_SUCCESS) {
197     cvProcessError(cvode_mem, CV_ILL_INPUT, "CVODES", "cvNlsInitSensStg",
198                    "Setting the linear solver setup function failed");
199     return(CV_NLS_INIT_FAIL);
200   }
201 
202   /* set the linear solver solve wrapper function */
203   if (cvode_mem->cv_lsolve)
204     retval = SUNNonlinSolSetLSolveFn(cvode_mem->NLSstg, cvNlsLSolveSensStg);
205   else
206     retval = SUNNonlinSolSetLSolveFn(cvode_mem->NLSstg, NULL);
207 
208   if (retval != CV_SUCCESS) {
209     cvProcessError(cvode_mem, CV_ILL_INPUT, "CVODES", "cvNlsInitSensStg",
210                    "Setting linear solver solve function failed");
211     return(CV_NLS_INIT_FAIL);
212   }
213 
214   /* initialize nonlinear solver */
215   retval = SUNNonlinSolInitialize(cvode_mem->NLSstg);
216 
217   if (retval != CV_SUCCESS) {
218     cvProcessError(cvode_mem, CV_ILL_INPUT, "CVODES", "cvNlsInitSensStg",
219                    MSGCV_NLS_INIT_FAIL);
220     return(CV_NLS_INIT_FAIL);
221   }
222 
223   return(CV_SUCCESS);
224 }
225 
226 
cvNlsLSetupSensStg(booleantype jbad,booleantype * jcur,void * cvode_mem)227 static int cvNlsLSetupSensStg(booleantype jbad, booleantype* jcur,
228                               void* cvode_mem)
229 {
230   CVodeMem cv_mem;
231   int retval;
232 
233   if (cvode_mem == NULL) {
234     cvProcessError(NULL, CV_MEM_NULL, "CVODES",
235                    "cvNlsLSetupSensStg", MSGCV_NO_MEM);
236     return(CV_MEM_NULL);
237   }
238   cv_mem = (CVodeMem) cvode_mem;
239 
240   /* if the nonlinear solver marked the Jacobian as bad update convfail */
241   if (jbad)
242     cv_mem->convfail = CV_FAIL_BAD_J;
243 
244   /* setup the linear solver */
245   retval = cv_mem->cv_lsetup(cv_mem, cv_mem->convfail, cv_mem->cv_y,
246                              cv_mem->cv_ftemp, &(cv_mem->cv_jcur),
247                              cv_mem->cv_vtemp1, cv_mem->cv_vtemp2,
248                              cv_mem->cv_vtemp3);
249   cv_mem->cv_nsetups++;
250   cv_mem->cv_nsetupsS++;
251 
252   /* update Jacobian status */
253   *jcur = cv_mem->cv_jcur;
254 
255   cv_mem->cv_gamrat     = ONE;
256   cv_mem->cv_gammap     = cv_mem->cv_gamma;
257   cv_mem->cv_crate      = ONE;
258   cv_mem->cv_crateS     = ONE;
259   cv_mem->cv_nstlp      = cv_mem->cv_nst;
260 
261   if (retval < 0) return(CV_LSETUP_FAIL);
262   if (retval > 0) return(SUN_NLS_CONV_RECVR);
263 
264   return(CV_SUCCESS);
265 }
266 
267 
cvNlsLSolveSensStg(N_Vector deltaStg,void * cvode_mem)268 static int cvNlsLSolveSensStg(N_Vector deltaStg, void* cvode_mem)
269 {
270   CVodeMem cv_mem;
271   int retval, is;
272   N_Vector *deltaS;
273 
274   if (cvode_mem == NULL) {
275     cvProcessError(NULL, CV_MEM_NULL, "CVODES",
276                    "cvNlsLSolveSensStg", MSGCV_NO_MEM);
277     return(CV_MEM_NULL);
278   }
279   cv_mem = (CVodeMem) cvode_mem;
280 
281   /* extract sensitivity deltas from the vector wrapper */
282   deltaS = NV_VECS_SW(deltaStg);
283 
284   /* solve the sensitivity linear systems */
285   for (is=0; is<cv_mem->cv_Ns; is++) {
286     retval = cv_mem->cv_lsolve(cv_mem, deltaS[is], cv_mem->cv_ewtS[is],
287                                cv_mem->cv_y, cv_mem->cv_ftemp);
288 
289     if (retval < 0) return(CV_LSOLVE_FAIL);
290     if (retval > 0) return(SUN_NLS_CONV_RECVR);
291   }
292 
293   return(CV_SUCCESS);
294 }
295 
296 
cvNlsConvTestSensStg(SUNNonlinearSolver NLS,N_Vector ycorStg,N_Vector deltaStg,realtype tol,N_Vector ewtStg,void * cvode_mem)297 static int cvNlsConvTestSensStg(SUNNonlinearSolver NLS,
298                                 N_Vector ycorStg, N_Vector deltaStg,
299                                 realtype tol, N_Vector ewtStg, void* cvode_mem)
300 {
301   CVodeMem cv_mem;
302   int m, retval;
303   realtype Del;
304   realtype dcon;
305   N_Vector *ycorS, *deltaS, *ewtS;
306 
307   if (cvode_mem == NULL) {
308     cvProcessError(NULL, CV_MEM_NULL, "CVODES",
309                    "cvNlsConvTestSensStg", MSGCV_NO_MEM);
310     return(CV_MEM_NULL);
311   }
312   cv_mem = (CVodeMem) cvode_mem;
313 
314   /* extract the current sensitivity corrections */
315   ycorS = NV_VECS_SW(ycorStg);
316 
317   /* extract the sensitivity deltas */
318   deltaS = NV_VECS_SW(deltaStg);
319 
320   /* extract the sensitivity error weights */
321   ewtS = NV_VECS_SW(ewtStg);
322 
323   /* compute the norm of the state and sensitivity corrections */
324   Del = cvSensNorm(cv_mem, deltaS, ewtS);
325 
326   /* get the current nonlinear solver iteration count */
327   retval = SUNNonlinSolGetCurIter(NLS, &m);
328   if (retval != CV_SUCCESS) return(CV_MEM_NULL);
329 
330   /* Test for convergence. If m > 0, an estimate of the convergence
331      rate constant is stored in crate, and used in the test.
332 
333      Recall that, even when errconS=SUNFALSE, all variables are used in the
334      convergence test. Hence, we use Del (and not del). However, acnrm is used
335      in the error test and thus it has different forms depending on errconS
336      (and this explains why we have to carry around del and delS).
337   */
338   if (m > 0) {
339     cv_mem->cv_crateS = SUNMAX(CRDOWN * cv_mem->cv_crateS, Del/cv_mem->cv_delp);
340   }
341   dcon = Del * SUNMIN(ONE, cv_mem->cv_crateS) / tol;
342 
343   /* check if nonlinear system was solved successfully */
344   if (dcon <= ONE) {
345     if (cv_mem->cv_errconS) {
346       cv_mem->cv_acnrmS = (m==0) ? Del : cvSensNorm(cv_mem, ycorS, ewtS);
347       cv_mem->cv_acnrmScur = SUNTRUE;
348     }
349     return(CV_SUCCESS);
350   }
351 
352   /* check if the iteration seems to be diverging */
353   if ((m >= 1) && (Del > RDIV*cv_mem->cv_delp)) return(SUN_NLS_CONV_RECVR);
354 
355   /* Save norm of correction and loop again */
356   cv_mem->cv_delp = Del;
357 
358   /* Not yet converged */
359   return(SUN_NLS_CONTINUE);
360 }
361 
362 
cvNlsResidualSensStg(N_Vector ycorStg,N_Vector resStg,void * cvode_mem)363 static int cvNlsResidualSensStg(N_Vector ycorStg, N_Vector resStg, void* cvode_mem)
364 {
365   CVodeMem cv_mem;
366   int retval;
367   N_Vector *ycorS, *resS;
368   realtype cvals[3];
369   N_Vector* XXvecs[3];
370 
371   if (cvode_mem == NULL) {
372     cvProcessError(NULL, CV_MEM_NULL, "CVODES",
373                    "cvNlsResidualSensStg", MSGCV_NO_MEM);
374     return(CV_MEM_NULL);
375   }
376   cv_mem = (CVodeMem) cvode_mem;
377 
378   /* extract sensitivity and residual vectors from the vector wrapper */
379   ycorS = NV_VECS_SW(ycorStg);
380   resS  = NV_VECS_SW(resStg);
381 
382   /* update sensitivities based on the current correction */
383   retval = N_VLinearSumVectorArray(cv_mem->cv_Ns,
384                                    ONE, cv_mem->cv_znS[0],
385                                    ONE, ycorS, cv_mem->cv_yS);
386   if (retval != CV_SUCCESS) return(CV_VECTOROP_ERR);
387 
388   /* evaluate the sensitivity rhs function */
389   retval = cvSensRhsWrapper(cv_mem, cv_mem->cv_tn,
390                             cv_mem->cv_y, cv_mem->cv_ftemp,
391                             cv_mem->cv_yS, cv_mem->cv_ftempS,
392                             cv_mem->cv_vtemp1, cv_mem->cv_vtemp2);
393 
394   if (retval < 0) return(CV_SRHSFUNC_FAIL);
395   if (retval > 0) return(SRHSFUNC_RECVR);
396 
397   /* compute the sensitivity resiudal */
398   cvals[0] = cv_mem->cv_rl1;    XXvecs[0] = cv_mem->cv_znS[1];
399   cvals[1] = ONE;               XXvecs[1] = ycorS;
400   cvals[2] = -cv_mem->cv_gamma; XXvecs[2] = cv_mem->cv_ftempS;
401 
402   retval = N_VLinearCombinationVectorArray(cv_mem->cv_Ns,
403                                            3, cvals, XXvecs, resS);
404   if (retval != CV_SUCCESS) return(CV_VECTOROP_ERR);
405 
406   return(CV_SUCCESS);
407 }
408 
409 
cvNlsFPFunctionSensStg(N_Vector ycorStg,N_Vector resStg,void * cvode_mem)410 static int cvNlsFPFunctionSensStg(N_Vector ycorStg, N_Vector resStg, void* cvode_mem)
411 {
412  CVodeMem cv_mem;
413  int retval, is;
414  N_Vector *ycorS, *resS;
415 
416   if (cvode_mem == NULL) {
417     cvProcessError(NULL, CV_MEM_NULL, "CVODES",
418                    "cvNlsFPFunctionSensStg", MSGCV_NO_MEM);
419     return(CV_MEM_NULL);
420   }
421   cv_mem = (CVodeMem) cvode_mem;
422 
423   /* extract sensitivity and residual vectors from the vector wrapper */
424   ycorS = NV_VECS_SW(ycorStg);
425   resS  = NV_VECS_SW(resStg);
426 
427   /* update the sensitivities based on the current correction */
428   retval = N_VLinearSumVectorArray(cv_mem->cv_Ns,
429                                    ONE, cv_mem->cv_znS[0],
430                                    ONE, ycorS, cv_mem->cv_yS);
431   if (retval != CV_SUCCESS) return(CV_VECTOROP_ERR);
432 
433   /* evaluate the sensitivity rhs function */
434   retval = cvSensRhsWrapper(cv_mem, cv_mem->cv_tn,
435                             cv_mem->cv_y, cv_mem->cv_ftemp,
436                             cv_mem->cv_yS, resS,
437                             cv_mem->cv_vtemp1, cv_mem->cv_vtemp2);
438 
439   if (retval < 0) return(CV_SRHSFUNC_FAIL);
440   if (retval > 0) return(SRHSFUNC_RECVR);
441 
442   /* evaluate sensitivity fixed point function */
443   for (is=0; is<cv_mem->cv_Ns; is++) {
444     N_VLinearSum(cv_mem->cv_h, resS[is], -ONE, cv_mem->cv_znS[1][is], resS[is]);
445     N_VScale(cv_mem->cv_rl1, resS[is], resS[is]);
446   }
447 
448   return(CV_SUCCESS);
449 }
450