1import numpy as np 2from mpi4py import MPI 3 4ROOT_PROC_RANK = 0 5 6 7def get_num_procs(): 8 return MPI.COMM_WORLD.Get_size() 9 10 11def get_proc_rank(): 12 return MPI.COMM_WORLD.Get_rank() 13 14 15def is_root_proc(): 16 rank = get_proc_rank() 17 return rank == ROOT_PROC_RANK 18 19 20def bcast(x): 21 MPI.COMM_WORLD.Bcast(x, root=ROOT_PROC_RANK) 22 return 23 24 25def reduce_sum(x): 26 return reduce_all(x, MPI.SUM) 27 28 29def reduce_prod(x): 30 return reduce_all(x, MPI.PROD) 31 32 33def reduce_avg(x): 34 buffer = reduce_sum(x) 35 buffer /= get_num_procs() 36 return buffer 37 38 39def reduce_min(x): 40 return reduce_all(x, MPI.MIN) 41 42 43def reduce_max(x): 44 return reduce_all(x, MPI.MAX) 45 46 47def reduce_all(x, op): 48 is_array = isinstance(x, np.ndarray) 49 x_buf = x if is_array else np.array([x]) 50 buffer = np.zeros_like(x_buf) 51 MPI.COMM_WORLD.Allreduce(x_buf, buffer, op=op) 52 buffer = buffer if is_array else buffer[0] 53 return buffer 54 55 56def gather_all(x): 57 is_array = isinstance(x, np.ndarray) 58 x_buf = np.array([x]) 59 buffer = np.zeros_like(x_buf) 60 buffer = np.repeat(buffer, get_num_procs(), axis=0) 61 MPI.COMM_WORLD.Allgather(x_buf, buffer) 62 buffer = list(buffer) 63 return buffer 64