1from mpi4py import MPI 2 3class Intracomm(MPI.Intracomm): 4 """ 5 Intracommunicator class with scalable, point-to-point based 6 implementations of global reduction operations. 7 """ 8 9 def __new__(cls, comm=None): 10 return super(Intracomm, cls).__new__(cls, comm) 11 12 def reduce(self, sendobj=None, recvobj=None, op=MPI.SUM, root=0): 13 size = self.size 14 rank = self.rank 15 assert 0 <= root < size 16 tag = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)-1 17 18 recvobj = sendobj 19 mask = 1 20 21 while mask < size: 22 if (mask & rank) != 0: 23 target = (rank & ~mask) % size 24 self.send(recvobj, dest=target, tag=tag) 25 else: 26 target = (rank | mask) 27 if target < size: 28 tmp = self.recv(None, source=target, tag=tag) 29 recvobj = op(recvobj, tmp) 30 mask <<= 1 31 32 if root != 0: 33 if rank == 0: 34 self.send(recvobj, dest=root, tag=tag) 35 elif rank == root: 36 recvobj = self.recv(None, source=0, tag=tag) 37 38 if rank != root: 39 recvobj = None 40 41 return recvobj 42 43 def allreduce(self, sendobj=None, recvobj=None, op=MPI.SUM): 44 recvobj = self.reduce(sendobj, recvobj, op, 0) 45 recvobj = self.bcast(recvobj, 0) 46 return recvobj 47 48 def scan(self, sendobj=None, recvobj=None, op=MPI.SUM): 49 size = self.size 50 rank = self.rank 51 tag = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)-1 52 53 recvobj = sendobj 54 partial = sendobj 55 mask = 1 56 57 while mask < size: 58 target = rank ^ mask 59 if target < size: 60 tmp = self.sendrecv(partial, dest=target, source=target, 61 sendtag=tag, recvtag=tag) 62 if rank > target: 63 partial = op(tmp, partial) 64 recvobj = op(tmp, recvobj) 65 else: 66 tmp = op(partial, tmp) 67 partial = tmp 68 mask <<= 1 69 70 return recvobj 71 72 def exscan(self, sendobj=None, recvobj=None, op=MPI.SUM): 73 size = self.size 74 rank = self.rank 75 tag = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)-1 76 77 recvobj = sendobj 78 partial = sendobj 79 mask = 1 80 flag = False 81 82 while mask < size: 83 target = rank ^ mask 84 if target < size: 85 tmp = self.sendrecv(partial, dest=target, source=target, 86 sendtag=tag, recvtag=tag) 87 if rank > target: 88 partial = op(tmp, partial) 89 if rank != 0: 90 if not flag: 91 recvobj = tmp 92 flag = True 93 else: 94 recvobj = op(tmp, recvobj) 95 else: 96 tmp = op(partial, tmp) 97 partial = tmp 98 mask <<= 1 99 100 if rank == 0: 101 recvobj = None 102 103 return recvobj 104