1from mpi4py import MPI
2import numpy as np
3
4tic = MPI.Wtime()
5
6x1 = -2.0
7x2 =  1.0
8y1 = -1.0
9y2 =  1.0
10
11w = 150
12h = 100
13maxit = 127
14
15def mandelbrot(x, y, maxit):
16    c = x + y*1j
17    z = 0 + 0j
18    it = 0
19    while abs(z) < 2 and it < maxit:
20        z = z**2 + c
21        it += 1
22    return it
23
24comm = MPI.COMM_WORLD
25size = comm.Get_size()
26rank = comm.Get_rank()
27
28rmsg = np.empty(4, dtype='f')
29imsg = np.empty(3, dtype='i')
30
31if rank == 0:
32    rmsg[:] = [x1, x2, y1, y2]
33    imsg[:] = [w, h, maxit]
34
35comm.Bcast([rmsg, MPI.FLOAT], root=0)
36comm.Bcast([imsg, MPI.INT], root=0)
37
38x1, x2, y1, y2 = [float(r) for r in rmsg]
39w, h, maxit    = [int(i) for i in imsg]
40dx = (x2 - x1) / w
41dy = (y2 - y1) / h
42
43# number of lines to compute here
44N = h // size + (h % size > rank)
45N = np.array(N, dtype='i')
46# indices of lines to compute here
47I = np.arange(rank, h, size, dtype='i')
48# compute local lines
49C = np.empty([N, w], dtype='i')
50for k in np.arange(N):
51    y = y1 + I[k] * dy
52    for j in np.arange(w):
53        x = x1 + j * dx
54        C[k, j] = mandelbrot(x, y, maxit)
55# gather results at root
56counts = 0
57indices = None
58cdata = None
59if rank == 0:
60    counts = np.empty(size, dtype='i')
61    indices = np.empty(h, dtype='i')
62    cdata = np.empty([h, w], dtype='i')
63comm.Gather(sendbuf=[N, MPI.INT],
64            recvbuf=[counts, MPI.INT],
65            root=0)
66comm.Gatherv(sendbuf=[I, MPI.INT],
67             recvbuf=[indices, (counts, None), MPI.INT],
68             root=0)
69comm.Gatherv(sendbuf=[C, MPI.INT],
70             recvbuf=[cdata, (counts*w, None), MPI.INT],
71             root=0)
72# reconstruct full result at root
73if rank == 0:
74    M = np.zeros([h,w], dtype='i')
75    M[indices, :] = cdata
76
77toc = MPI.Wtime()
78wct = comm.gather(toc-tic, root=0)
79if rank == 0:
80    for task, time in enumerate(wct):
81        print('wall clock time: %8.2f seconds (task %d)' % (time, task))
82    def mean(seq): return sum(seq)/len(seq)
83    print    ('all tasks, mean: %8.2f seconds' % mean(wct))
84    print    ('all tasks, min:  %8.2f seconds' % min(wct))
85    print    ('all tasks, max:  %8.2f seconds' % max(wct))
86    print    ('all tasks, sum:  %8.2f seconds' % sum(wct))
87
88# eye candy (requires matplotlib)
89if rank == 0:
90    try:
91        from matplotlib import pyplot as plt
92        plt.imshow(M, aspect='equal')
93        try:
94            plt.nipy_spectral()
95        except AttributeError:
96            plt.spectral()
97        plt.pause(2)
98    except:
99        pass
100MPI.COMM_WORLD.Barrier()
101