1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9class Function:
10    def __call__(self, snes, x, f):
11        f[0] = (x[0]*x[0] + x[0]*x[1] - 3.0).item()
12        f[1] = (x[0]*x[1] + x[1]*x[1] - 6.0).item()
13        f.assemble()
14
15class Jacobian:
16    def __call__(self, snes, x, J, P):
17        P[0,0] = (2.0*x[0] + x[1]).item()
18        P[0,1] = (x[0]).item()
19        P[1,0] = (x[1]).item()
20        P[1,1] = (x[0] + 2.0*x[1]).item()
21        P.assemble()
22        if J != P: J.assemble()
23
24# --------------------------------------------------------------------
25
26class BaseTestSNES(object):
27
28    SNES_TYPE = None
29
30    def setUp(self):
31        snes = PETSc.SNES()
32        snes.create(PETSc.COMM_SELF)
33        if self.SNES_TYPE:
34            snes.setType(self.SNES_TYPE)
35        self.snes = snes
36
37    def tearDown(self):
38        self.snes = None
39
40    def testGetSetType(self):
41        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
42        self.snes.setType(self.SNES_TYPE)
43        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
44
45    def testTols(self):
46        tols = self.snes.getTolerances()
47        self.snes.setTolerances(*tols)
48        tnames = ('rtol', 'atol','stol', 'max_it')
49        tolvals = [getattr(self.snes, t) for t in  tnames]
50        self.assertEqual(tuple(tols), tuple(tolvals))
51
52    def testProperties(self):
53        snes = self.snes
54        #
55        snes.appctx = (1,2,3)
56        self.assertEqual(snes.appctx, (1,2,3))
57        snes.appctx = None
58        self.assertEqual(snes.appctx, None)
59        #
60        snes.its = 1
61        self.assertEqual(snes.its, 1)
62        snes.its = 0
63        self.assertEqual(snes.its, 0)
64        #
65        snes.norm = 1
66        self.assertEqual(snes.norm, 1)
67        snes.norm = 0
68        self.assertEqual(snes.norm, 0)
69        #
70        rh, ih = snes.history
71        self.assertTrue(len(rh)==0)
72        self.assertTrue(len(ih)==0)
73        #
74        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS
75        snes.reason = reason
76        self.assertEqual(snes.reason, reason)
77        self.assertTrue(snes.converged)
78        self.assertFalse(snes.diverged)
79        self.assertFalse(snes.iterating)
80        reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT
81        snes.reason = reason
82        self.assertEqual(snes.reason, reason)
83        self.assertFalse(snes.converged)
84        self.assertTrue(snes.diverged)
85        self.assertFalse(snes.iterating)
86        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING
87        snes.reason = reason
88        self.assertEqual(snes.reason, reason)
89        self.assertFalse(snes.converged)
90        self.assertFalse(snes.diverged)
91        self.assertTrue(snes.iterating)
92        #
93        self.assertFalse(snes.use_ew)
94        self.assertFalse(snes.use_mf)
95        self.assertFalse(snes.use_fd)
96
97    def testGetSetFunc(self):
98        r, func = self.snes.getFunction()
99        self.assertFalse(r)
100        self.assertTrue(func is None)
101        r = PETSc.Vec().createSeq(2)
102        func = Function()
103        refcnt = getrefcount(func)
104        self.snes.setFunction(func, r)
105        self.snes.setFunction(func, r)
106        self.assertEqual(getrefcount(func), refcnt + 1)
107        r2, func2 = self.snes.getFunction()
108        self.assertEqual(r, r2)
109        self.assertEqual(func, func2[0])
110        self.assertEqual(getrefcount(func), refcnt + 1)
111        r3, func3 = self.snes.getFunction()
112        self.assertEqual(r, r3)
113        self.assertEqual(func, func3[0])
114        self.assertEqual(getrefcount(func), refcnt + 1)
115
116    def testCompFunc(self):
117        r = PETSc.Vec().createSeq(2)
118        func = Function()
119        self.snes.setFunction(func, r)
120        x, y = r.duplicate(), r.duplicate()
121        x[0], x[1] = [1, 2]
122        self.snes.computeFunction(x, y)
123        self.assertAlmostEqual(abs(y[0]), 0.0)
124        self.assertAlmostEqual(abs(y[1]), 0.0)
125
126    def testGetSetJac(self):
127        A, P, jac = self.snes.getJacobian()
128        self.assertFalse(A)
129        self.assertFalse(P)
130        self.assertTrue(jac is None)
131        J = PETSc.Mat().create(PETSc.COMM_SELF)
132        J.setSizes([2,2])
133        J.setType(PETSc.Mat.Type.SEQAIJ)
134        J.setUp()
135        jac = Jacobian()
136        refcnt = getrefcount(jac)
137        self.snes.setJacobian(jac, J)
138        self.snes.setJacobian(jac, J)
139        self.assertEqual(getrefcount(jac), refcnt + 1)
140        J2, P2, jac2 = self.snes.getJacobian()
141        self.assertEqual(J, J2)
142        self.assertEqual(J2, P2)
143        self.assertEqual(jac, jac2[0])
144        self.assertEqual(getrefcount(jac), refcnt + 1)
145        J3, P3, jac3 = self.snes.getJacobian()
146        self.assertEqual(J, J3)
147        self.assertEqual(J3, P3)
148        self.assertEqual(jac, jac3[0])
149        self.assertEqual(getrefcount(jac), refcnt + 1)
150
151    def testCompJac(self):
152        J = PETSc.Mat().create(PETSc.COMM_SELF)
153        J.setSizes([2,2])
154        J.setType(PETSc.Mat.Type.SEQAIJ)
155        J.setUp()
156        jac = Jacobian()
157        self.snes.setJacobian(jac, J)
158        x = PETSc.Vec().createSeq(2)
159        x[0], x[1] = [1, 2]
160        self.snes.getKSP().getPC()
161        self.snes.computeJacobian(x, J)
162
163    def testGetSetUpd(self):
164        self.assertTrue(self.snes.getUpdate() is None)
165        upd = lambda snes, it: None
166        refcnt = getrefcount(upd)
167        self.snes.setUpdate(upd)
168        self.assertEqual(getrefcount(upd), refcnt + 1)
169        self.snes.setUpdate(upd)
170        self.assertEqual(getrefcount(upd), refcnt + 1)
171        self.snes.setUpdate(None)
172        self.assertTrue(self.snes.getUpdate() is None)
173        self.assertEqual(getrefcount(upd), refcnt)
174        self.snes.setUpdate(upd)
175        self.assertEqual(getrefcount(upd), refcnt + 1)
176        upd2 = lambda snes, it: None
177        refcnt2 = getrefcount(upd2)
178        self.snes.setUpdate(upd2)
179        self.assertEqual(getrefcount(upd),  refcnt)
180        self.assertEqual(getrefcount(upd2), refcnt2 + 1)
181        tmp = self.snes.getUpdate()[0]
182        self.assertTrue(tmp is upd2)
183        self.assertEqual(getrefcount(upd2), refcnt2 + 2)
184        del tmp
185        self.snes.setUpdate(None)
186        self.assertTrue(self.snes.getUpdate() is None)
187        self.assertEqual(getrefcount(upd2), refcnt2)
188
189    def testGetKSP(self):
190        ksp = self.snes.getKSP()
191        self.assertEqual(ksp.getRefCount(), 2)
192
193    def testSolve(self):
194        J = PETSc.Mat().create(PETSc.COMM_SELF)
195        J.setSizes([2,2])
196        J.setType(PETSc.Mat.Type.SEQAIJ)
197        J.setUp()
198        r = PETSc.Vec().createSeq(2)
199        x = PETSc.Vec().createSeq(2)
200        b = PETSc.Vec().createSeq(2)
201        self.snes.setFunction(Function(), r)
202        self.snes.setJacobian(Jacobian(), J)
203        x.setArray([2,3])
204        b.set(0)
205        self.snes.setConvergenceHistory()
206        self.snes.setFromOptions()
207        self.snes.solve(b, x)
208        rh, ih = self.snes.getConvergenceHistory()
209        self.snes.setConvergenceHistory(0, reset=True)
210        rh, ih = self.snes.getConvergenceHistory()
211        self.assertEqual(len(rh), 0)
212        self.assertEqual(len(ih), 0)
213        self.assertAlmostEqual(abs(x[0]), 1.0)
214        self.assertAlmostEqual(abs(x[1]), 2.0)
215        # XXX this test should not be here !
216        reason = self.snes.callConvergenceTest(1, 0, 0, 0)
217        self.assertTrue(reason > 0)
218
219    def testResetAndSolve(self):
220        self.snes.reset()
221        self.testSolve()
222        self.snes.reset()
223        self.testSolve()
224        self.snes.reset()
225
226    def testSetMonitor(self):
227        reshist = {}
228        def monitor(snes, its, fgnorm):
229            reshist[its] = fgnorm
230        refcnt = getrefcount(monitor)
231        self.snes.setMonitor(monitor)
232        self.assertEqual(getrefcount(monitor), refcnt + 1)
233        self.testSolve()
234        self.assertTrue(len(reshist) > 0)
235        reshist = {}
236        self.snes.cancelMonitor()
237        self.assertEqual(getrefcount(monitor), refcnt)
238        self.testSolve()
239        self.assertTrue(len(reshist) == 0)
240        self.snes.setMonitor(monitor)
241        self.snes.monitor(1, 7)
242        self.assertTrue(reshist[1] == 7)
243        ## Monitor = PETSc.SNES.Monitor
244        ## self.snes.setMonitor(Monitor())
245        ## self.snes.setMonitor(Monitor.DEFAULT)
246        ## self.snes.setMonitor(Monitor.SOLUTION)
247        ## self.snes.setMonitor(Monitor.RESIDUAL)
248        ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE)
249
250    def testSetGetStepFails(self):
251        its = self.snes.getIterationNumber()
252        self.assertEqual(its, 0)
253        fails = self.snes.getNonlinearStepFailures()
254        self.assertEqual(fails, 0)
255        fails = self.snes.getMaxNonlinearStepFailures()
256        self.assertEqual(fails, 1)
257        self.snes.setMaxNonlinearStepFailures(5)
258        fails = self.snes.getMaxNonlinearStepFailures()
259        self.assertEqual(fails, 5)
260        self.snes.setMaxNonlinearStepFailures(1)
261        fails = self.snes.getMaxNonlinearStepFailures()
262        self.assertEqual(fails, 1)
263
264    def testSetGetLinFails(self):
265        its = self.snes.getLinearSolveIterations()
266        self.assertEqual(its, 0)
267        fails = self.snes.getLinearSolveFailures()
268        self.assertEqual(fails, 0)
269        fails = self.snes.getMaxLinearSolveFailures()
270        self.assertEqual(fails, 1)
271        self.snes.setMaxLinearSolveFailures(5)
272        fails = self.snes.getMaxLinearSolveFailures()
273        self.assertEqual(fails, 5)
274        self.snes.setMaxLinearSolveFailures(1)
275        fails = self.snes.getMaxLinearSolveFailures()
276        self.assertEqual(fails, 1)
277
278    def testEW(self):
279        self.snes.setUseEW(False)
280        self.assertFalse(self.snes.getUseEW())
281        self.snes.setUseEW(True)
282        self.assertTrue(self.snes.getUseEW())
283        params = self.snes.getParamsEW()
284        params['version'] = 1
285        self.snes.setParamsEW(**params)
286        params = self.snes.getParamsEW()
287        self.assertEqual(params['version'], 1)
288        params['version'] = PETSc.DEFAULT
289        self.snes.setParamsEW(**params)
290        params = self.snes.getParamsEW()
291        self.assertEqual(params['version'], 1)
292
293    def testMF(self):
294        #self.snes.setOptionsPrefix('MF-')
295        #opts = PETSc.Options(self.snes)
296        #opts['mat_mffd_type'] = 'ds'
297        #opts['snes_monitor']  = 'stdout'
298        #opts['ksp_monitor']   = 'stdout'
299        #opts['snes_view']     = 'stdout'
300        J = PETSc.Mat().create(PETSc.COMM_SELF)
301        J.setSizes([2,2])
302        J.setType(PETSc.Mat.Type.SEQAIJ)
303        J.setUp()
304        r = PETSc.Vec().createSeq(2)
305        x = PETSc.Vec().createSeq(2)
306        b = PETSc.Vec().createSeq(2)
307        fun = Function()
308        jac = Jacobian()
309        self.snes.setFunction(fun, r)
310        self.snes.setJacobian(jac, J)
311        self.assertFalse(self.snes.getUseMF())
312        self.snes.setUseMF(False)
313        self.assertFalse(self.snes.getUseMF())
314        self.snes.setUseMF(True)
315        self.assertTrue(self.snes.getUseMF())
316        self.snes.setFromOptions()
317        x.setArray([2,3])
318        b.set(0)
319        self.snes.solve(b, x)
320        self.assertAlmostEqual(abs(x[0]), 1.0)
321        self.assertAlmostEqual(abs(x[1]), 2.0)
322
323    def testFDColor(self):
324        J = PETSc.Mat().create(PETSc.COMM_SELF)
325        J.setSizes([2,2])
326        J.setType(PETSc.Mat.Type.SEQAIJ)
327        J.setUp()
328        r = PETSc.Vec().createSeq(2)
329        x = PETSc.Vec().createSeq(2)
330        b = PETSc.Vec().createSeq(2)
331        fun = Function()
332        jac = Jacobian()
333        self.snes.setFunction(fun, r)
334        self.snes.setJacobian(jac, J)
335        self.assertFalse(self.snes.getUseFD())
336        jac(self.snes, x, J, J)
337        self.snes.setUseFD(False)
338        self.assertFalse(self.snes.getUseFD())
339        self.snes.setUseFD(True)
340        self.assertTrue(self.snes.getUseFD())
341        self.snes.setFromOptions()
342        x.setArray([2,3])
343        b.set(0)
344        self.snes.solve(b, x)
345        self.assertAlmostEqual(abs(x[0]), 1.0)
346        self.assertAlmostEqual(abs(x[1]), 2.0)
347
348# --------------------------------------------------------------------
349
350class TestSNESLS(BaseTestSNES, unittest.TestCase):
351    SNES_TYPE = PETSc.SNES.Type.NEWTONLS
352
353class TestSNESTR(BaseTestSNES, unittest.TestCase):
354    SNES_TYPE = PETSc.SNES.Type.NEWTONTR
355
356# --------------------------------------------------------------------
357
358if __name__ == '__main__':
359    unittest.main()
360