1# http://mvapich.cse.ohio-state.edu/benchmarks/
2
3from mpi4py import MPI
4
5def osu_multi_lat(
6    BENCHMARH = "MPI Multi Latency Test",
7    skip_small = 100,
8    loop_small = 10000,
9    skip_large = 10,
10    loop_large = 1000,
11    large_message_size = 8192,
12    MAX_MSG_SIZE = 1<<22,
13    ):
14
15    comm = MPI.COMM_WORLD
16    myid = comm.Get_rank()
17    nprocs = comm.Get_size()
18    pairs = nprocs/2
19
20    s_buf = allocate(MAX_MSG_SIZE)
21    r_buf = allocate(MAX_MSG_SIZE)
22
23    if myid == 0:
24        print ('# %s' % (BENCHMARH,))
25    if myid == 0:
26        print ('# %-8s%20s' % ("Size [B]", "Latency [us]"))
27
28    message_sizes = [0] + [2**i for i in range(30)]
29    for size in message_sizes:
30        if size > MAX_MSG_SIZE:
31            break
32        if size > large_message_size:
33            skip = skip_large
34            loop = loop_large
35        else:
36            skip = skip_small
37            loop = loop_small
38        iterations = list(range(loop+skip))
39        s_msg = [s_buf, size, MPI.BYTE]
40        r_msg = [r_buf, size, MPI.BYTE]
41        #
42        comm.Barrier()
43        if myid < pairs:
44            partner = myid + pairs
45            for i in iterations:
46                if i == skip:
47                    t_start = MPI.Wtime()
48                comm.Send(s_msg, partner, 1)
49                comm.Recv(r_msg, partner, 1)
50            t_end = MPI.Wtime()
51        else:
52            partner = myid - pairs
53            for i in iterations:
54                if i == skip:
55                    t_start = MPI.Wtime()
56                comm.Recv(r_msg, partner, 1)
57                comm.Send(s_msg, partner, 1)
58            t_end = MPI.Wtime()
59        #
60        latency = (t_end - t_start) * 1e6 / (2 * loop)
61        total_lat = comm.reduce(latency, root=0, op=MPI.SUM)
62        if myid == 0:
63            avg_lat = total_lat/(pairs * 2)
64            print ('%-10d%20.2f' % (size, avg_lat))
65
66
67def allocate(n):
68    try:
69        import mmap
70        return mmap.mmap(-1, n)
71    except (ImportError, EnvironmentError):
72        try:
73            from numpy import zeros
74            return zeros(n, 'B')
75        except ImportError:
76            from array import array
77            return array('B', [0]) * n
78
79
80if __name__ == '__main__':
81    osu_multi_lat()
82