1from mpi4py import MPI 2import sys 3 4if MPI.Query_thread() < MPI.THREAD_MULTIPLE: 5 sys.stderr.write("MPI does not provide enough thread support\n") 6 sys.exit(0) 7 8try: 9 import threading 10except ImportError: 11 sys.stderr.write("threading module not available\n") 12 sys.exit(0) 13 14try: 15 import numpy 16except ImportError: 17 sys.stderr.write("NumPy package not available\n") 18 sys.exit(0) 19 20send_msg = numpy.arange(1000000, dtype='i') 21recv_msg = numpy.zeros_like(send_msg) 22 23start_event = threading.Event() 24 25def self_send(): 26 start_event.wait() 27 comm = MPI.COMM_WORLD 28 rank = comm.Get_rank() 29 comm.Send([send_msg, MPI.INT], dest=rank, tag=0) 30 31def self_recv(): 32 start_event.wait() 33 comm = MPI.COMM_WORLD 34 rank = comm.Get_rank() 35 comm.Recv([recv_msg, MPI.INT], source=rank, tag=0) 36 37send_thread = threading.Thread(target=self_send) 38recv_thread = threading.Thread(target=self_recv) 39 40for t in (recv_thread, send_thread): 41 t.start() 42assert not numpy.allclose(send_msg, recv_msg) 43 44start_event.set() 45 46for t in (recv_thread, send_thread): 47 t.join() 48assert numpy.allclose(send_msg, recv_msg) 49