1 /* -----------------------------------------------------------------------------
2 * Programmer(s): David J. Gardner @ LLNL
3 * -----------------------------------------------------------------------------
4 * SUNDIALS Copyright Start
5 * Copyright (c) 2002-2021, Lawrence Livermore National Security
6 * and Southern Methodist University.
7 * All rights reserved.
8 *
9 * See the top-level LICENSE and NOTICE files for details.
10 *
11 * SPDX-License-Identifier: BSD-3-Clause
12 * SUNDIALS Copyright End
13 * -----------------------------------------------------------------------------
14 * This the implementation file for the CVODES nonlinear solver interface.
15 * ---------------------------------------------------------------------------*/
16
17 #include "cvodes_impl.h"
18 #include "sundials/sundials_math.h"
19 #include "sundials/sundials_nvector_senswrapper.h"
20
21 /* constant macros */
22 #define ONE RCONST(1.0)
23
24 /* private functions */
25 static int cvNlsResidualSensStg(N_Vector ycorStg, N_Vector resStg,
26 void* cvode_mem);
27 static int cvNlsFPFunctionSensStg(N_Vector ycorStg, N_Vector resStg,
28 void* cvode_mem);
29
30 static int cvNlsLSetupSensStg(booleantype jbad, booleantype* jcur,
31 void* cvode_mem);
32 static int cvNlsLSolveSensStg(N_Vector deltaStg, void* cvode_mem);
33 static int cvNlsConvTestSensStg(SUNNonlinearSolver NLS,
34 N_Vector ycorStg, N_Vector delStg,
35 realtype tol, N_Vector ewtStg, void* cvode_mem);
36
37 /* -----------------------------------------------------------------------------
38 * Exported functions
39 * ---------------------------------------------------------------------------*/
40
CVodeSetNonlinearSolverSensStg(void * cvode_mem,SUNNonlinearSolver NLS)41 int CVodeSetNonlinearSolverSensStg(void *cvode_mem, SUNNonlinearSolver NLS)
42 {
43 CVodeMem cv_mem;
44 int retval, is;
45
46 /* Return immediately if CVode memory is NULL */
47 if (cvode_mem == NULL) {
48 cvProcessError(NULL, CV_MEM_NULL, "CVODES",
49 "CVodeSetNonlinearSolverSensStg", MSGCV_NO_MEM);
50 return(CV_MEM_NULL);
51 }
52 cv_mem = (CVodeMem) cvode_mem;
53
54 /* Return immediately if NLS memory is NULL */
55 if (NLS == NULL) {
56 cvProcessError(NULL, CV_ILL_INPUT, "CVODES",
57 "CVodeSetNonlinearSolverSensStg",
58 "NLS must be non-NULL");
59 return (CV_ILL_INPUT);
60 }
61
62 /* check for required nonlinear solver functions */
63 if ( NLS->ops->gettype == NULL ||
64 NLS->ops->solve == NULL ||
65 NLS->ops->setsysfn == NULL ) {
66 cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
67 "CVodeSetNonlinearSolverSensStg",
68 "NLS does not support required operations");
69 return(CV_ILL_INPUT);
70 }
71
72 /* check that sensitivities were initialized */
73 if (!(cv_mem->cv_sensi)) {
74 cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
75 "CVodeSetNonlinearSolverSensStg",
76 MSGCV_NO_SENSI);
77 return(CV_ILL_INPUT);
78 }
79
80 /* check that staggered corrector was selected */
81 if (cv_mem->cv_ism != CV_STAGGERED) {
82 cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
83 "CVodeSetNonlinearSolverSensStg",
84 "Sensitivity solution method is not CV_STAGGERED");
85 return(CV_ILL_INPUT);
86 }
87
88 /* free any existing nonlinear solver */
89 if ((cv_mem->NLSstg != NULL) && (cv_mem->ownNLSstg))
90 retval = SUNNonlinSolFree(cv_mem->NLSstg);
91
92 /* set SUNNonlinearSolver pointer */
93 cv_mem->NLSstg = NLS;
94
95 /* Set NLS ownership flag. If this function was called to attach the default
96 NLS, CVODE will set the flag to SUNTRUE after this function returns. */
97 cv_mem->ownNLSstg = SUNFALSE;
98
99 /* set the nonlinear system function */
100 if (SUNNonlinSolGetType(NLS) == SUNNONLINEARSOLVER_ROOTFIND) {
101 retval = SUNNonlinSolSetSysFn(cv_mem->NLSstg, cvNlsResidualSensStg);
102 } else if (SUNNonlinSolGetType(NLS) == SUNNONLINEARSOLVER_FIXEDPOINT) {
103 retval = SUNNonlinSolSetSysFn(cv_mem->NLSstg, cvNlsFPFunctionSensStg);
104 } else {
105 cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
106 "CVodeSetNonlinearSolverSensStg",
107 "Invalid nonlinear solver type");
108 return(CV_ILL_INPUT);
109 }
110
111 if (retval != CV_SUCCESS) {
112 cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
113 "CVodeSetNonlinearSolverSensStg",
114 "Setting nonlinear system function failed");
115 return(CV_ILL_INPUT);
116 }
117
118 /* set convergence test function */
119 retval = SUNNonlinSolSetConvTestFn(cv_mem->NLSstg, cvNlsConvTestSensStg,
120 cvode_mem);
121 if (retval != CV_SUCCESS) {
122 cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
123 "CVodeSetNonlinearSolverSensStg",
124 "Setting convergence test function failed");
125 return(CV_ILL_INPUT);
126 }
127
128 /* set max allowed nonlinear iterations */
129 retval = SUNNonlinSolSetMaxIters(cv_mem->NLSstg, NLS_MAXCOR);
130 if (retval != CV_SUCCESS) {
131 cvProcessError(cv_mem, CV_ILL_INPUT, "CVODES",
132 "CVodeSetNonlinearSolverSensStg",
133 "Setting maximum number of nonlinear iterations failed");
134 return(CV_ILL_INPUT);
135 }
136
137 /* create vector wrappers if necessary */
138 if (cv_mem->stgMallocDone == SUNFALSE) {
139
140 cv_mem->zn0Stg = N_VNewEmpty_SensWrapper(cv_mem->cv_Ns);
141 if (cv_mem->zn0Stg == NULL) {
142 cvProcessError(cv_mem, CV_MEM_FAIL, "CVODES",
143 "CVodeSetNonlinearSolverSensStg", MSGCV_MEM_FAIL);
144 return(CV_MEM_FAIL);
145 }
146
147 cv_mem->ycorStg = N_VNewEmpty_SensWrapper(cv_mem->cv_Ns);
148 if (cv_mem->ycorStg == NULL) {
149 N_VDestroy(cv_mem->zn0Stg);
150 cvProcessError(cv_mem, CV_MEM_FAIL, "CVODES",
151 "CVodeSetNonlinearSolverSensStg", MSGCV_MEM_FAIL);
152 return(CV_MEM_FAIL);
153 }
154
155 cv_mem->ewtStg = N_VNewEmpty_SensWrapper(cv_mem->cv_Ns);
156 if (cv_mem->ewtStg == NULL) {
157 N_VDestroy(cv_mem->zn0Stg);
158 N_VDestroy(cv_mem->ycorStg);
159 cvProcessError(cv_mem, CV_MEM_FAIL, "CVODES",
160 "CVodeSetNonlinearSolverSensStg", MSGCV_MEM_FAIL);
161 return(CV_MEM_FAIL);
162 }
163
164 cv_mem->stgMallocDone = SUNTRUE;
165 }
166
167 /* attach vectors to vector wrappers */
168 for (is=0; is < cv_mem->cv_Ns; is++) {
169 NV_VEC_SW(cv_mem->zn0Stg, is) = cv_mem->cv_znS[0][is];
170 NV_VEC_SW(cv_mem->ycorStg, is) = cv_mem->cv_acorS[is];
171 NV_VEC_SW(cv_mem->ewtStg, is) = cv_mem->cv_ewtS[is];
172 }
173
174 /* Reset the acnrmScur flag to SUNFALSE */
175 cv_mem->cv_acnrmScur = SUNFALSE;
176
177 return(CV_SUCCESS);
178 }
179
180
181 /* -----------------------------------------------------------------------------
182 * Private functions
183 * ---------------------------------------------------------------------------*/
184
185
cvNlsInitSensStg(CVodeMem cvode_mem)186 int cvNlsInitSensStg(CVodeMem cvode_mem)
187 {
188 int retval;
189
190 /* set the linear solver setup wrapper function */
191 if (cvode_mem->cv_lsetup)
192 retval = SUNNonlinSolSetLSetupFn(cvode_mem->NLSstg, cvNlsLSetupSensStg);
193 else
194 retval = SUNNonlinSolSetLSetupFn(cvode_mem->NLSstg, NULL);
195
196 if (retval != CV_SUCCESS) {
197 cvProcessError(cvode_mem, CV_ILL_INPUT, "CVODES", "cvNlsInitSensStg",
198 "Setting the linear solver setup function failed");
199 return(CV_NLS_INIT_FAIL);
200 }
201
202 /* set the linear solver solve wrapper function */
203 if (cvode_mem->cv_lsolve)
204 retval = SUNNonlinSolSetLSolveFn(cvode_mem->NLSstg, cvNlsLSolveSensStg);
205 else
206 retval = SUNNonlinSolSetLSolveFn(cvode_mem->NLSstg, NULL);
207
208 if (retval != CV_SUCCESS) {
209 cvProcessError(cvode_mem, CV_ILL_INPUT, "CVODES", "cvNlsInitSensStg",
210 "Setting linear solver solve function failed");
211 return(CV_NLS_INIT_FAIL);
212 }
213
214 /* initialize nonlinear solver */
215 retval = SUNNonlinSolInitialize(cvode_mem->NLSstg);
216
217 if (retval != CV_SUCCESS) {
218 cvProcessError(cvode_mem, CV_ILL_INPUT, "CVODES", "cvNlsInitSensStg",
219 MSGCV_NLS_INIT_FAIL);
220 return(CV_NLS_INIT_FAIL);
221 }
222
223 return(CV_SUCCESS);
224 }
225
226
cvNlsLSetupSensStg(booleantype jbad,booleantype * jcur,void * cvode_mem)227 static int cvNlsLSetupSensStg(booleantype jbad, booleantype* jcur,
228 void* cvode_mem)
229 {
230 CVodeMem cv_mem;
231 int retval;
232
233 if (cvode_mem == NULL) {
234 cvProcessError(NULL, CV_MEM_NULL, "CVODES",
235 "cvNlsLSetupSensStg", MSGCV_NO_MEM);
236 return(CV_MEM_NULL);
237 }
238 cv_mem = (CVodeMem) cvode_mem;
239
240 /* if the nonlinear solver marked the Jacobian as bad update convfail */
241 if (jbad)
242 cv_mem->convfail = CV_FAIL_BAD_J;
243
244 /* setup the linear solver */
245 retval = cv_mem->cv_lsetup(cv_mem, cv_mem->convfail, cv_mem->cv_y,
246 cv_mem->cv_ftemp, &(cv_mem->cv_jcur),
247 cv_mem->cv_vtemp1, cv_mem->cv_vtemp2,
248 cv_mem->cv_vtemp3);
249 cv_mem->cv_nsetups++;
250 cv_mem->cv_nsetupsS++;
251
252 /* update Jacobian status */
253 *jcur = cv_mem->cv_jcur;
254
255 cv_mem->cv_gamrat = ONE;
256 cv_mem->cv_gammap = cv_mem->cv_gamma;
257 cv_mem->cv_crate = ONE;
258 cv_mem->cv_crateS = ONE;
259 cv_mem->cv_nstlp = cv_mem->cv_nst;
260
261 if (retval < 0) return(CV_LSETUP_FAIL);
262 if (retval > 0) return(SUN_NLS_CONV_RECVR);
263
264 return(CV_SUCCESS);
265 }
266
267
cvNlsLSolveSensStg(N_Vector deltaStg,void * cvode_mem)268 static int cvNlsLSolveSensStg(N_Vector deltaStg, void* cvode_mem)
269 {
270 CVodeMem cv_mem;
271 int retval, is;
272 N_Vector *deltaS;
273
274 if (cvode_mem == NULL) {
275 cvProcessError(NULL, CV_MEM_NULL, "CVODES",
276 "cvNlsLSolveSensStg", MSGCV_NO_MEM);
277 return(CV_MEM_NULL);
278 }
279 cv_mem = (CVodeMem) cvode_mem;
280
281 /* extract sensitivity deltas from the vector wrapper */
282 deltaS = NV_VECS_SW(deltaStg);
283
284 /* solve the sensitivity linear systems */
285 for (is=0; is<cv_mem->cv_Ns; is++) {
286 retval = cv_mem->cv_lsolve(cv_mem, deltaS[is], cv_mem->cv_ewtS[is],
287 cv_mem->cv_y, cv_mem->cv_ftemp);
288
289 if (retval < 0) return(CV_LSOLVE_FAIL);
290 if (retval > 0) return(SUN_NLS_CONV_RECVR);
291 }
292
293 return(CV_SUCCESS);
294 }
295
296
cvNlsConvTestSensStg(SUNNonlinearSolver NLS,N_Vector ycorStg,N_Vector deltaStg,realtype tol,N_Vector ewtStg,void * cvode_mem)297 static int cvNlsConvTestSensStg(SUNNonlinearSolver NLS,
298 N_Vector ycorStg, N_Vector deltaStg,
299 realtype tol, N_Vector ewtStg, void* cvode_mem)
300 {
301 CVodeMem cv_mem;
302 int m, retval;
303 realtype Del;
304 realtype dcon;
305 N_Vector *ycorS, *deltaS, *ewtS;
306
307 if (cvode_mem == NULL) {
308 cvProcessError(NULL, CV_MEM_NULL, "CVODES",
309 "cvNlsConvTestSensStg", MSGCV_NO_MEM);
310 return(CV_MEM_NULL);
311 }
312 cv_mem = (CVodeMem) cvode_mem;
313
314 /* extract the current sensitivity corrections */
315 ycorS = NV_VECS_SW(ycorStg);
316
317 /* extract the sensitivity deltas */
318 deltaS = NV_VECS_SW(deltaStg);
319
320 /* extract the sensitivity error weights */
321 ewtS = NV_VECS_SW(ewtStg);
322
323 /* compute the norm of the state and sensitivity corrections */
324 Del = cvSensNorm(cv_mem, deltaS, ewtS);
325
326 /* get the current nonlinear solver iteration count */
327 retval = SUNNonlinSolGetCurIter(NLS, &m);
328 if (retval != CV_SUCCESS) return(CV_MEM_NULL);
329
330 /* Test for convergence. If m > 0, an estimate of the convergence
331 rate constant is stored in crate, and used in the test.
332
333 Recall that, even when errconS=SUNFALSE, all variables are used in the
334 convergence test. Hence, we use Del (and not del). However, acnrm is used
335 in the error test and thus it has different forms depending on errconS
336 (and this explains why we have to carry around del and delS).
337 */
338 if (m > 0) {
339 cv_mem->cv_crateS = SUNMAX(CRDOWN * cv_mem->cv_crateS, Del/cv_mem->cv_delp);
340 }
341 dcon = Del * SUNMIN(ONE, cv_mem->cv_crateS) / tol;
342
343 /* check if nonlinear system was solved successfully */
344 if (dcon <= ONE) {
345 if (cv_mem->cv_errconS) {
346 cv_mem->cv_acnrmS = (m==0) ? Del : cvSensNorm(cv_mem, ycorS, ewtS);
347 cv_mem->cv_acnrmScur = SUNTRUE;
348 }
349 return(CV_SUCCESS);
350 }
351
352 /* check if the iteration seems to be diverging */
353 if ((m >= 1) && (Del > RDIV*cv_mem->cv_delp)) return(SUN_NLS_CONV_RECVR);
354
355 /* Save norm of correction and loop again */
356 cv_mem->cv_delp = Del;
357
358 /* Not yet converged */
359 return(SUN_NLS_CONTINUE);
360 }
361
362
cvNlsResidualSensStg(N_Vector ycorStg,N_Vector resStg,void * cvode_mem)363 static int cvNlsResidualSensStg(N_Vector ycorStg, N_Vector resStg, void* cvode_mem)
364 {
365 CVodeMem cv_mem;
366 int retval;
367 N_Vector *ycorS, *resS;
368 realtype cvals[3];
369 N_Vector* XXvecs[3];
370
371 if (cvode_mem == NULL) {
372 cvProcessError(NULL, CV_MEM_NULL, "CVODES",
373 "cvNlsResidualSensStg", MSGCV_NO_MEM);
374 return(CV_MEM_NULL);
375 }
376 cv_mem = (CVodeMem) cvode_mem;
377
378 /* extract sensitivity and residual vectors from the vector wrapper */
379 ycorS = NV_VECS_SW(ycorStg);
380 resS = NV_VECS_SW(resStg);
381
382 /* update sensitivities based on the current correction */
383 retval = N_VLinearSumVectorArray(cv_mem->cv_Ns,
384 ONE, cv_mem->cv_znS[0],
385 ONE, ycorS, cv_mem->cv_yS);
386 if (retval != CV_SUCCESS) return(CV_VECTOROP_ERR);
387
388 /* evaluate the sensitivity rhs function */
389 retval = cvSensRhsWrapper(cv_mem, cv_mem->cv_tn,
390 cv_mem->cv_y, cv_mem->cv_ftemp,
391 cv_mem->cv_yS, cv_mem->cv_ftempS,
392 cv_mem->cv_vtemp1, cv_mem->cv_vtemp2);
393
394 if (retval < 0) return(CV_SRHSFUNC_FAIL);
395 if (retval > 0) return(SRHSFUNC_RECVR);
396
397 /* compute the sensitivity resiudal */
398 cvals[0] = cv_mem->cv_rl1; XXvecs[0] = cv_mem->cv_znS[1];
399 cvals[1] = ONE; XXvecs[1] = ycorS;
400 cvals[2] = -cv_mem->cv_gamma; XXvecs[2] = cv_mem->cv_ftempS;
401
402 retval = N_VLinearCombinationVectorArray(cv_mem->cv_Ns,
403 3, cvals, XXvecs, resS);
404 if (retval != CV_SUCCESS) return(CV_VECTOROP_ERR);
405
406 return(CV_SUCCESS);
407 }
408
409
cvNlsFPFunctionSensStg(N_Vector ycorStg,N_Vector resStg,void * cvode_mem)410 static int cvNlsFPFunctionSensStg(N_Vector ycorStg, N_Vector resStg, void* cvode_mem)
411 {
412 CVodeMem cv_mem;
413 int retval, is;
414 N_Vector *ycorS, *resS;
415
416 if (cvode_mem == NULL) {
417 cvProcessError(NULL, CV_MEM_NULL, "CVODES",
418 "cvNlsFPFunctionSensStg", MSGCV_NO_MEM);
419 return(CV_MEM_NULL);
420 }
421 cv_mem = (CVodeMem) cvode_mem;
422
423 /* extract sensitivity and residual vectors from the vector wrapper */
424 ycorS = NV_VECS_SW(ycorStg);
425 resS = NV_VECS_SW(resStg);
426
427 /* update the sensitivities based on the current correction */
428 retval = N_VLinearSumVectorArray(cv_mem->cv_Ns,
429 ONE, cv_mem->cv_znS[0],
430 ONE, ycorS, cv_mem->cv_yS);
431 if (retval != CV_SUCCESS) return(CV_VECTOROP_ERR);
432
433 /* evaluate the sensitivity rhs function */
434 retval = cvSensRhsWrapper(cv_mem, cv_mem->cv_tn,
435 cv_mem->cv_y, cv_mem->cv_ftemp,
436 cv_mem->cv_yS, resS,
437 cv_mem->cv_vtemp1, cv_mem->cv_vtemp2);
438
439 if (retval < 0) return(CV_SRHSFUNC_FAIL);
440 if (retval > 0) return(SRHSFUNC_RECVR);
441
442 /* evaluate sensitivity fixed point function */
443 for (is=0; is<cv_mem->cv_Ns; is++) {
444 N_VLinearSum(cv_mem->cv_h, resS[is], -ONE, cv_mem->cv_znS[1][is], resS[is]);
445 N_VScale(cv_mem->cv_rl1, resS[is], resS[is]);
446 }
447
448 return(CV_SUCCESS);
449 }
450