1from mpi4py import MPI
2import mpiunittest as unittest
3from arrayimpl import allclose
4from arrayimpl import typestr
5import sys
6
7typemap = MPI._typedict
8
9try:
10    import array
11except ImportError:
12    array = None
13try:
14    import numpy
15except ImportError:
16    numpy = None
17try:
18    import cupy
19except ImportError:
20    cupy = None
21try:
22    import numba
23    import numba.cuda
24    from distutils.version import StrictVersion
25    numba_version = StrictVersion(numba.__version__).version
26    if numba_version < (0, 48):
27        import warnings
28        warnings.warn('To test Numba GPU arrays, use Numba v0.48.0+.',
29                      RuntimeWarning)
30        numba = None
31except ImportError:
32    numba = None
33
34
35py2 = sys.version_info[0] == 2
36py3 = sys.version_info[0] >= 3
37pypy = hasattr(sys, 'pypy_version_info')
38pypy2 = pypy and py2
39pypy_lt_53 = pypy and sys.pypy_version_info < (5, 3)
40
41
42# ---
43
44class BaseBuf(object):
45
46    def __init__(self, typecode, initializer):
47        self._buf = array.array(typecode, initializer)
48
49    def __eq__(self, other):
50        return self._buf == other._buf
51
52    def __ne__(self, other):
53        return self._buf != other._buf
54
55    def __len__(self):
56        return len(self._buf)
57
58    def __getitem__(self, item):
59        return self._buf[item]
60
61    def __setitem__(self, item, value):
62        self._buf[item] = value._buf
63
64# ---
65
66try:
67    import dlpackimpl as dlpack
68except ImportError:
69    dlpack = None
70
71class DLPackCPUBuf(BaseBuf):
72
73    def __init__(self, typecode, initializer):
74        super(DLPackCPUBuf, self).__init__(typecode, initializer)
75        self.managed = dlpack.make_dl_managed_tensor(self._buf)
76
77    def __del__(self):
78        self.managed = None
79        if not pypy and sys.getrefcount(self._buf) > 2:
80            raise RuntimeError('dlpack: possible reference leak')
81
82    def __dlpack_device__(self):
83        device = self.managed.dl_tensor.device
84        return (device.device_type, device.device_id)
85
86    def __dlpack__(self, stream=None):
87        managed = self.managed
88        if managed.dl_tensor.device.device_type == \
89           dlpack.DLDeviceType.kDLCPU:
90            assert stream == None
91        capsule = dlpack.make_py_capsule(managed)
92        return capsule
93
94
95if cupy is not None:
96
97    class DLPackGPUBuf(BaseBuf):
98
99        has_dlpack = None
100        dev_type = None
101
102        def __init__(self, typecode, initializer):
103            self._buf = cupy.array(initializer, dtype=typecode)
104            self.has_dlpack = hasattr(self._buf, '__dlpack_device__')
105            # TODO(leofang): test CUDA managed memory?
106            if cupy.cuda.runtime.is_hip:
107                self.dev_type = dlpack.DLDeviceType.kDLROCM
108            else:
109                self.dev_type = dlpack.DLDeviceType.kDLCUDA
110
111        def __del__(self):
112            if not pypy and sys.getrefcount(self._buf) > 2:
113                raise RuntimeError('dlpack: possible reference leak')
114
115        def __dlpack_device__(self):
116            if self.has_dlpack:
117                return self._buf.__dlpack_device__()
118            else:
119                return (self.dev_type, self._buf.device.id)
120
121        def __dlpack__(self, stream=None):
122            cupy.cuda.get_current_stream().synchronize()
123            if self.has_dlpack:
124                return self._buf.__dlpack__(stream=-1)
125            else:
126                return self._buf.toDlpack()
127
128
129# ---
130
131class CAIBuf(BaseBuf):
132
133    def __init__(self, typecode, initializer, readonly=False):
134        super(CAIBuf, self).__init__(typecode, initializer)
135        address = self._buf.buffer_info()[0]
136        typecode = self._buf.typecode
137        itemsize = self._buf.itemsize
138        self.__cuda_array_interface__ = dict(
139            version = 0,
140            data    = (address, readonly),
141            typestr = typestr(typecode, itemsize),
142            shape   = (len(self._buf), 1, 1),
143            strides = (itemsize,) * 3,
144            descr   = [('', typestr(typecode, itemsize))],
145        )
146
147
148cupy_issue_2259 = False
149if cupy is not None:
150    cupy_issue_2259 = not isinstance(
151        cupy.zeros((2,2)).T.__cuda_array_interface__['strides'],
152        tuple
153    )
154
155# ---
156
157def Sendrecv(smsg, rmsg):
158    MPI.COMM_SELF.Sendrecv(sendbuf=smsg, dest=0,   sendtag=0,
159                           recvbuf=rmsg, source=0, recvtag=0,
160                           status=MPI.Status())
161
162
163class TestMessageSimple(unittest.TestCase):
164
165    def testMessageBad(self):
166        buf = MPI.Alloc_mem(5)
167        empty = [None, 0, "B"]
168        def f(): Sendrecv([buf, 0, 0, "i", None], empty)
169        self.assertRaises(ValueError, f)
170        def f(): Sendrecv([buf,  0, "\0"], empty)
171        self.assertRaises(KeyError, f)
172        def f(): Sendrecv([buf, -1, "i"], empty)
173        self.assertRaises(ValueError, f)
174        def f(): Sendrecv([buf, 0, -1, "i"], empty)
175        self.assertRaises(ValueError, f)
176        def f(): Sendrecv([buf, 0, +2, "i"], empty)
177        self.assertRaises(ValueError, f)
178        def f(): Sendrecv([None, 1,  0, "i"], empty)
179        self.assertRaises(ValueError, f)
180        def f(): Sendrecv([buf, None,  0, "i"], empty)
181        self.assertRaises(ValueError, f)
182        def f(): Sendrecv([buf, 0, 1, MPI.DATATYPE_NULL], empty)
183        self.assertRaises(ValueError, f)
184        def f(): Sendrecv([buf, None, 0, MPI.DATATYPE_NULL], empty)
185        self.assertRaises(ValueError, f)
186        try:
187            t = MPI.INT.Create_resized(0, -4).Commit()
188            def f(): Sendrecv([buf, None, t], empty)
189            self.assertRaises(ValueError, f)
190            def f(): Sendrecv([buf, 0, 1, t], empty)
191            self.assertRaises(ValueError, f)
192            t.Free()
193        except NotImplementedError:
194            pass
195        MPI.Free_mem(buf)
196        buf = [1,2,3,4]
197        def f(): Sendrecv([buf, 4,  0, "i"], empty)
198        self.assertRaises(TypeError, f)
199        buf = {1:2,3:4}
200        def f(): Sendrecv([buf, 4,  0, "i"], empty)
201        self.assertRaises(TypeError, f)
202        def f(): Sendrecv(b"abc", b"abc")
203        self.assertRaises((BufferError, TypeError, ValueError), f)
204
205    def testMessageNone(self):
206        empty = [None, 0, "B"]
207        Sendrecv(empty, empty)
208        empty = [None, "B"]
209        Sendrecv(empty, empty)
210
211    def testMessageBottom(self):
212        empty = [MPI.BOTTOM, 0, "B"]
213        Sendrecv(empty, empty)
214        empty = [MPI.BOTTOM, "B"]
215        Sendrecv(empty, empty)
216
217    @unittest.skipIf(pypy_lt_53, 'pypy(<5.3)')
218    def testMessageBytes(self):
219        sbuf = b"abc"
220        rbuf = bytearray(3)
221        Sendrecv([sbuf, "c"], [rbuf, MPI.CHAR])
222        self.assertEqual(sbuf, rbuf)
223
224    @unittest.skipIf(pypy_lt_53, 'pypy(<5.3)')
225    def testMessageBytearray(self):
226        sbuf = bytearray(b"abc")
227        rbuf = bytearray(3)
228        Sendrecv([sbuf, "c"], [rbuf, MPI.CHAR])
229        self.assertEqual(sbuf, rbuf)
230
231    @unittest.skipIf(py3, 'python3')
232    @unittest.skipIf(pypy2, 'pypy2')
233    @unittest.skipIf(hasattr(MPI, 'ffi'), 'mpi4py-cffi')
234    def testMessageUnicode(self):  # Test for Issue #120
235        sbuf = unicode("abc")
236        rbuf = bytearray(len(buffer(sbuf)))
237        Sendrecv([sbuf, MPI.BYTE], [rbuf, MPI.BYTE])
238
239    @unittest.skipIf(py3, 'python3')
240    @unittest.skipIf(pypy_lt_53, 'pypy(<5.3)')
241    def testMessageBuffer(self):
242        sbuf = buffer(b"abc")
243        rbuf = bytearray(3)
244        Sendrecv([sbuf, "c"], [rbuf, MPI.CHAR])
245        self.assertEqual(sbuf, rbuf)
246        self.assertRaises((BufferError, TypeError, ValueError),
247                          Sendrecv, [rbuf, "c"], [sbuf, "c"])
248
249    @unittest.skipIf(pypy2, 'pypy2')
250    @unittest.skipIf(pypy_lt_53, 'pypy(<5.3)')
251    def testMessageMemoryView(self):
252        sbuf = memoryview(b"abc")
253        rbuf = bytearray(3)
254        Sendrecv([sbuf, "c"], [rbuf, MPI.CHAR])
255        self.assertEqual(sbuf, rbuf)
256        self.assertRaises((BufferError, TypeError, ValueError),
257                          Sendrecv, [rbuf, "c"], [sbuf, "c"])
258
259
260@unittest.skipMPI('msmpi(<8.0.0)')
261class TestMessageBlock(unittest.TestCase):
262
263    @unittest.skipIf(MPI.COMM_WORLD.Get_size() < 2, 'mpi-world-size<2')
264    def testMessageBad(self):
265        comm = MPI.COMM_WORLD
266        buf = MPI.Alloc_mem(4)
267        empty = [None, 0, "B"]
268        def f(): comm.Alltoall([buf, None, "i"], empty)
269        self.assertRaises(ValueError, f)
270        MPI.Free_mem(buf)
271
272
273class BaseTestMessageSimpleArray(object):
274
275    TYPECODES = "bhil"+"BHIL"+"fd"
276
277    def array(self, typecode, initializer):
278        raise NotImplementedError
279
280    def check1(self, z, s, r, typecode):
281        r[:] = z
282        Sendrecv(s, r)
283        for a, b in zip(s, r):
284            self.assertEqual(a, b)
285
286    def check2(self, z, s, r, typecode):
287        datatype = typemap[typecode]
288        for type in (None, typecode, datatype):
289            r[:] = z
290            Sendrecv([s, type],
291                     [r, type])
292            for a, b in zip(s, r):
293                self.assertEqual(a, b)
294
295    def check3(self, z, s, r, typecode):
296        size = len(r)
297        for count in range(size):
298            r[:] = z
299            Sendrecv([s, count],
300                     [r, count])
301            for i in range(count):
302                self.assertEqual(r[i], s[i])
303            for i in range(count, size):
304                self.assertEqual(r[i], z[0])
305        for count in range(size):
306            r[:] = z
307            Sendrecv([s, (count, None)],
308                     [r, (count, None)])
309            for i in range(count):
310                self.assertEqual(r[i], s[i])
311            for i in range(count, size):
312                self.assertEqual(r[i], z[0])
313        for disp in range(size):
314            r[:] = z
315            Sendrecv([s, (None, disp)],
316                     [r, (None, disp)])
317            for i in range(disp):
318                self.assertEqual(r[i], z[0])
319            for i in range(disp, size):
320                self.assertEqual(r[i], s[i])
321        for disp in range(size):
322            for count in range(size-disp):
323                r[:] = z
324                Sendrecv([s, (count, disp)],
325                         [r, (count, disp)])
326                for i in range(0, disp):
327                    self.assertEqual(r[i], z[0])
328                for i in range(disp, disp+count):
329                    self.assertEqual(r[i], s[i])
330                for i in range(disp+count, size):
331                    self.assertEqual(r[i], z[0])
332
333    def check4(self, z, s, r, typecode):
334        datatype = typemap[typecode]
335        for type in (None, typecode, datatype):
336            for count in (None, len(s)):
337                r[:] = z
338                Sendrecv([s, count, type],
339                         [r, count, type])
340                for a, b in zip(s, r):
341                    self.assertEqual(a, b)
342
343    def check5(self, z, s, r, typecode):
344        datatype = typemap[typecode]
345        for type in (None, typecode, datatype):
346            for p in range(0, len(s)):
347                r[:] = z
348                Sendrecv([s, (p, None), type],
349                         [r, (p, None), type])
350                for a, b in zip(s[:p], r[:p]):
351                    self.assertEqual(a, b)
352                for q in range(p, len(s)):
353                    count, displ = q-p, p
354                    r[:] = z
355                    Sendrecv([s, (count, displ), type],
356                             [r, (count, displ), type])
357                    for a, b in zip(r[:p], z[:p]):
358                        self.assertEqual(a, b)
359                    for a, b in zip(r[p:q], s[p:q]):
360                        self.assertEqual(a, b)
361                    for a, b in zip(r[q:], z[q:]):
362                        self.assertEqual(a, b)
363
364    def check6(self, z, s, r, typecode):
365        datatype = typemap[typecode]
366        for type in (None, typecode, datatype):
367            for p in range(0, len(s)):
368                r[:] = z
369                Sendrecv([s, p, None, type],
370                         [r, p, None, type])
371                for a, b in zip(s[:p], r[:p]):
372                    self.assertEqual(a, b)
373                for q in range(p, len(s)):
374                    count, displ = q-p, p
375                    r[:] = z
376                    Sendrecv([s, count, displ, type],
377                             [r, count, displ, type])
378                    for a, b in zip(r[:p], z[:p]):
379                        self.assertEqual(a, b)
380                    for a, b in zip(r[p:q], s[p:q]):
381                        self.assertEqual(a, b)
382                    for a, b in zip(r[q:], z[q:]):
383                        self.assertEqual(a, b)
384
385    def check(self, test):
386        for t in tuple(self.TYPECODES):
387            for n in range(1, 10):
388                z = self.array(t, [0]*n)
389                s = self.array(t, list(range(n)))
390                r = self.array(t, [0]*n)
391                test(z, s, r, t)
392
393    def testArray1(self):
394        self.check(self.check1)
395
396    def testArray2(self):
397        self.check(self.check2)
398
399    def testArray3(self):
400        self.check(self.check3)
401
402    def testArray4(self):
403        self.check(self.check4)
404
405    def testArray5(self):
406        self.check(self.check5)
407
408    def testArray6(self):
409        self.check(self.check6)
410
411
412@unittest.skipIf(array is None, 'array')
413class TestMessageSimpleArray(unittest.TestCase,
414                             BaseTestMessageSimpleArray):
415
416    def array(self, typecode, initializer):
417        return array.array(typecode, initializer)
418
419
420@unittest.skipIf(numpy is None, 'numpy')
421class TestMessageSimpleNumPy(unittest.TestCase,
422                             BaseTestMessageSimpleArray):
423
424    def array(self, typecode, initializer):
425        return numpy.array(initializer, dtype=typecode)
426
427    def testOrderC(self):
428        sbuf = numpy.ones([3,2])
429        rbuf = numpy.zeros([3,2])
430        Sendrecv(sbuf, rbuf)
431        self.assertTrue((sbuf == rbuf).all())
432
433    def testOrderFortran(self):
434        sbuf = numpy.ones([3,2]).T
435        rbuf = numpy.zeros([3,2]).T
436        Sendrecv(sbuf, rbuf)
437        self.assertTrue((sbuf == rbuf).all())
438
439    def testReadonly(self):
440        sbuf = numpy.ones([3])
441        rbuf = numpy.zeros([3])
442        sbuf.flags.writeable = False
443        Sendrecv(sbuf, rbuf)
444        self.assertTrue((sbuf == rbuf).all())
445
446    def testNotWriteable(self):
447        sbuf = numpy.ones([3])
448        rbuf = numpy.zeros([3])
449        rbuf.flags.writeable = False
450        self.assertRaises((BufferError, ValueError),
451                          Sendrecv, sbuf, rbuf)
452
453    def testNotContiguous(self):
454        sbuf = numpy.ones([3,2])[:,0]
455        rbuf = numpy.zeros([3])
456        sbuf.flags.writeable = False
457        self.assertRaises((BufferError, ValueError),
458                          Sendrecv, sbuf, rbuf)
459
460
461@unittest.skipIf(array is None, 'array')
462@unittest.skipIf(dlpack is None, 'dlpack')
463class TestMessageSimpleDLPackCPUBuf(unittest.TestCase,
464                                    BaseTestMessageSimpleArray):
465
466    def array(self, typecode, initializer):
467        return DLPackCPUBuf(typecode, initializer)
468
469
470@unittest.skipIf(cupy is None, 'cupy')
471class TestMessageSimpleDLPackGPUBuf(unittest.TestCase,
472                                    BaseTestMessageSimpleArray):
473
474    def array(self, typecode, initializer):
475        return DLPackGPUBuf(typecode, initializer)
476
477
478@unittest.skipIf(array is None, 'array')
479class TestMessageSimpleCAIBuf(unittest.TestCase,
480                              BaseTestMessageSimpleArray):
481
482    def array(self, typecode, initializer):
483        return CAIBuf(typecode, initializer)
484
485
486@unittest.skipIf(cupy is None, 'cupy')
487class TestMessageSimpleCuPy(unittest.TestCase,
488                            BaseTestMessageSimpleArray):
489
490    def array(self, typecode, initializer):
491        return cupy.array(initializer, dtype=typecode)
492
493    def testOrderC(self):
494        sbuf = cupy.ones([3,2])
495        rbuf = cupy.zeros([3,2])
496        Sendrecv(sbuf, rbuf)
497        self.assertTrue((sbuf == rbuf).all())
498
499    @unittest.skipIf(cupy_issue_2259, 'cupy-issue-2259')
500    def testOrderFortran(self):
501        sbuf = cupy.ones([3,2]).T
502        rbuf = cupy.zeros([3,2]).T
503        Sendrecv(sbuf, rbuf)
504        self.assertTrue((sbuf == rbuf).all())
505
506    @unittest.skipIf(cupy_issue_2259, 'cupy-issue-2259')
507    def testNotContiguous(self):
508        sbuf = cupy.ones([3,2])[:,0]
509        rbuf = cupy.zeros([3])
510        self.assertRaises((BufferError, ValueError),
511                          Sendrecv, sbuf, rbuf)
512
513
514@unittest.skipIf(numba is None, 'numba')
515class TestMessageSimpleNumba(unittest.TestCase,
516                             BaseTestMessageSimpleArray):
517
518    def array(self, typecode, initializer):
519        n = len(initializer)
520        arr = numba.cuda.device_array((n,), dtype=typecode)
521        arr[:] = initializer
522        return arr
523
524    def testOrderC(self):
525        sbuf = numba.cuda.device_array((6,))
526        sbuf[:] = 1
527        sbuf = sbuf.reshape(3,2)
528        rbuf = numba.cuda.device_array((6,))
529        rbuf[:] = 0
530        rbuf = sbuf.reshape(3,2)
531        Sendrecv(sbuf, rbuf)
532        # numba arrays do not have the .all() method
533        for i in range(3):
534            for j in range(2):
535                self.assertTrue(sbuf[i,j] == rbuf[i,j])
536
537    def testOrderFortran(self):
538        sbuf = numba.cuda.device_array((6,))
539        sbuf[:] = 1
540        sbuf = sbuf.reshape(3,2,order='F')
541        rbuf = numba.cuda.device_array((6,))
542        rbuf[:] = 0
543        rbuf = sbuf.reshape(3,2,order='F')
544        Sendrecv(sbuf, rbuf)
545        # numba arrays do not have the .all() method
546        for i in range(3):
547            for j in range(2):
548                self.assertTrue(sbuf[i,j] == rbuf[i,j])
549
550    def testNotContiguous(self):
551        sbuf = numba.cuda.device_array((6,))
552        sbuf[:] = 1
553        sbuf = sbuf.reshape(3,2)[:,0]
554        rbuf = numba.cuda.device_array((3,))
555        rbuf[:] = 0
556        self.assertRaises((BufferError, ValueError),
557                          Sendrecv, sbuf, rbuf)
558
559
560# ---
561
562@unittest.skipIf(array is None, 'array')
563@unittest.skipIf(dlpack is None, 'dlpack')
564class TestMessageDLPackCPUBuf(unittest.TestCase):
565
566    def testDevice(self):
567        buf = DLPackCPUBuf('i', [0,1,2,3])
568        buf.__dlpack_device__ = None
569        self.assertRaises(TypeError, MPI.Get_address, buf)
570        buf.__dlpack_device__ = lambda: None
571        self.assertRaises(TypeError, MPI.Get_address, buf)
572        buf.__dlpack_device__ = lambda: (None, 0)
573        self.assertRaises(TypeError, MPI.Get_address, buf)
574        buf.__dlpack_device__ = lambda: (1, None)
575        self.assertRaises(TypeError, MPI.Get_address, buf)
576        buf.__dlpack_device__ = lambda: (1,)
577        self.assertRaises(ValueError, MPI.Get_address, buf)
578        buf.__dlpack_device__ = lambda: (1, 0, 1)
579        self.assertRaises(ValueError, MPI.Get_address, buf)
580        del buf.__dlpack_device__
581        MPI.Get_address(buf)
582
583    def testCapsule(self):
584        buf = DLPackCPUBuf('i', [0,1,2,3])
585        #
586        capsule = buf.__dlpack__()
587        MPI.Get_address(buf)
588        MPI.Get_address(buf)
589        del capsule
590        #
591        capsule = buf.__dlpack__()
592        retvals = [capsule] * 2
593        buf.__dlpack__ = lambda *args, **kwargs: retvals.pop()
594        MPI.Get_address(buf)
595        self.assertRaises(BufferError, MPI.Get_address, buf)
596        del buf.__dlpack__
597        del capsule
598        #
599        buf.__dlpack__ = lambda *args, **kwargs:  None
600        self.assertRaises(BufferError, MPI.Get_address, buf)
601        del buf.__dlpack__
602
603    def testNdim(self):
604        buf = DLPackCPUBuf('i', [0,1,2,3])
605        dltensor = buf.managed.dl_tensor
606        #
607        for ndim in (2, 1, 0):
608            dltensor.ndim = ndim
609            MPI.Get_address(buf)
610        #
611        dltensor.ndim = -1
612        self.assertRaises(BufferError, MPI.Get_address, buf)
613        #
614        del dltensor
615
616    def testShape(self):
617        buf = DLPackCPUBuf('i', [0,1,2,3])
618        dltensor = buf.managed.dl_tensor
619        #
620        dltensor.ndim = 1
621        dltensor.shape[0] = -1
622        self.assertRaises(BufferError, MPI.Get_address, buf)
623        #
624        dltensor.ndim = 0
625        dltensor.shape = None
626        MPI.Get_address(buf)
627        #
628        dltensor.ndim = 1
629        dltensor.shape = None
630        self.assertRaises(BufferError, MPI.Get_address, buf)
631        #
632        del dltensor
633
634    def testStrides(self):
635        buf = DLPackCPUBuf('i', range(8))
636        dltensor = buf.managed.dl_tensor
637        #
638        for order in ('C', 'F'):
639            dltensor.ndim, dltensor.shape, dltensor.strides = \
640                dlpack.make_dl_shape([2, 2, 2], order=order)
641            MPI.Get_address(buf)
642            dltensor.strides[0] = -1
643            self.assertRaises(BufferError, MPI.Get_address, buf)
644        #
645        del dltensor
646
647    def testContiguous(self):
648        buf = DLPackCPUBuf('i', range(8))
649        dltensor = buf.managed.dl_tensor
650        #
651        dltensor.ndim, dltensor.shape, dltensor.strides = \
652            dlpack.make_dl_shape([2, 2, 2], order='C')
653        s = dltensor.strides
654        strides = [s[i] for i in range(dltensor.ndim)]
655        s[0], s[1], s[2] = [strides[i] for i in [0, 1, 2]]
656        MPI.Get_address(buf)
657        s[0], s[1], s[2] = [strides[i] for i in [2, 1, 0]]
658        MPI.Get_address(buf)
659        s[0], s[1], s[2] = [strides[i] for i in [0, 2, 1]]
660        self.assertRaises(BufferError, MPI.Get_address, buf)
661        s[0], s[1], s[2] = [strides[i] for i in [1, 0, 2]]
662        self.assertRaises(BufferError, MPI.Get_address, buf)
663        del s
664        #
665        del dltensor
666
667    def testByteOffset(self):
668        buf = DLPackCPUBuf('B', [0,1,2,3])
669        dltensor = buf.managed.dl_tensor
670        #
671        dltensor.ndim = 1
672        for i in range(len(buf)):
673            dltensor.byte_offset = i
674            mem = MPI.memory(buf)
675            self.assertEqual(mem[0], buf[i])
676        #
677        del dltensor
678
679# ---
680
681@unittest.skipIf(array is None, 'array')
682class TestMessageCAIBuf(unittest.TestCase):
683
684    def testNonReadonly(self):
685        smsg = CAIBuf('i', [1,2,3], readonly=True)
686        rmsg = CAIBuf('i', [0,0,0], readonly=True)
687        self.assertRaises(BufferError, Sendrecv, smsg, rmsg)
688
689    def testNonContiguous(self):
690        smsg = CAIBuf('i', [1,2,3])
691        rmsg = CAIBuf('i', [0,0,0])
692        strides = rmsg.__cuda_array_interface__['strides']
693        bad_strides = strides[:-1] + (7,)
694        rmsg.__cuda_array_interface__['strides'] = bad_strides
695        self.assertRaises(BufferError, Sendrecv, smsg, rmsg)
696
697    def testAttrNone(self):
698        smsg = CAIBuf('B', [1,2,3])
699        rmsg = CAIBuf('B', [0,0,0])
700        rmsg.__cuda_array_interface__ = None
701        self.assertRaises(TypeError, Sendrecv, smsg, rmsg)
702
703    def testAttrEmpty(self):
704        smsg = CAIBuf('B', [1,2,3])
705        rmsg = CAIBuf('B', [0,0,0])
706        rmsg.__cuda_array_interface__ = dict()
707        self.assertRaises(KeyError, Sendrecv, smsg, rmsg)
708
709    def testAttrType(self):
710        smsg = CAIBuf('B', [1,2,3])
711        rmsg = CAIBuf('B', [0,0,0])
712        items = list(rmsg.__cuda_array_interface__.items())
713        rmsg.__cuda_array_interface__ = items
714        self.assertRaises(TypeError, Sendrecv, smsg, rmsg)
715
716    def testDataMissing(self):
717        smsg = CAIBuf('B', [1,2,3])
718        rmsg = CAIBuf('B', [0,0,0])
719        del rmsg.__cuda_array_interface__['data']
720        self.assertRaises(KeyError, Sendrecv, smsg, rmsg)
721
722    def testDataNone(self):
723        smsg = CAIBuf('B', [1,2,3])
724        rmsg = CAIBuf('B', [0,0,0])
725        rmsg.__cuda_array_interface__['data'] = None
726        self.assertRaises(TypeError, Sendrecv, smsg, rmsg)
727
728    def testDataType(self):
729        smsg = CAIBuf('B', [1,2,3])
730        rmsg = CAIBuf('B', [0,0,0])
731        rmsg.__cuda_array_interface__['data'] = 0
732        self.assertRaises(TypeError, Sendrecv, smsg, rmsg)
733
734    def testDataValue(self):
735        smsg = CAIBuf('B', [1,2,3])
736        rmsg = CAIBuf('B', [0,0,0])
737        dev_ptr = rmsg.__cuda_array_interface__['data'][0]
738        rmsg.__cuda_array_interface__['data'] = (dev_ptr, )
739        self.assertRaises(ValueError, Sendrecv, smsg, rmsg)
740        rmsg.__cuda_array_interface__['data'] = ( )
741        self.assertRaises(ValueError, Sendrecv, smsg, rmsg)
742        rmsg.__cuda_array_interface__['data'] = (dev_ptr, False, None)
743        self.assertRaises(ValueError, Sendrecv, smsg, rmsg)
744
745    def testTypestrMissing(self):
746        smsg = CAIBuf('B', [1,2,3])
747        rmsg = CAIBuf('B', [0,0,0])
748        del rmsg.__cuda_array_interface__['typestr']
749        self.assertRaises(KeyError, Sendrecv, smsg, rmsg)
750
751    def testTypestrNone(self):
752        smsg = CAIBuf('B', [1,2,3])
753        rmsg = CAIBuf('B', [0,0,0])
754        rmsg.__cuda_array_interface__['typestr'] = None
755        self.assertRaises(TypeError, Sendrecv, smsg, rmsg)
756
757    def testTypestrType(self):
758        smsg = CAIBuf('B', [1,2,3])
759        rmsg = CAIBuf('B', [0,0,0])
760        rmsg.__cuda_array_interface__['typestr'] = 42
761        self.assertRaises(TypeError, Sendrecv, smsg, rmsg)
762
763    def testTypestrItemsize(self):
764        smsg = CAIBuf('B', [1,2,3])
765        rmsg = CAIBuf('B', [0,0,0])
766        typestr = rmsg.__cuda_array_interface__['typestr']
767        rmsg.__cuda_array_interface__['typestr'] = typestr[:2]+'X'
768        self.assertRaises(ValueError, Sendrecv, smsg, rmsg)
769
770    def testShapeMissing(self):
771        smsg = CAIBuf('B', [1,2,3])
772        rmsg = CAIBuf('B', [0,0,0])
773        del rmsg.__cuda_array_interface__['shape']
774        self.assertRaises(KeyError, Sendrecv, smsg, rmsg)
775
776    def testShapeNone(self):
777        smsg = CAIBuf('B', [1,2,3])
778        rmsg = CAIBuf('B', [0,0,0])
779        rmsg.__cuda_array_interface__['shape'] = None
780        self.assertRaises(TypeError, Sendrecv, smsg, rmsg)
781
782    def testShapeType(self):
783        smsg = CAIBuf('B', [1,2,3])
784        rmsg = CAIBuf('B', [0,0,0])
785        rmsg.__cuda_array_interface__['shape'] = 3
786        self.assertRaises(TypeError, Sendrecv, smsg, rmsg)
787
788    def testShapeValue(self):
789        smsg = CAIBuf('B', [1,2,3])
790        rmsg = CAIBuf('B', [0,0,0])
791        rmsg.__cuda_array_interface__['shape'] = (3, -1)
792        rmsg.__cuda_array_interface__['strides'] = None
793        self.assertRaises(BufferError, Sendrecv, smsg, rmsg)
794
795    def testStridesMissing(self):
796        smsg = CAIBuf('B', [1,2,3])
797        rmsg = CAIBuf('B', [0,0,0])
798        del rmsg.__cuda_array_interface__['strides']
799        Sendrecv(smsg, rmsg)
800        self.assertEqual(smsg, rmsg)
801
802    def testStridesNone(self):
803        smsg = CAIBuf('B', [1,2,3])
804        rmsg = CAIBuf('B', [0,0,0])
805        rmsg.__cuda_array_interface__['strides'] = None
806        Sendrecv(smsg, rmsg)
807        self.assertEqual(smsg, rmsg)
808
809    def testStridesType(self):
810        smsg = CAIBuf('B', [1,2,3])
811        rmsg = CAIBuf('B', [0,0,0])
812        rmsg.__cuda_array_interface__['strides'] = 42
813        self.assertRaises(TypeError, Sendrecv, smsg, rmsg)
814
815    def testDescrMissing(self):
816        smsg = CAIBuf('d', [1,2,3])
817        rmsg = CAIBuf('d', [0,0,0])
818        del rmsg.__cuda_array_interface__['descr']
819        Sendrecv(smsg, rmsg)
820        self.assertEqual(smsg, rmsg)
821
822    def testDescrNone(self):
823        smsg = CAIBuf('d', [1,2,3])
824        rmsg = CAIBuf('d', [0,0,0])
825        rmsg.__cuda_array_interface__['descr'] = None
826        Sendrecv(smsg, rmsg)
827        self.assertEqual(smsg, rmsg)
828
829    def testDescrType(self):
830        smsg = CAIBuf('B', [1,2,3])
831        rmsg = CAIBuf('B', [0,0,0])
832        rmsg.__cuda_array_interface__['descr'] = 42
833        self.assertRaises(TypeError, Sendrecv, smsg, rmsg)
834
835    def testDescrWarning(self):
836        m, n = 5, 3
837        smsg = CAIBuf('d', list(range(m*n)))
838        rmsg = CAIBuf('d', [0]*(m*n))
839        typestr = rmsg.__cuda_array_interface__['typestr']
840        itemsize = int(typestr[2:])
841        new_typestr = "|V"+str(itemsize*n)
842        new_descr = [('', typestr)]*n
843        rmsg.__cuda_array_interface__['shape'] = (m,)
844        rmsg.__cuda_array_interface__['strides'] = (itemsize*n,)
845        rmsg.__cuda_array_interface__['typestr'] = new_typestr
846        rmsg.__cuda_array_interface__['descr'] = new_descr
847        import warnings
848        with warnings.catch_warnings():
849            warnings.simplefilter("error")
850            self.assertRaises(RuntimeWarning, Sendrecv, smsg, rmsg)
851        try:  # Python 3.2+
852            self.assertWarns(RuntimeWarning, Sendrecv, smsg, rmsg)
853        except AttributeError:  # Python 2
854            with warnings.catch_warnings(record=True) as w:
855                warnings.simplefilter("always")
856                Sendrecv(smsg, rmsg)
857                self.assertEqual(len(w), 1)
858                self.assertEqual(w[-1].category, RuntimeWarning)
859        self.assertEqual(smsg, rmsg)
860
861
862# ---
863
864def Alltoallv(smsg, rmsg):
865    comm = MPI.COMM_SELF
866    comm.Alltoallv(smsg, rmsg)
867
868
869@unittest.skipMPI('msmpi(<8.0.0)')
870class TestMessageVector(unittest.TestCase):
871
872    def testMessageBad(self):
873        buf = MPI.Alloc_mem(5)
874        empty = [None, 0, [0], "B"]
875        def f(): Alltoallv([buf, 0, [0], "i", None], empty)
876        self.assertRaises(ValueError, f)
877        def f(): Alltoallv([buf, 0, [0], "\0"], empty)
878        self.assertRaises(KeyError, f)
879        def f(): Alltoallv([buf, None, [0], MPI.DATATYPE_NULL], empty)
880        self.assertRaises(ValueError, f)
881        def f(): Alltoallv([buf, None, [0], "i"], empty)
882        self.assertRaises(ValueError, f)
883        try:
884            t = MPI.INT.Create_resized(0, -4).Commit()
885            def f(): Alltoallv([buf, None, [0], t], empty)
886            self.assertRaises(ValueError, f)
887            t.Free()
888        except NotImplementedError:
889            pass
890        MPI.Free_mem(buf)
891        buf = [1,2,3,4]
892        def f(): Alltoallv([buf, 0,  0, "i"], empty)
893        self.assertRaises(TypeError, f)
894        buf = {1:2,3:4}
895        def f(): Alltoallv([buf, 0,  0, "i"], empty)
896        self.assertRaises(TypeError, f)
897
898    def testMessageNone(self):
899        empty = [None, 0, "B"]
900        Alltoallv(empty, empty)
901        empty = [None, "B"]
902        Alltoallv(empty, empty)
903
904    def testMessageBottom(self):
905        empty = [MPI.BOTTOM, 0, [0], "B"]
906        Alltoallv(empty, empty)
907        empty = [MPI.BOTTOM, 0, "B"]
908        Alltoallv(empty, empty)
909        empty = [MPI.BOTTOM, "B"]
910        Alltoallv(empty, empty)
911
912    @unittest.skipIf(pypy_lt_53, 'pypy(<5.3)')
913    def testMessageBytes(self):
914        sbuf = b"abc"
915        rbuf = bytearray(3)
916        Alltoallv([sbuf, "c"], [rbuf, MPI.CHAR])
917        self.assertEqual(sbuf, rbuf)
918
919    @unittest.skipIf(pypy_lt_53, 'pypy(<5.3)')
920    def testMessageBytearray(self):
921        sbuf = bytearray(b"abc")
922        rbuf = bytearray(3)
923        Alltoallv([sbuf, "c"], [rbuf, MPI.CHAR])
924        self.assertEqual(sbuf, rbuf)
925
926
927@unittest.skipMPI('msmpi(<8.0.0)')
928class BaseTestMessageVectorArray(object):
929
930    TYPECODES = "bhil"+"BHIL"+"fd"
931
932    def array(self, typecode, initializer):
933        raise NotImplementedError
934
935    def check1(self, z, s, r, typecode):
936        r[:] = z
937        Alltoallv(s, r)
938        for a, b in zip(s, r):
939            self.assertEqual(a, b)
940
941    def check2(self, z, s, r, typecode):
942        datatype = typemap[typecode]
943        for type in (None, typecode, datatype):
944            r[:] = z
945            Alltoallv([s, type],
946                      [r, type])
947            for a, b in zip(s, r):
948                self.assertEqual(a, b)
949
950    def check3(self, z, s, r, typecode):
951        size = len(r)
952        for count in range(size):
953            r[:] = z
954            Alltoallv([s, count],
955                      [r, count])
956            for i in range(count):
957                self.assertEqual(r[i], s[i])
958            for i in range(count, size):
959                self.assertEqual(r[i], z[0])
960        for count in range(size):
961            r[:] = z
962            Alltoallv([s, (count, None)],
963                      [r, (count, None)])
964            for i in range(count):
965                self.assertEqual(r[i], s[i])
966            for i in range(count, size):
967                self.assertEqual(r[i], z[0])
968        for disp in range(size):
969            for count in range(size-disp):
970                r[:] = z
971                Alltoallv([s, ([count], [disp])],
972                          [r, ([count], [disp])])
973                for i in range(0, disp):
974                    self.assertEqual(r[i], z[0])
975                for i in range(disp, disp+count):
976                    self.assertEqual(r[i], s[i])
977                for i in range(disp+count, size):
978                    self.assertEqual(r[i], z[0])
979
980    def check4(self, z, s, r, typecode):
981        datatype = typemap[typecode]
982        for type in (None, typecode, datatype):
983            for count in (None, len(s)):
984                r[:] = z
985                Alltoallv([s, count, type],
986                          [r, count, type])
987                for a, b in zip(s, r):
988                    self.assertEqual(a, b)
989
990    def check5(self, z, s, r, typecode):
991        datatype = typemap[typecode]
992        for type in (None, typecode, datatype):
993            for p in range(len(s)):
994                r[:] = z
995                Alltoallv([s, (p, None), type],
996                          [r, (p, None), type])
997                for a, b in zip(s[:p], r[:p]):
998                    self.assertEqual(a, b)
999                for q in range(p, len(s)):
1000                    count, displ = q-p, p
1001                    r[:] = z
1002                    Alltoallv([s, (count, [displ]), type],
1003                              [r, (count, [displ]), type])
1004                    for a, b in zip(r[:p], z[:p]):
1005                        self.assertEqual(a, b)
1006                    for a, b in zip(r[p:q], s[p:q]):
1007                        self.assertEqual(a, b)
1008                    for a, b in zip(r[q:], z[q:]):
1009                        self.assertEqual(a, b)
1010
1011    def check6(self, z, s, r, typecode):
1012        datatype = typemap[typecode]
1013        for type in (None, typecode, datatype):
1014            for p in range(0, len(s)):
1015                r[:] = z
1016                Alltoallv([s, p, None, type],
1017                          [r, p, None, type])
1018                for a, b in zip(s[:p], r[:p]):
1019                    self.assertEqual(a, b)
1020                for q in range(p, len(s)):
1021                    count, displ = q-p, p
1022                    r[:] = z
1023                    Alltoallv([s, count, [displ], type],
1024                              [r, count, [displ], type])
1025                    for a, b in zip(r[:p], z[:p]):
1026                        self.assertEqual(a, b)
1027                    for a, b in zip(r[p:q], s[p:q]):
1028                        self.assertEqual(a, b)
1029                    for a, b in zip(r[q:], z[q:]):
1030                        self.assertEqual(a, b)
1031
1032    def check(self, test):
1033        for t in tuple(self.TYPECODES):
1034            for n in range(1, 10):
1035                z = self.array(t, [0]*n)
1036                s = self.array(t, list(range(n)))
1037                r = self.array(t, [0]*n)
1038                test(z, s, r, t)
1039
1040    def testArray1(self):
1041        self.check(self.check1)
1042
1043    def testArray2(self):
1044        self.check(self.check2)
1045
1046    def testArray3(self):
1047        self.check(self.check3)
1048
1049    def testArray4(self):
1050        self.check(self.check4)
1051
1052    def testArray5(self):
1053        self.check(self.check5)
1054
1055    def testArray6(self):
1056        self.check(self.check6)
1057
1058
1059@unittest.skipIf(array is None, 'array')
1060class TestMessageVectorArray(unittest.TestCase,
1061                             BaseTestMessageVectorArray):
1062
1063    def array(self, typecode, initializer):
1064        return array.array(typecode, initializer)
1065
1066
1067@unittest.skipIf(numpy is None, 'numpy')
1068class TestMessageVectorNumPy(unittest.TestCase,
1069                             BaseTestMessageVectorArray):
1070
1071    def array(self, typecode, initializer):
1072        return numpy.array(initializer, dtype=typecode)
1073
1074
1075@unittest.skipIf(array is None, 'array')
1076class TestMessageVectorCAIBuf(unittest.TestCase,
1077                              BaseTestMessageVectorArray):
1078
1079    def array(self, typecode, initializer):
1080        return CAIBuf(typecode, initializer)
1081
1082
1083@unittest.skipIf(cupy is None, 'cupy')
1084class TestMessageVectorCuPy(unittest.TestCase,
1085                            BaseTestMessageVectorArray):
1086
1087    def array(self, typecode, initializer):
1088        return cupy.array(initializer, dtype=typecode)
1089
1090
1091@unittest.skipIf(numba is None, 'numba')
1092class TestMessageVectorNumba(unittest.TestCase,
1093                             BaseTestMessageVectorArray):
1094
1095    def array(self, typecode, initializer):
1096        n = len(initializer)
1097        arr = numba.cuda.device_array((n,), dtype=typecode)
1098        arr[:] = initializer
1099        return arr
1100
1101
1102# ---
1103
1104def Alltoallw(smsg, rmsg):
1105    try:
1106        MPI.COMM_SELF.Alltoallw(smsg, rmsg)
1107    except NotImplementedError:
1108        if isinstance(smsg, (list, tuple)): smsg = smsg[0]
1109        if isinstance(rmsg, (list, tuple)): rmsg = rmsg[0]
1110        try: rmsg[:] = smsg
1111        except: pass
1112
1113
1114class TestMessageVectorW(unittest.TestCase):
1115
1116    def testMessageBad(self):
1117        sbuf = MPI.Alloc_mem(4)
1118        rbuf = MPI.Alloc_mem(4)
1119        def f(): Alltoallw([sbuf],[rbuf])
1120        self.assertRaises(ValueError, f)
1121        def f(): Alltoallw([sbuf, [0], [0], [MPI.BYTE], None],
1122                           [rbuf, [0], [0], [MPI.BYTE]])
1123        self.assertRaises(ValueError, f)
1124        def f(): Alltoallw([sbuf, [0], [0], [MPI.BYTE]],
1125                           [rbuf, [0], [0], [MPI.BYTE], None])
1126        self.assertRaises(ValueError, f)
1127        MPI.Free_mem(sbuf)
1128        MPI.Free_mem(rbuf)
1129
1130    @unittest.skipIf(pypy_lt_53, 'pypy(<5.3)')
1131    def testMessageBottom(self):
1132        sbuf = b"abcxyz"
1133        rbuf = bytearray(6)
1134        saddr = MPI.Get_address(sbuf)
1135        raddr = MPI.Get_address(rbuf)
1136        stype = MPI.Datatype.Create_struct([6], [saddr], [MPI.CHAR]).Commit()
1137        rtype = MPI.Datatype.Create_struct([6], [raddr], [MPI.CHAR]).Commit()
1138        smsg = [MPI.BOTTOM,  [1], [0] , [stype]]
1139        rmsg = [MPI.BOTTOM, ([1], [0]), [rtype]]
1140        try:
1141            Alltoallw(smsg, rmsg)
1142            self.assertEqual(sbuf, rbuf)
1143        finally:
1144            stype.Free()
1145            rtype.Free()
1146
1147    @unittest.skipIf(pypy_lt_53, 'pypy(<5.3)')
1148    def testMessageBytes(self):
1149        sbuf = b"abc"
1150        rbuf = bytearray(3)
1151        smsg = [sbuf, [3], [0], [MPI.CHAR]]
1152        rmsg = [rbuf, ([3], [0]), [MPI.CHAR]]
1153        Alltoallw(smsg, rmsg)
1154        self.assertEqual(sbuf, rbuf)
1155
1156    @unittest.skipIf(pypy_lt_53, 'pypy(<5.3)')
1157    def testMessageBytearray(self):
1158        sbuf = bytearray(b"abc")
1159        rbuf = bytearray(3)
1160        smsg = [sbuf, [3], [0], [MPI.CHAR]]
1161        rmsg = [rbuf, ([3], [0]), [MPI.CHAR]]
1162        Alltoallw(smsg, rmsg)
1163        self.assertEqual(sbuf, rbuf)
1164        sbuf = bytearray(b"abc")
1165        rbuf = bytearray(3)
1166        smsg = [sbuf, None, None, [MPI.CHAR]]
1167        rmsg = [rbuf, [MPI.CHAR]]
1168        Alltoallw(smsg, rmsg)
1169        self.assertEqual(sbuf[0], rbuf[0])
1170        self.assertEqual(bytearray(2), rbuf[1:])
1171
1172    @unittest.skipIf(array is None, 'array')
1173    def testMessageArray(self):
1174        sbuf = array.array('i', [1,2,3])
1175        rbuf = array.array('i', [0,0,0])
1176        smsg = [sbuf, [3], [0], [MPI.INT]]
1177        rmsg = [rbuf, ([3], [0]), [MPI.INT]]
1178        Alltoallw(smsg, rmsg)
1179        self.assertEqual(sbuf, rbuf)
1180
1181    @unittest.skipIf(numpy is None, 'numpy')
1182    def testMessageNumPy(self):
1183        sbuf = numpy.array([1,2,3], dtype='i')
1184        rbuf = numpy.array([0,0,0], dtype='i')
1185        smsg = [sbuf, [3], [0], [MPI.INT]]
1186        rmsg = [rbuf, ([3], [0]), [MPI.INT]]
1187        Alltoallw(smsg, rmsg)
1188        self.assertTrue((sbuf == rbuf).all())
1189
1190    @unittest.skipIf(array is None, 'array')
1191    def testMessageCAIBuf(self):
1192        sbuf = CAIBuf('i', [1,2,3], readonly=True)
1193        rbuf = CAIBuf('i', [0,0,0], readonly=False)
1194        smsg = [sbuf, [3], [0], [MPI.INT]]
1195        rmsg = [rbuf, ([3], [0]), [MPI.INT]]
1196        Alltoallw(smsg, rmsg)
1197        self.assertEqual(sbuf, rbuf)
1198
1199    @unittest.skipIf(cupy is None, 'cupy')
1200    def testMessageCuPy(self):
1201        sbuf = cupy.array([1,2,3], 'i')
1202        rbuf = cupy.array([0,0,0], 'i')
1203        smsg = [sbuf, [3], [0], [MPI.INT]]
1204        rmsg = [rbuf, ([3], [0]), [MPI.INT]]
1205        Alltoallw(smsg, rmsg)
1206        self.assertTrue((sbuf == rbuf).all())
1207
1208    @unittest.skipIf(numba is None, 'numba')
1209    def testMessageNumba(self):
1210        sbuf = numba.cuda.device_array((3,), 'i')
1211        sbuf[:] = [1,2,3]
1212        rbuf = numba.cuda.device_array((3,), 'i')
1213        rbuf[:] = [0,0,0]
1214        smsg = [sbuf, [3], [0], [MPI.INT]]
1215        rmsg = [rbuf, ([3], [0]), [MPI.INT]]
1216        Alltoallw(smsg, rmsg)
1217        # numba arrays do not have the .all() method
1218        for i in range(3):
1219            self.assertTrue(sbuf[i] == rbuf[i])
1220
1221
1222# ---
1223
1224def PutGet(smsg, rmsg, target=None):
1225    try: win =  MPI.Win.Allocate(256, 1, MPI.INFO_NULL, MPI.COMM_SELF)
1226    except NotImplementedError: win = MPI.WIN_NULL
1227    try:
1228        try: win.Fence()
1229        except NotImplementedError: pass
1230        try: win.Put(smsg, 0, target)
1231        except NotImplementedError: pass
1232        try: win.Fence()
1233        except NotImplementedError: pass
1234        try: win.Get(rmsg, 0, target)
1235        except NotImplementedError:
1236            if isinstance(smsg, (list, tuple)): smsg = smsg[0]
1237            if isinstance(rmsg, (list, tuple)): rmsg = rmsg[0]
1238            try: rmsg[:] = smsg
1239            except: pass
1240        try: win.Fence()
1241        except NotImplementedError: pass
1242    finally:
1243        if win != MPI.WIN_NULL: win.Free()
1244
1245
1246class TestMessageRMA(unittest.TestCase):
1247
1248    def testMessageBad(self):
1249        sbuf = [None, 0, 0, "B", None]
1250        rbuf = [None, 0, 0, "B"]
1251        target = (0, 0, MPI.BYTE)
1252        def f(): PutGet(sbuf, rbuf, target)
1253        self.assertRaises(ValueError, f)
1254        sbuf = [None, 0, 0, "B"]
1255        rbuf = [None, 0, 0, "B", None]
1256        target = (0, 0, MPI.BYTE)
1257        def f(): PutGet(sbuf, rbuf, target)
1258        self.assertRaises(ValueError, f)
1259        sbuf = [None, 0, "B"]
1260        rbuf = [None, 0, "B"]
1261        target = (0, 0, MPI.BYTE, None)
1262        def f(): PutGet(sbuf, rbuf, target)
1263        self.assertRaises(ValueError, f)
1264        sbuf = [None, 0, "B"]
1265        rbuf = [None, 0, "B"]
1266        target = {1:2,3:4}
1267        def f(): PutGet(sbuf, rbuf, target)
1268        self.assertRaises(ValueError, f)
1269
1270    def testMessageNone(self):
1271        for empty in ([None, 0, 0, MPI.BYTE],
1272                      [None, 0, MPI.BYTE],
1273                      [None, MPI.BYTE]):
1274            for target in (None, 0, [0, 0, MPI.BYTE]):
1275                PutGet(empty, empty, target)
1276
1277    def testMessageBottom(self):
1278        for empty in ([MPI.BOTTOM, 0, 0, MPI.BYTE],
1279                      [MPI.BOTTOM, 0, MPI.BYTE],
1280                      [MPI.BOTTOM, MPI.BYTE]):
1281            for target in (None, 0, [0, 0, MPI.BYTE]):
1282                PutGet(empty, empty, target)
1283
1284    @unittest.skipIf(pypy_lt_53, 'pypy(<5.3)')
1285    def testMessageBytes(self):
1286        for target in (None, 0, [0, 3, MPI.BYTE]):
1287            sbuf = b"abc"
1288            rbuf = bytearray(3)
1289            PutGet(sbuf, rbuf, target)
1290            self.assertEqual(sbuf, rbuf)
1291
1292    @unittest.skipIf(pypy_lt_53, 'pypy(<5.3)')
1293    def testMessageBytearray(self):
1294        for target in (None, 0, [0, 3, MPI.BYTE]):
1295            sbuf = bytearray(b"abc")
1296            rbuf = bytearray(3)
1297            PutGet(sbuf, rbuf, target)
1298            self.assertEqual(sbuf, rbuf)
1299
1300    @unittest.skipIf(py3, 'python3')
1301    @unittest.skipIf(pypy2, 'pypy2')
1302    @unittest.skipIf(hasattr(MPI, 'ffi'), 'mpi4py-cffi')
1303    def testMessageUnicode(self):  # Test for Issue #120
1304        sbuf = unicode("abc")
1305        rbuf = bytearray(len(buffer(sbuf)))
1306        PutGet([sbuf, MPI.BYTE], [rbuf, MPI.BYTE], None)
1307
1308    @unittest.skipMPI('msmpi')
1309    @unittest.skipIf(array is None, 'array')
1310    def testMessageArray(self):
1311        sbuf = array.array('i', [1,2,3])
1312        rbuf = array.array('i', [0,0,0])
1313        PutGet(sbuf, rbuf)
1314        self.assertEqual(sbuf, rbuf)
1315
1316    @unittest.skipMPI('msmpi')
1317    @unittest.skipIf(numpy is None, 'numpy')
1318    def testMessageNumPy(self):
1319        sbuf = numpy.array([1,2,3], dtype='i')
1320        rbuf = numpy.array([0,0,0], dtype='i')
1321        PutGet(sbuf, rbuf)
1322        self.assertTrue((sbuf == rbuf).all())
1323
1324    @unittest.skipMPI('msmpi')
1325    @unittest.skipIf(array is None, 'array')
1326    def testMessageCAIBuf(self):
1327        sbuf = CAIBuf('i', [1,2,3], readonly=True)
1328        rbuf = CAIBuf('i', [0,0,0], readonly=False)
1329        PutGet(sbuf, rbuf)
1330        self.assertEqual(sbuf, rbuf)
1331
1332    @unittest.skipMPI('msmpi')
1333    @unittest.skipMPI('mvapich2')
1334    @unittest.skipIf(cupy is None, 'cupy')
1335    def testMessageCuPy(self):
1336        sbuf = cupy.array([1,2,3], 'i')
1337        rbuf = cupy.array([0,0,0], 'i')
1338        PutGet(sbuf, rbuf)
1339        self.assertTrue((sbuf == rbuf).all())
1340
1341    @unittest.skipMPI('msmpi')
1342    @unittest.skipMPI('mvapich2')
1343    @unittest.skipIf(numba is None, 'numba')
1344    def testMessageNumba(self):
1345        sbuf = numba.cuda.device_array((3,), 'i')
1346        sbuf[:] = [1,2,3]
1347        rbuf = numba.cuda.device_array((3,), 'i')
1348        rbuf[:] = [0,0,0]
1349        PutGet(sbuf, rbuf)
1350        # numba arrays do not have the .all() method
1351        for i in range(3):
1352            self.assertTrue(sbuf[i] == rbuf[i])
1353
1354
1355# ---
1356
1357if __name__ == '__main__':
1358    unittest.main()
1359