1# http://mvapich.cse.ohio-state.edu/benchmarks/ 2 3from mpi4py import MPI 4 5def osu_multi_lat( 6 BENCHMARH = "MPI Multi Latency Test", 7 skip_small = 100, 8 loop_small = 10000, 9 skip_large = 10, 10 loop_large = 1000, 11 large_message_size = 8192, 12 MAX_MSG_SIZE = 1<<22, 13 ): 14 15 comm = MPI.COMM_WORLD 16 myid = comm.Get_rank() 17 nprocs = comm.Get_size() 18 pairs = nprocs/2 19 20 s_buf = allocate(MAX_MSG_SIZE) 21 r_buf = allocate(MAX_MSG_SIZE) 22 23 if myid == 0: 24 print ('# %s' % (BENCHMARH,)) 25 if myid == 0: 26 print ('# %-8s%20s' % ("Size [B]", "Latency [us]")) 27 28 message_sizes = [0] + [2**i for i in range(30)] 29 for size in message_sizes: 30 if size > MAX_MSG_SIZE: 31 break 32 if size > large_message_size: 33 skip = skip_large 34 loop = loop_large 35 else: 36 skip = skip_small 37 loop = loop_small 38 iterations = list(range(loop+skip)) 39 s_msg = [s_buf, size, MPI.BYTE] 40 r_msg = [r_buf, size, MPI.BYTE] 41 # 42 comm.Barrier() 43 if myid < pairs: 44 partner = myid + pairs 45 for i in iterations: 46 if i == skip: 47 t_start = MPI.Wtime() 48 comm.Send(s_msg, partner, 1) 49 comm.Recv(r_msg, partner, 1) 50 t_end = MPI.Wtime() 51 else: 52 partner = myid - pairs 53 for i in iterations: 54 if i == skip: 55 t_start = MPI.Wtime() 56 comm.Recv(r_msg, partner, 1) 57 comm.Send(s_msg, partner, 1) 58 t_end = MPI.Wtime() 59 # 60 latency = (t_end - t_start) * 1e6 / (2 * loop) 61 total_lat = comm.reduce(latency, root=0, op=MPI.SUM) 62 if myid == 0: 63 avg_lat = total_lat/(pairs * 2) 64 print ('%-10d%20.2f' % (size, avg_lat)) 65 66 67def allocate(n): 68 try: 69 import mmap 70 return mmap.mmap(-1, n) 71 except (ImportError, EnvironmentError): 72 try: 73 from numpy import zeros 74 return zeros(n, 'B') 75 except ImportError: 76 from array import array 77 return array('B', [0]) * n 78 79 80if __name__ == '__main__': 81 osu_multi_lat() 82