1import sys, petsc4py 2petsc4py.init(sys.argv) 3 4from petsc4py import PETSc 5import Bratu2D as Bratu2D 6 7class App(object): 8 9 def __init__(self, da, lambda_): 10 assert da.getDim() == 2 11 self.da = da 12 self.lambda_ = lambda_ 13 14 def formInitGuess(self, snes, X): 15 X.zeroEntries() # just in case 16 da = self.da.fortran 17 vec_X = X.fortran 18 ierr = Bratu2D.FormInitGuess(da, vec_X, self.lambda_) 19 if ierr: raise PETSc.Error(ierr) 20 21 def formFunction(self, snes, X, F): 22 F.zeroEntries() # just in case 23 da = self.da.fortran 24 vec_X = X.fortran 25 vec_F = F.fortran 26 ierr = Bratu2D.FormFunction(da, vec_X, vec_F, self.lambda_) 27 if ierr: raise PETSc.Error(ierr) 28 29 def formJacobian(self, snes, X, J, P): 30 P.zeroEntries() # just in case 31 da = self.da.fortran 32 vec_X = X.fortran 33 mat_P = P.fortran 34 ierr = Bratu2D.FormJacobian(da, vec_X, mat_P, self.lambda_) 35 if ierr: raise PETSc.Error(ierr) 36 if J != P: J.assemble() # matrix-free operator 37 return PETSc.Mat.Structure.SAME_NONZERO_PATTERN 38 39 40OptDB = PETSc.Options() 41 42N = OptDB.getInt('N', 16) 43lambda_ = OptDB.getReal('lambda', 6.0) 44do_plot = OptDB.getBool('plot', False) 45 46da = PETSc.DA().create([N, N], stencil_width=1) 47app = App(da, lambda_) 48 49snes = PETSc.SNES().create() 50F = da.createGlobalVec() 51snes.setFunction(app.formFunction, F) 52J = da.createMat() 53snes.setJacobian(app.formJacobian, J) 54 55snes.setFromOptions() 56 57X = da.createGlobalVec() 58app.formInitGuess(snes, X) 59snes.solve(None, X) 60 61U = da.createNaturalVec() 62da.globalToNatural(X, U) 63 64def plot(da, U): 65 comm = da.getComm() 66 scatter, U0 = PETSc.Scatter.toZero(U) 67 scatter.scatter(U, U0, False, PETSc.Scatter.Mode.FORWARD) 68 rank = comm.getRank() 69 if rank == 0: 70 solution = U0[...] 71 solution = solution.reshape(da.sizes, order='f').copy() 72 try: 73 from matplotlib import pyplot 74 pyplot.contourf(solution) 75 pyplot.axis('equal') 76 pyplot.show() 77 except: 78 pass 79 comm.barrier() 80 scatter.destroy() 81 U0.destroy() 82 83if do_plot: plot(da, U) 84 85 86U.destroy() 87X.destroy() 88F.destroy() 89J.destroy() 90da.destroy() 91snes.destroy() 92