1import sys, petsc4py
2petsc4py.init(sys.argv)
3
4from petsc4py import PETSc
5import Bratu2D as Bratu2D
6
7class App(object):
8
9    def __init__(self, da, lambda_):
10        assert da.getDim() == 2
11        self.da = da
12        self.lambda_ = lambda_
13
14    def formInitGuess(self, snes, X):
15        X.zeroEntries() # just in case
16        da = self.da.fortran
17        vec_X = X.fortran
18        ierr = Bratu2D.FormInitGuess(da, vec_X, self.lambda_)
19        if ierr: raise PETSc.Error(ierr)
20
21    def formFunction(self, snes, X, F):
22        F.zeroEntries() # just in case
23        da = self.da.fortran
24        vec_X = X.fortran
25        vec_F = F.fortran
26        ierr = Bratu2D.FormFunction(da, vec_X, vec_F, self.lambda_)
27        if ierr: raise PETSc.Error(ierr)
28
29    def formJacobian(self, snes, X, J, P):
30        P.zeroEntries() # just in case
31        da = self.da.fortran
32        vec_X = X.fortran
33        mat_P = P.fortran
34        ierr = Bratu2D.FormJacobian(da, vec_X, mat_P, self.lambda_)
35        if ierr: raise PETSc.Error(ierr)
36        if J != P: J.assemble() # matrix-free operator
37        return PETSc.Mat.Structure.SAME_NONZERO_PATTERN
38
39
40OptDB = PETSc.Options()
41
42N = OptDB.getInt('N', 16)
43lambda_ = OptDB.getReal('lambda', 6.0)
44do_plot = OptDB.getBool('plot', False)
45
46da = PETSc.DA().create([N, N], stencil_width=1)
47app = App(da, lambda_)
48
49snes = PETSc.SNES().create()
50F = da.createGlobalVec()
51snes.setFunction(app.formFunction, F)
52J = da.createMat()
53snes.setJacobian(app.formJacobian, J)
54
55snes.setFromOptions()
56
57X = da.createGlobalVec()
58app.formInitGuess(snes, X)
59snes.solve(None, X)
60
61U = da.createNaturalVec()
62da.globalToNatural(X, U)
63
64def plot(da, U):
65    comm = da.getComm()
66    scatter, U0 = PETSc.Scatter.toZero(U)
67    scatter.scatter(U, U0, False, PETSc.Scatter.Mode.FORWARD)
68    rank = comm.getRank()
69    if rank == 0:
70        solution = U0[...]
71        solution = solution.reshape(da.sizes, order='f').copy()
72        try:
73            from matplotlib import pyplot
74            pyplot.contourf(solution)
75            pyplot.axis('equal')
76            pyplot.show()
77        except:
78            pass
79    comm.barrier()
80    scatter.destroy()
81    U0.destroy()
82
83if do_plot: plot(da, U)
84
85
86U.destroy()
87X.destroy()
88F.destroy()
89J.destroy()
90da.destroy()
91snes.destroy()
92