1# (C) Copyright 2007 2# Andreas Kloeckner <inform -at- tiker.net> 3# 4# Use, modification and distribution is subject to the Boost Software 5# License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at 6# http://www.boost.org/LICENSE_1_0.txt) 7# 8# Authors: Andreas Kloeckner 9 10 11 12 13import boost.mpi as mpi 14import random 15import sys 16 17MAX_GENERATIONS = 20 18TAG_DEBUG = 0 19TAG_DATA = 1 20TAG_TERMINATE = 2 21TAG_PROGRESS_REPORT = 3 22 23 24 25 26class TagGroupListener: 27 """Class to help listen for only a given set of tags. 28 29 This is contrived: Typicallly you could just listen for 30 mpi.any_tag and filter.""" 31 def __init__(self, comm, tags): 32 self.tags = tags 33 self.comm = comm 34 self.active_requests = {} 35 36 def wait(self): 37 for tag in self.tags: 38 if tag not in self.active_requests: 39 self.active_requests[tag] = self.comm.irecv(tag=tag) 40 requests = mpi.RequestList(self.active_requests.values()) 41 data, status, index = mpi.wait_any(requests) 42 del self.active_requests[status.tag] 43 return status, data 44 45 def cancel(self): 46 for r in self.active_requests.itervalues(): 47 r.cancel() 48 #r.wait() 49 self.active_requests = {} 50 51 52 53def rank0(): 54 sent_histories = (mpi.size-1)*15 55 print "sending %d packets on their way" % sent_histories 56 send_reqs = mpi.RequestList() 57 for i in range(sent_histories): 58 dest = random.randrange(1, mpi.size) 59 send_reqs.append(mpi.world.isend(dest, TAG_DATA, [])) 60 61 mpi.wait_all(send_reqs) 62 63 completed_histories = [] 64 progress_reports = {} 65 dead_kids = [] 66 67 tgl = TagGroupListener(mpi.world, 68 [TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE]) 69 70 def is_complete(): 71 for i in progress_reports.values(): 72 if i != sent_histories: 73 return False 74 return len(dead_kids) == mpi.size-1 75 76 while True: 77 status, data = tgl.wait() 78 79 if status.tag == TAG_DATA: 80 #print "received completed history %s from %d" % (data, status.source) 81 completed_histories.append(data) 82 if len(completed_histories) == sent_histories: 83 print "all histories received, exiting" 84 for rank in range(1, mpi.size): 85 mpi.world.send(rank, TAG_TERMINATE, None) 86 elif status.tag == TAG_PROGRESS_REPORT: 87 progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1 88 elif status.tag == TAG_DEBUG: 89 print "[DBG %d] %s" % (status.source, data) 90 elif status.tag == TAG_TERMINATE: 91 dead_kids.append(status.source) 92 else: 93 print "unexpected tag %d from %d" % (status.tag, status.source) 94 95 if is_complete(): 96 break 97 98 print "OK" 99 100def comm_rank(): 101 while True: 102 data, status = mpi.world.recv(return_status=True) 103 if status.tag == TAG_DATA: 104 mpi.world.send(0, TAG_PROGRESS_REPORT, data) 105 data.append(mpi.rank) 106 if len(data) >= MAX_GENERATIONS: 107 dest = 0 108 else: 109 dest = random.randrange(1, mpi.size) 110 mpi.world.send(dest, TAG_DATA, data) 111 elif status.tag == TAG_TERMINATE: 112 from time import sleep 113 mpi.world.send(0, TAG_TERMINATE, 0) 114 break 115 else: 116 print "[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source) 117 118 119def main(): 120 # this program sends around messages consisting of lists of visited nodes 121 # randomly. After MAX_GENERATIONS, they are returned to rank 0. 122 123 if mpi.rank == 0: 124 rank0() 125 else: 126 comm_rank() 127 128 129 130if __name__ == "__main__": 131 main() 132