1# http://mvapich.cse.ohio-state.edu/benchmarks/
2
3from mpi4py import MPI
4
5def osu_scatter(
6    BENCHMARH = "MPI Scatter Latency Test",
7    skip = 1000,
8    loop = 10000,
9    skip_large = 10,
10    loop_large = 100,
11    large_message_size = 8192,
12    MAX_MSG_SIZE = 1<<20,
13    ):
14
15    comm = MPI.COMM_WORLD
16    myid = comm.Get_rank()
17    numprocs = comm.Get_size()
18
19    if numprocs < 2:
20        if myid == 0:
21            errmsg = "This test requires at least two processes"
22        else:
23            errmsg = None
24        raise SystemExit(errmsg)
25
26    if myid == 0:
27        s_buf = allocate(MAX_MSG_SIZE*numprocs)
28    else:
29        r_buf = allocate(MAX_MSG_SIZE)
30
31    if myid == 0:
32        print ('# %s' % (BENCHMARH,))
33    if myid == 0:
34        print ('# %-8s%20s' % ("Size [B]", "Latency [us]"))
35
36    for size in message_sizes(MAX_MSG_SIZE):
37        if size > large_message_size:
38            skip = skip_large
39            loop = loop_large
40        iterations = list(range(loop+skip))
41        if myid == 0:
42            s_msg = [s_buf, size, MPI.BYTE]
43            r_msg = MPI.IN_PLACE
44        else:
45            s_msg = None
46            r_msg = [r_buf, size, MPI.BYTE]
47        #
48        comm.Barrier()
49        for i in iterations:
50            if i == skip:
51                t_start = MPI.Wtime()
52            comm.Scatter(s_msg, r_msg, 0)
53        t_end = MPI.Wtime()
54        comm.Barrier()
55        #
56        if myid == 0:
57            latency = (t_end - t_start) * 1e6 / loop
58            print ('%-10d%20.2f' % (size, latency))
59
60
61def message_sizes(max_size):
62    return [0] + [(1<<i) for i in range(30)
63                  if (1<<i) <= max_size]
64
65def allocate(n):
66    try:
67        import mmap
68        return mmap.mmap(-1, n)
69    except (ImportError, EnvironmentError):
70        try:
71            from numpy import zeros
72            return zeros(n, 'B')
73        except ImportError:
74            from array import array
75            return array('B', [0]) * n
76
77
78if __name__ == '__main__':
79    osu_scatter()
80