1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6 7# -------------------------------------------------------------------- 8 9class Function: 10 def __call__(self, snes, x, f): 11 f[0] = (x[0]*x[0] + x[0]*x[1] - 3.0).item() 12 f[1] = (x[0]*x[1] + x[1]*x[1] - 6.0).item() 13 f.assemble() 14 15class Jacobian: 16 def __call__(self, snes, x, J, P): 17 P[0,0] = (2.0*x[0] + x[1]).item() 18 P[0,1] = (x[0]).item() 19 P[1,0] = (x[1]).item() 20 P[1,1] = (x[0] + 2.0*x[1]).item() 21 P.assemble() 22 if J != P: J.assemble() 23 24# -------------------------------------------------------------------- 25 26class BaseTestSNES(object): 27 28 SNES_TYPE = None 29 30 def setUp(self): 31 snes = PETSc.SNES() 32 snes.create(PETSc.COMM_SELF) 33 if self.SNES_TYPE: 34 snes.setType(self.SNES_TYPE) 35 self.snes = snes 36 37 def tearDown(self): 38 self.snes = None 39 40 def testGetSetType(self): 41 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 42 self.snes.setType(self.SNES_TYPE) 43 self.assertEqual(self.snes.getType(), self.SNES_TYPE) 44 45 def testTols(self): 46 tols = self.snes.getTolerances() 47 self.snes.setTolerances(*tols) 48 tnames = ('rtol', 'atol','stol', 'max_it') 49 tolvals = [getattr(self.snes, t) for t in tnames] 50 self.assertEqual(tuple(tols), tuple(tolvals)) 51 52 def testProperties(self): 53 snes = self.snes 54 # 55 snes.appctx = (1,2,3) 56 self.assertEqual(snes.appctx, (1,2,3)) 57 snes.appctx = None 58 self.assertEqual(snes.appctx, None) 59 # 60 snes.its = 1 61 self.assertEqual(snes.its, 1) 62 snes.its = 0 63 self.assertEqual(snes.its, 0) 64 # 65 snes.norm = 1 66 self.assertEqual(snes.norm, 1) 67 snes.norm = 0 68 self.assertEqual(snes.norm, 0) 69 # 70 rh, ih = snes.history 71 self.assertTrue(len(rh)==0) 72 self.assertTrue(len(ih)==0) 73 # 74 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS 75 snes.reason = reason 76 self.assertEqual(snes.reason, reason) 77 self.assertTrue(snes.converged) 78 self.assertFalse(snes.diverged) 79 self.assertFalse(snes.iterating) 80 reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT 81 snes.reason = reason 82 self.assertEqual(snes.reason, reason) 83 self.assertFalse(snes.converged) 84 self.assertTrue(snes.diverged) 85 self.assertFalse(snes.iterating) 86 reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING 87 snes.reason = reason 88 self.assertEqual(snes.reason, reason) 89 self.assertFalse(snes.converged) 90 self.assertFalse(snes.diverged) 91 self.assertTrue(snes.iterating) 92 # 93 self.assertFalse(snes.use_ew) 94 self.assertFalse(snes.use_mf) 95 self.assertFalse(snes.use_fd) 96 97 def testGetSetFunc(self): 98 r, func = self.snes.getFunction() 99 self.assertFalse(r) 100 self.assertTrue(func is None) 101 r = PETSc.Vec().createSeq(2) 102 func = Function() 103 refcnt = getrefcount(func) 104 self.snes.setFunction(func, r) 105 self.snes.setFunction(func, r) 106 self.assertEqual(getrefcount(func), refcnt + 1) 107 r2, func2 = self.snes.getFunction() 108 self.assertEqual(r, r2) 109 self.assertEqual(func, func2[0]) 110 self.assertEqual(getrefcount(func), refcnt + 1) 111 r3, func3 = self.snes.getFunction() 112 self.assertEqual(r, r3) 113 self.assertEqual(func, func3[0]) 114 self.assertEqual(getrefcount(func), refcnt + 1) 115 116 def testCompFunc(self): 117 r = PETSc.Vec().createSeq(2) 118 func = Function() 119 self.snes.setFunction(func, r) 120 x, y = r.duplicate(), r.duplicate() 121 x[0], x[1] = [1, 2] 122 self.snes.computeFunction(x, y) 123 self.assertAlmostEqual(abs(y[0]), 0.0) 124 self.assertAlmostEqual(abs(y[1]), 0.0) 125 126 def testGetSetJac(self): 127 A, P, jac = self.snes.getJacobian() 128 self.assertFalse(A) 129 self.assertFalse(P) 130 self.assertTrue(jac is None) 131 J = PETSc.Mat().create(PETSc.COMM_SELF) 132 J.setSizes([2,2]) 133 J.setType(PETSc.Mat.Type.SEQAIJ) 134 J.setUp() 135 jac = Jacobian() 136 refcnt = getrefcount(jac) 137 self.snes.setJacobian(jac, J) 138 self.snes.setJacobian(jac, J) 139 self.assertEqual(getrefcount(jac), refcnt + 1) 140 J2, P2, jac2 = self.snes.getJacobian() 141 self.assertEqual(J, J2) 142 self.assertEqual(J2, P2) 143 self.assertEqual(jac, jac2[0]) 144 self.assertEqual(getrefcount(jac), refcnt + 1) 145 J3, P3, jac3 = self.snes.getJacobian() 146 self.assertEqual(J, J3) 147 self.assertEqual(J3, P3) 148 self.assertEqual(jac, jac3[0]) 149 self.assertEqual(getrefcount(jac), refcnt + 1) 150 151 def testCompJac(self): 152 J = PETSc.Mat().create(PETSc.COMM_SELF) 153 J.setSizes([2,2]) 154 J.setType(PETSc.Mat.Type.SEQAIJ) 155 J.setUp() 156 jac = Jacobian() 157 self.snes.setJacobian(jac, J) 158 x = PETSc.Vec().createSeq(2) 159 x[0], x[1] = [1, 2] 160 self.snes.getKSP().getPC() 161 self.snes.computeJacobian(x, J) 162 163 def testGetSetUpd(self): 164 self.assertTrue(self.snes.getUpdate() is None) 165 upd = lambda snes, it: None 166 refcnt = getrefcount(upd) 167 self.snes.setUpdate(upd) 168 self.assertEqual(getrefcount(upd), refcnt + 1) 169 self.snes.setUpdate(upd) 170 self.assertEqual(getrefcount(upd), refcnt + 1) 171 self.snes.setUpdate(None) 172 self.assertTrue(self.snes.getUpdate() is None) 173 self.assertEqual(getrefcount(upd), refcnt) 174 self.snes.setUpdate(upd) 175 self.assertEqual(getrefcount(upd), refcnt + 1) 176 upd2 = lambda snes, it: None 177 refcnt2 = getrefcount(upd2) 178 self.snes.setUpdate(upd2) 179 self.assertEqual(getrefcount(upd), refcnt) 180 self.assertEqual(getrefcount(upd2), refcnt2 + 1) 181 tmp = self.snes.getUpdate()[0] 182 self.assertTrue(tmp is upd2) 183 self.assertEqual(getrefcount(upd2), refcnt2 + 2) 184 del tmp 185 self.snes.setUpdate(None) 186 self.assertTrue(self.snes.getUpdate() is None) 187 self.assertEqual(getrefcount(upd2), refcnt2) 188 189 def testGetKSP(self): 190 ksp = self.snes.getKSP() 191 self.assertEqual(ksp.getRefCount(), 2) 192 193 def testSolve(self): 194 J = PETSc.Mat().create(PETSc.COMM_SELF) 195 J.setSizes([2,2]) 196 J.setType(PETSc.Mat.Type.SEQAIJ) 197 J.setUp() 198 r = PETSc.Vec().createSeq(2) 199 x = PETSc.Vec().createSeq(2) 200 b = PETSc.Vec().createSeq(2) 201 self.snes.setFunction(Function(), r) 202 self.snes.setJacobian(Jacobian(), J) 203 x.setArray([2,3]) 204 b.set(0) 205 self.snes.setConvergenceHistory() 206 self.snes.setFromOptions() 207 self.snes.solve(b, x) 208 rh, ih = self.snes.getConvergenceHistory() 209 self.snes.setConvergenceHistory(0, reset=True) 210 rh, ih = self.snes.getConvergenceHistory() 211 self.assertEqual(len(rh), 0) 212 self.assertEqual(len(ih), 0) 213 self.assertAlmostEqual(abs(x[0]), 1.0) 214 self.assertAlmostEqual(abs(x[1]), 2.0) 215 # XXX this test should not be here ! 216 reason = self.snes.callConvergenceTest(1, 0, 0, 0) 217 self.assertTrue(reason > 0) 218 219 def testResetAndSolve(self): 220 self.snes.reset() 221 self.testSolve() 222 self.snes.reset() 223 self.testSolve() 224 self.snes.reset() 225 226 def testSetMonitor(self): 227 reshist = {} 228 def monitor(snes, its, fgnorm): 229 reshist[its] = fgnorm 230 refcnt = getrefcount(monitor) 231 self.snes.setMonitor(monitor) 232 self.assertEqual(getrefcount(monitor), refcnt + 1) 233 self.testSolve() 234 self.assertTrue(len(reshist) > 0) 235 reshist = {} 236 self.snes.cancelMonitor() 237 self.assertEqual(getrefcount(monitor), refcnt) 238 self.testSolve() 239 self.assertTrue(len(reshist) == 0) 240 self.snes.setMonitor(monitor) 241 self.snes.monitor(1, 7) 242 self.assertTrue(reshist[1] == 7) 243 ## Monitor = PETSc.SNES.Monitor 244 ## self.snes.setMonitor(Monitor()) 245 ## self.snes.setMonitor(Monitor.DEFAULT) 246 ## self.snes.setMonitor(Monitor.SOLUTION) 247 ## self.snes.setMonitor(Monitor.RESIDUAL) 248 ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE) 249 250 def testSetGetStepFails(self): 251 its = self.snes.getIterationNumber() 252 self.assertEqual(its, 0) 253 fails = self.snes.getNonlinearStepFailures() 254 self.assertEqual(fails, 0) 255 fails = self.snes.getMaxNonlinearStepFailures() 256 self.assertEqual(fails, 1) 257 self.snes.setMaxNonlinearStepFailures(5) 258 fails = self.snes.getMaxNonlinearStepFailures() 259 self.assertEqual(fails, 5) 260 self.snes.setMaxNonlinearStepFailures(1) 261 fails = self.snes.getMaxNonlinearStepFailures() 262 self.assertEqual(fails, 1) 263 264 def testSetGetLinFails(self): 265 its = self.snes.getLinearSolveIterations() 266 self.assertEqual(its, 0) 267 fails = self.snes.getLinearSolveFailures() 268 self.assertEqual(fails, 0) 269 fails = self.snes.getMaxLinearSolveFailures() 270 self.assertEqual(fails, 1) 271 self.snes.setMaxLinearSolveFailures(5) 272 fails = self.snes.getMaxLinearSolveFailures() 273 self.assertEqual(fails, 5) 274 self.snes.setMaxLinearSolveFailures(1) 275 fails = self.snes.getMaxLinearSolveFailures() 276 self.assertEqual(fails, 1) 277 278 def testEW(self): 279 self.snes.setUseEW(False) 280 self.assertFalse(self.snes.getUseEW()) 281 self.snes.setUseEW(True) 282 self.assertTrue(self.snes.getUseEW()) 283 params = self.snes.getParamsEW() 284 params['version'] = 1 285 self.snes.setParamsEW(**params) 286 params = self.snes.getParamsEW() 287 self.assertEqual(params['version'], 1) 288 params['version'] = PETSc.DEFAULT 289 self.snes.setParamsEW(**params) 290 params = self.snes.getParamsEW() 291 self.assertEqual(params['version'], 1) 292 293 def testMF(self): 294 #self.snes.setOptionsPrefix('MF-') 295 #opts = PETSc.Options(self.snes) 296 #opts['mat_mffd_type'] = 'ds' 297 #opts['snes_monitor'] = 'stdout' 298 #opts['ksp_monitor'] = 'stdout' 299 #opts['snes_view'] = 'stdout' 300 J = PETSc.Mat().create(PETSc.COMM_SELF) 301 J.setSizes([2,2]) 302 J.setType(PETSc.Mat.Type.SEQAIJ) 303 J.setUp() 304 r = PETSc.Vec().createSeq(2) 305 x = PETSc.Vec().createSeq(2) 306 b = PETSc.Vec().createSeq(2) 307 fun = Function() 308 jac = Jacobian() 309 self.snes.setFunction(fun, r) 310 self.snes.setJacobian(jac, J) 311 self.assertFalse(self.snes.getUseMF()) 312 self.snes.setUseMF(False) 313 self.assertFalse(self.snes.getUseMF()) 314 self.snes.setUseMF(True) 315 self.assertTrue(self.snes.getUseMF()) 316 self.snes.setFromOptions() 317 x.setArray([2,3]) 318 b.set(0) 319 self.snes.solve(b, x) 320 self.assertAlmostEqual(abs(x[0]), 1.0) 321 self.assertAlmostEqual(abs(x[1]), 2.0) 322 323 def testFDColor(self): 324 J = PETSc.Mat().create(PETSc.COMM_SELF) 325 J.setSizes([2,2]) 326 J.setType(PETSc.Mat.Type.SEQAIJ) 327 J.setUp() 328 r = PETSc.Vec().createSeq(2) 329 x = PETSc.Vec().createSeq(2) 330 b = PETSc.Vec().createSeq(2) 331 fun = Function() 332 jac = Jacobian() 333 self.snes.setFunction(fun, r) 334 self.snes.setJacobian(jac, J) 335 self.assertFalse(self.snes.getUseFD()) 336 jac(self.snes, x, J, J) 337 self.snes.setUseFD(False) 338 self.assertFalse(self.snes.getUseFD()) 339 self.snes.setUseFD(True) 340 self.assertTrue(self.snes.getUseFD()) 341 self.snes.setFromOptions() 342 x.setArray([2,3]) 343 b.set(0) 344 self.snes.solve(b, x) 345 self.assertAlmostEqual(abs(x[0]), 1.0) 346 self.assertAlmostEqual(abs(x[1]), 2.0) 347 348# -------------------------------------------------------------------- 349 350class TestSNESLS(BaseTestSNES, unittest.TestCase): 351 SNES_TYPE = PETSc.SNES.Type.NEWTONLS 352 353class TestSNESTR(BaseTestSNES, unittest.TestCase): 354 SNES_TYPE = PETSc.SNES.Type.NEWTONTR 355 356# -------------------------------------------------------------------- 357 358if __name__ == '__main__': 359 unittest.main() 360