1from mpi4py import MPI
2import sys
3
4if MPI.Query_thread() < MPI.THREAD_MULTIPLE:
5    sys.stderr.write("MPI does not provide enough thread support\n")
6    sys.exit(0)
7
8try:
9    import threading
10except ImportError:
11    sys.stderr.write("threading module not available\n")
12    sys.exit(0)
13
14try:
15    import numpy
16except ImportError:
17    sys.stderr.write("NumPy package not available\n")
18    sys.exit(0)
19
20send_msg = numpy.arange(1000000, dtype='i')
21recv_msg = numpy.zeros_like(send_msg)
22
23start_event = threading.Event()
24
25def self_send():
26    start_event.wait()
27    comm = MPI.COMM_WORLD
28    rank = comm.Get_rank()
29    comm.Send([send_msg, MPI.INT], dest=rank, tag=0)
30
31def self_recv():
32    start_event.wait()
33    comm = MPI.COMM_WORLD
34    rank = comm.Get_rank()
35    comm.Recv([recv_msg, MPI.INT], source=rank, tag=0)
36
37send_thread = threading.Thread(target=self_send)
38recv_thread = threading.Thread(target=self_recv)
39
40for t in (recv_thread, send_thread):
41    t.start()
42assert not numpy.allclose(send_msg, recv_msg)
43
44start_event.set()
45
46for t in (recv_thread, send_thread):
47    t.join()
48assert numpy.allclose(send_msg, recv_msg)
49