1import numpy as np
2from mpi4py import MPI
3
4ROOT_PROC_RANK = 0
5
6
7def get_num_procs():
8  return MPI.COMM_WORLD.Get_size()
9
10
11def get_proc_rank():
12  return MPI.COMM_WORLD.Get_rank()
13
14
15def is_root_proc():
16  rank = get_proc_rank()
17  return rank == ROOT_PROC_RANK
18
19
20def bcast(x):
21  MPI.COMM_WORLD.Bcast(x, root=ROOT_PROC_RANK)
22  return
23
24
25def reduce_sum(x):
26  return reduce_all(x, MPI.SUM)
27
28
29def reduce_prod(x):
30  return reduce_all(x, MPI.PROD)
31
32
33def reduce_avg(x):
34  buffer = reduce_sum(x)
35  buffer /= get_num_procs()
36  return buffer
37
38
39def reduce_min(x):
40  return reduce_all(x, MPI.MIN)
41
42
43def reduce_max(x):
44  return reduce_all(x, MPI.MAX)
45
46
47def reduce_all(x, op):
48  is_array = isinstance(x, np.ndarray)
49  x_buf = x if is_array else np.array([x])
50  buffer = np.zeros_like(x_buf)
51  MPI.COMM_WORLD.Allreduce(x_buf, buffer, op=op)
52  buffer = buffer if is_array else buffer[0]
53  return buffer
54
55
56def gather_all(x):
57  is_array = isinstance(x, np.ndarray)
58  x_buf = np.array([x])
59  buffer = np.zeros_like(x_buf)
60  buffer = np.repeat(buffer, get_num_procs(), axis=0)
61  MPI.COMM_WORLD.Allgather(x_buf, buffer)
62  buffer = list(buffer)
63  return buffer
64