1from mpi4py import MPI 2import numpy as np 3 4tic = MPI.Wtime() 5 6x1 = -2.0 7x2 = 1.0 8y1 = -1.0 9y2 = 1.0 10 11w = 150 12h = 100 13maxit = 127 14 15def mandelbrot(x, y, maxit): 16 c = x + y*1j 17 z = 0 + 0j 18 it = 0 19 while abs(z) < 2 and it < maxit: 20 z = z**2 + c 21 it += 1 22 return it 23 24comm = MPI.COMM_WORLD 25size = comm.Get_size() 26rank = comm.Get_rank() 27 28rmsg = np.empty(4, dtype='f') 29imsg = np.empty(3, dtype='i') 30 31if rank == 0: 32 rmsg[:] = [x1, x2, y1, y2] 33 imsg[:] = [w, h, maxit] 34 35comm.Bcast([rmsg, MPI.FLOAT], root=0) 36comm.Bcast([imsg, MPI.INT], root=0) 37 38x1, x2, y1, y2 = [float(r) for r in rmsg] 39w, h, maxit = [int(i) for i in imsg] 40dx = (x2 - x1) / w 41dy = (y2 - y1) / h 42 43# number of lines to compute here 44N = h // size + (h % size > rank) 45N = np.array(N, dtype='i') 46# indices of lines to compute here 47I = np.arange(rank, h, size, dtype='i') 48# compute local lines 49C = np.empty([N, w], dtype='i') 50for k in np.arange(N): 51 y = y1 + I[k] * dy 52 for j in np.arange(w): 53 x = x1 + j * dx 54 C[k, j] = mandelbrot(x, y, maxit) 55# gather results at root 56counts = 0 57indices = None 58cdata = None 59if rank == 0: 60 counts = np.empty(size, dtype='i') 61 indices = np.empty(h, dtype='i') 62 cdata = np.empty([h, w], dtype='i') 63comm.Gather(sendbuf=[N, MPI.INT], 64 recvbuf=[counts, MPI.INT], 65 root=0) 66comm.Gatherv(sendbuf=[I, MPI.INT], 67 recvbuf=[indices, (counts, None), MPI.INT], 68 root=0) 69comm.Gatherv(sendbuf=[C, MPI.INT], 70 recvbuf=[cdata, (counts*w, None), MPI.INT], 71 root=0) 72# reconstruct full result at root 73if rank == 0: 74 M = np.zeros([h,w], dtype='i') 75 M[indices, :] = cdata 76 77toc = MPI.Wtime() 78wct = comm.gather(toc-tic, root=0) 79if rank == 0: 80 for task, time in enumerate(wct): 81 print('wall clock time: %8.2f seconds (task %d)' % (time, task)) 82 def mean(seq): return sum(seq)/len(seq) 83 print ('all tasks, mean: %8.2f seconds' % mean(wct)) 84 print ('all tasks, min: %8.2f seconds' % min(wct)) 85 print ('all tasks, max: %8.2f seconds' % max(wct)) 86 print ('all tasks, sum: %8.2f seconds' % sum(wct)) 87 88# eye candy (requires matplotlib) 89if rank == 0: 90 try: 91 from matplotlib import pyplot as plt 92 plt.imshow(M, aspect='equal') 93 try: 94 plt.nipy_spectral() 95 except AttributeError: 96 plt.spectral() 97 plt.pause(2) 98 except: 99 pass 100MPI.COMM_WORLD.Barrier() 101