1from mpi4py import MPI
2
3class Intracomm(MPI.Intracomm):
4    """
5    Intracommunicator class with scalable, point-to-point based
6    implementations of global reduction operations.
7    """
8
9    def __new__(cls, comm=None):
10        return super(Intracomm, cls).__new__(cls, comm)
11
12    def reduce(self, sendobj=None, recvobj=None, op=MPI.SUM, root=0):
13        size = self.size
14        rank = self.rank
15        assert 0 <= root < size
16        tag = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)-1
17
18        recvobj = sendobj
19        mask = 1
20
21        while mask < size:
22            if (mask & rank) != 0:
23                target = (rank & ~mask) % size
24                self.send(recvobj, dest=target, tag=tag)
25            else:
26                target = (rank | mask)
27                if target < size:
28                    tmp = self.recv(None, source=target, tag=tag)
29                    recvobj = op(recvobj, tmp)
30            mask <<= 1
31
32        if root != 0:
33            if rank == 0:
34                self.send(recvobj, dest=root, tag=tag)
35            elif rank == root:
36                recvobj = self.recv(None, source=0, tag=tag)
37
38        if rank != root:
39            recvobj = None
40
41        return recvobj
42
43    def allreduce(self, sendobj=None, recvobj=None, op=MPI.SUM):
44        recvobj = self.reduce(sendobj, recvobj, op, 0)
45        recvobj = self.bcast(recvobj, 0)
46        return recvobj
47
48    def scan(self, sendobj=None, recvobj=None, op=MPI.SUM):
49        size = self.size
50        rank = self.rank
51        tag = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)-1
52
53        recvobj = sendobj
54        partial = sendobj
55        mask = 1
56
57        while mask < size:
58            target = rank ^ mask
59            if target < size:
60                tmp = self.sendrecv(partial, dest=target, source=target,
61                                    sendtag=tag, recvtag=tag)
62                if rank > target:
63                    partial = op(tmp, partial)
64                    recvobj = op(tmp, recvobj)
65                else:
66                    tmp = op(partial, tmp)
67                    partial = tmp
68            mask <<= 1
69
70        return recvobj
71
72    def exscan(self, sendobj=None, recvobj=None, op=MPI.SUM):
73        size = self.size
74        rank = self.rank
75        tag = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)-1
76
77        recvobj = sendobj
78        partial = sendobj
79        mask = 1
80        flag = False
81
82        while mask < size:
83            target = rank ^ mask
84            if target < size:
85                tmp = self.sendrecv(partial, dest=target, source=target,
86                                    sendtag=tag, recvtag=tag)
87                if rank > target:
88                    partial = op(tmp, partial)
89                    if rank != 0:
90                        if not flag:
91                            recvobj = tmp
92                            flag = True
93                        else:
94                            recvobj = op(tmp, recvobj)
95                else:
96                    tmp = op(partial, tmp)
97                    partial = tmp
98            mask <<= 1
99
100        if rank == 0:
101            recvobj = None
102
103        return recvobj
104