1from mpi4py import MPI
2import mpiunittest as unittest
3import arrayimpl
4
5
6class BaseTestP2PBuf(object):
7
8    COMM = MPI.COMM_NULL
9
10    def testSendrecv(self):
11        size = self.COMM.Get_size()
12        rank = self.COMM.Get_rank()
13        dest = (rank + 1) % size
14        source = (rank - 1) % size
15        for array, typecode in arrayimpl.subTest(self):
16            for s in range(0, size):
17                sbuf = array( s, typecode, s)
18                rbuf = array(-1, typecode, s+1)
19                self.COMM.Sendrecv(sbuf.as_mpi(), dest,   0,
20                                   rbuf.as_mpi(), source, 0)
21                for value in rbuf[:-1]:
22                    self.assertEqual(value, s)
23                self.assertEqual(rbuf[-1], -1)
24
25    def testSendrecvReplace(self):
26        size = self.COMM.Get_size()
27        rank = self.COMM.Get_rank()
28        dest = (rank + 1) % size
29        source = (rank - 1) % size
30        for array, typecode in arrayimpl.subTest(self):
31            for s in range(0, size):
32                buf = array(rank, typecode, s);
33                self.COMM.Sendrecv_replace(buf.as_mpi(), dest, 0, source, 0)
34                for value in buf:
35                    self.assertEqual(value, source)
36
37    def testSendRecv(self):
38        size = self.COMM.Get_size()
39        rank = self.COMM.Get_rank()
40        for array, typecode in arrayimpl.subTest(self):
41            if unittest.is_mpi_gpu('openmpi', array): continue
42            if unittest.is_mpi_gpu('mvapich2', array): continue
43            for s in range(0, size):
44                #
45                sbuf = array( s, typecode, s)
46                rbuf = array(-1, typecode, s)
47                mem  = array( 0, typecode, 2*(s+MPI.BSEND_OVERHEAD)).as_raw()
48                if size == 1:
49                    MPI.Attach_buffer(mem)
50                    rbuf = sbuf
51                    MPI.Detach_buffer()
52                elif rank == 0:
53                    MPI.Attach_buffer(mem)
54                    self.COMM.Ibsend(sbuf.as_mpi(), 1, 0).Wait()
55                    self.COMM.Bsend(sbuf.as_mpi(), 1, 0)
56                    MPI.Detach_buffer()
57                    self.COMM.Send(sbuf.as_mpi(), 1, 0)
58                    self.COMM.Ssend(sbuf.as_mpi(), 1, 0)
59                    self.COMM.Recv(rbuf.as_mpi(),  1, 0)
60                    self.COMM.Recv(rbuf.as_mpi(),  1, 0)
61                    self.COMM.Recv(rbuf.as_mpi(), 1, 0)
62                    self.COMM.Recv(rbuf.as_mpi(), 1, 0)
63                elif rank == 1:
64                    self.COMM.Recv(rbuf.as_mpi(), 0, 0)
65                    self.COMM.Recv(rbuf.as_mpi(), 0, 0)
66                    self.COMM.Recv(rbuf.as_mpi(), 0, 0)
67                    self.COMM.Recv(rbuf.as_mpi(), 0, 0)
68                    MPI.Attach_buffer(mem)
69                    self.COMM.Ibsend(sbuf.as_mpi(), 0, 0).Wait()
70                    self.COMM.Bsend(sbuf.as_mpi(), 0, 0)
71                    MPI.Detach_buffer()
72                    self.COMM.Send(sbuf.as_mpi(), 0, 0)
73                    self.COMM.Ssend(sbuf.as_mpi(), 0, 0)
74                else:
75                    rbuf = sbuf
76                for value in rbuf:
77                    self.assertEqual(value, s)
78                #
79                rank = self.COMM.Get_rank()
80                sbuf = array( s, typecode, s)
81                rbuf = array(-1, typecode, s)
82                rreq = self.COMM.Irecv(rbuf.as_mpi(), rank, 0)
83                self.COMM.Rsend(sbuf.as_mpi(), rank, 0)
84                rreq.Wait()
85                for value in rbuf:
86                    self.assertEqual(value, s)
87                rbuf = array(-1, typecode, s)
88                rreq = self.COMM.Irecv(rbuf.as_mpi(), rank, 0)
89                self.COMM.Irsend(sbuf.as_mpi(), rank, 0).Wait()
90                rreq.Wait()
91                for value in rbuf:
92                    self.assertEqual(value, s)
93
94    def testProcNull(self):
95        comm = self.COMM
96        #
97        comm.Sendrecv(None, MPI.PROC_NULL, 0,
98                      None, MPI.PROC_NULL, 0)
99        comm.Sendrecv_replace(None,
100                              MPI.PROC_NULL, 0,
101                              MPI.PROC_NULL, 0)
102        #
103        comm.Send (None, MPI.PROC_NULL)
104        comm.Isend (None, MPI.PROC_NULL).Wait()
105        #
106        comm.Ssend(None, MPI.PROC_NULL)
107        comm.Issend(None, MPI.PROC_NULL).Wait()
108        #
109        buf = MPI.Alloc_mem(MPI.BSEND_OVERHEAD)
110        MPI.Attach_buffer(buf)
111        comm.Bsend(None, MPI.PROC_NULL)
112        comm.Ibsend(None, MPI.PROC_NULL).Wait()
113        MPI.Detach_buffer()
114        MPI.Free_mem(buf)
115        #
116        comm.Rsend(None, MPI.PROC_NULL)
117        comm.Irsend(None, MPI.PROC_NULL).Wait()
118        #
119        comm.Recv (None, MPI.PROC_NULL)
120        comm.Irecv(None, MPI.PROC_NULL).Wait()
121
122    @unittest.skipMPI('mpich(==3.4.1)')
123    def testProcNullPersistent(self):
124        comm = self.COMM
125        #
126        req = comm.Send_init(None, MPI.PROC_NULL)
127        req.Start(); req.Wait(); req.Free()
128        #
129        req = comm.Ssend_init(None, MPI.PROC_NULL)
130        req.Start(); req.Wait(); req.Free()
131        #
132        buf = MPI.Alloc_mem(MPI.BSEND_OVERHEAD)
133        MPI.Attach_buffer(buf)
134        req = comm.Bsend_init(None, MPI.PROC_NULL)
135        req.Start(); req.Wait(); req.Free()
136        MPI.Detach_buffer()
137        MPI.Free_mem(buf)
138        #
139        req = comm.Rsend_init(None, MPI.PROC_NULL)
140        req.Start(); req.Wait(); req.Free()
141        #
142        req = comm.Recv_init(None, MPI.PROC_NULL)
143        req.Start(); req.Wait(); req.Free()
144
145    def testPersistent(self):
146        size = self.COMM.Get_size()
147        rank = self.COMM.Get_rank()
148        dest = (rank + 1) % size
149        source = (rank - 1) % size
150        for array, typecode in arrayimpl.subTest(self):
151            if unittest.is_mpi_gpu('openmpi', array): continue
152            if unittest.is_mpi_gpu('mvapich2', array): continue
153            for s in range(size):
154                for xs in range(3):
155                    #
156                    sbuf = array( s, typecode, s)
157                    rbuf = array(-1, typecode, s+xs)
158                    sendreq = self.COMM.Send_init(sbuf.as_mpi(), dest, 0)
159                    recvreq = self.COMM.Recv_init(rbuf.as_mpi(), source, 0)
160                    sendreq.Start()
161                    recvreq.Start()
162                    sendreq.Wait()
163                    recvreq.Wait()
164                    self.assertNotEqual(sendreq, MPI.REQUEST_NULL)
165                    self.assertNotEqual(recvreq, MPI.REQUEST_NULL)
166                    sendreq.Free()
167                    recvreq.Free()
168                    self.assertEqual(sendreq, MPI.REQUEST_NULL)
169                    self.assertEqual(recvreq, MPI.REQUEST_NULL)
170                    for value in rbuf[:s]:
171                        self.assertEqual(value, s)
172                    for value in rbuf[s:]:
173                        self.assertEqual(value, -1)
174                    #
175                    sbuf = array(s,  typecode, s)
176                    rbuf = array(-1, typecode, s+xs)
177                    sendreq = self.COMM.Send_init(sbuf.as_mpi(), dest, 0)
178                    recvreq = self.COMM.Recv_init(rbuf.as_mpi(), source, 0)
179                    reqlist = [sendreq, recvreq]
180                    MPI.Prequest.Startall(reqlist)
181                    index1 = MPI.Prequest.Waitany(reqlist)
182                    self.assertTrue(index1 in [0, 1])
183                    self.assertNotEqual(reqlist[index1], MPI.REQUEST_NULL)
184                    index2 = MPI.Prequest.Waitany(reqlist)
185                    self.assertTrue(index2 in [0, 1])
186                    self.assertNotEqual(reqlist[index2], MPI.REQUEST_NULL)
187                    self.assertTrue(index1 != index2)
188                    index3 = MPI.Prequest.Waitany(reqlist)
189                    self.assertEqual(index3, MPI.UNDEFINED)
190                    for preq in reqlist:
191                        self.assertNotEqual(preq, MPI.REQUEST_NULL)
192                        preq.Free()
193                        self.assertEqual(preq, MPI.REQUEST_NULL)
194                    for value in rbuf[:s]:
195                        self.assertEqual(value, s)
196                    for value in rbuf[s:]:
197                        self.assertEqual(value, -1)
198                    #
199                    sbuf = array( s, typecode, s)
200                    rbuf = array(-1, typecode, s+xs)
201                    sendreq = self.COMM.Ssend_init(sbuf.as_mpi(), dest, 0)
202                    recvreq = self.COMM.Recv_init(rbuf.as_mpi(), source, 0)
203                    sendreq.Start()
204                    recvreq.Start()
205                    sendreq.Wait()
206                    recvreq.Wait()
207                    self.assertNotEqual(sendreq, MPI.REQUEST_NULL)
208                    self.assertNotEqual(recvreq, MPI.REQUEST_NULL)
209                    sendreq.Free()
210                    recvreq.Free()
211                    self.assertEqual(sendreq, MPI.REQUEST_NULL)
212                    self.assertEqual(recvreq, MPI.REQUEST_NULL)
213                    for value in rbuf[:s]:
214                        self.assertEqual(value, s)
215                    for value in rbuf[s:]:
216                        self.assertEqual(value, -1)
217                    #
218                    mem = array( 0, typecode, s+MPI.BSEND_OVERHEAD).as_raw()
219                    sbuf = array( s, typecode, s)
220                    rbuf = array(-1, typecode, s+xs)
221                    MPI.Attach_buffer(mem)
222                    sendreq = self.COMM.Bsend_init(sbuf.as_mpi(), dest, 0)
223                    recvreq = self.COMM.Recv_init(rbuf.as_mpi(), source, 0)
224                    sendreq.Start()
225                    recvreq.Start()
226                    sendreq.Wait()
227                    recvreq.Wait()
228                    MPI.Detach_buffer()
229                    self.assertNotEqual(sendreq, MPI.REQUEST_NULL)
230                    self.assertNotEqual(recvreq, MPI.REQUEST_NULL)
231                    sendreq.Free()
232                    recvreq.Free()
233                    self.assertEqual(sendreq, MPI.REQUEST_NULL)
234                    self.assertEqual(recvreq, MPI.REQUEST_NULL)
235                    for value in rbuf[:s]:
236                        self.assertEqual(value, s)
237                    for value in rbuf[s:]:
238                        self.assertEqual(value, -1)
239                    #
240                    rank = self.COMM.Get_rank()
241                    sbuf = array( s, typecode, s)
242                    rbuf = array(-1, typecode, s+xs)
243                    recvreq = self.COMM.Recv_init (rbuf.as_mpi(), rank, 0)
244                    sendreq = self.COMM.Rsend_init(sbuf.as_mpi(), rank, 0)
245                    recvreq.Start()
246                    sendreq.Start()
247                    recvreq.Wait()
248                    sendreq.Wait()
249                    self.assertNotEqual(sendreq, MPI.REQUEST_NULL)
250                    self.assertNotEqual(recvreq, MPI.REQUEST_NULL)
251                    sendreq.Free()
252                    recvreq.Free()
253                    self.assertEqual(sendreq, MPI.REQUEST_NULL)
254                    self.assertEqual(recvreq, MPI.REQUEST_NULL)
255                    for value in rbuf[:s]:
256                        self.assertEqual(value, s)
257                    for value in rbuf[s:]:
258                        self.assertEqual(value, -1)
259
260    def testProbe(self):
261        comm = self.COMM.Dup()
262        try:
263            request = comm.Issend([None, 0, MPI.BYTE], comm.rank, 123)
264            self.assertTrue(request)
265            status = MPI.Status()
266            comm.Probe(MPI.ANY_SOURCE, MPI.ANY_TAG, status)
267            self.assertEqual(status.source, comm.rank)
268            self.assertEqual(status.tag, 123)
269            self.assertTrue(request)
270            flag = request.Test()
271            self.assertTrue(request)
272            self.assertFalse(flag)
273            comm.Recv([None, 0, MPI.BYTE], comm.rank, 123)
274            self.assertTrue(request)
275            flag = False
276            while not flag:
277                flag = request.Test()
278            self.assertFalse(request)
279            self.assertTrue(flag)
280        finally:
281            comm.Free()
282
283    @unittest.skipMPI('MPICH1')
284    @unittest.skipMPI('LAM/MPI')
285    def testProbeCancel(self):
286        comm = self.COMM.Dup()
287        try:
288            request = comm.Issend([None, 0, MPI.BYTE], comm.rank, 123)
289            status = MPI.Status()
290            comm.Probe(MPI.ANY_SOURCE, MPI.ANY_TAG, status)
291            self.assertEqual(status.source, comm.rank)
292            self.assertEqual(status.tag, 123)
293            request.Cancel()
294            self.assertTrue(request)
295            status = MPI.Status()
296            request.Get_status(status)
297            cancelled = status.Is_cancelled()
298            if not cancelled:
299                comm.Recv([None, 0, MPI.BYTE], comm.rank, 123)
300                request.Wait()
301            else:
302                request.Free()
303        finally:
304            comm.Free()
305
306    def testIProbe(self):
307        comm = self.COMM.Dup()
308        try:
309            f = comm.Iprobe()
310            self.assertFalse(f)
311            f = comm.Iprobe(MPI.ANY_SOURCE)
312            self.assertFalse(f)
313            f = comm.Iprobe(MPI.ANY_SOURCE, MPI.ANY_TAG)
314            self.assertFalse(f)
315            status = MPI.Status()
316            f = comm.Iprobe(MPI.ANY_SOURCE, MPI.ANY_TAG, status)
317            self.assertFalse(f)
318            self.assertEqual(status.source, MPI.ANY_SOURCE)
319            self.assertEqual(status.tag,    MPI.ANY_TAG)
320            self.assertEqual(status.error,  MPI.SUCCESS)
321        finally:
322            comm.Free()
323
324
325class TestP2PBufSelf(BaseTestP2PBuf, unittest.TestCase):
326    COMM = MPI.COMM_SELF
327
328class TestP2PBufWorld(BaseTestP2PBuf, unittest.TestCase):
329    COMM = MPI.COMM_WORLD
330
331class TestP2PBufSelfDup(TestP2PBufSelf):
332    def setUp(self):
333        self.COMM = MPI.COMM_SELF.Dup()
334    def tearDown(self):
335        self.COMM.Free()
336
337@unittest.skipMPI('openmpi(<1.4.0)', MPI.Query_thread() > MPI.THREAD_SINGLE)
338class TestP2PBufWorldDup(TestP2PBufWorld):
339    def setUp(self):
340        self.COMM = MPI.COMM_WORLD.Dup()
341    def tearDown(self):
342        self.COMM.Free()
343
344
345if __name__ == '__main__':
346    unittest.main()
347