1# (C) Copyright 2007
2# Andreas Kloeckner <inform -at- tiker.net>
3#
4# Use, modification and distribution is subject to the Boost Software
5# License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
6# http://www.boost.org/LICENSE_1_0.txt)
7#
8#  Authors: Andreas Kloeckner
9
10
11
12
13import boost.mpi as mpi
14import random
15import sys
16
17MAX_GENERATIONS = 20
18TAG_DEBUG = 0
19TAG_DATA = 1
20TAG_TERMINATE = 2
21TAG_PROGRESS_REPORT = 3
22
23
24
25
26class TagGroupListener:
27    """Class to help listen for only a given set of tags.
28
29    This is contrived: Typicallly you could just listen for
30    mpi.any_tag and filter."""
31    def __init__(self, comm, tags):
32        self.tags = tags
33        self.comm = comm
34        self.active_requests = {}
35
36    def wait(self):
37        for tag in self.tags:
38            if tag not in self.active_requests:
39                self.active_requests[tag] = self.comm.irecv(tag=tag)
40        requests = mpi.RequestList(self.active_requests.values())
41        data, status, index = mpi.wait_any(requests)
42        del self.active_requests[status.tag]
43        return status, data
44
45    def cancel(self):
46        for r in self.active_requests.itervalues():
47            r.cancel()
48            #r.wait()
49        self.active_requests = {}
50
51
52
53def rank0():
54    sent_histories = (mpi.size-1)*15
55    print "sending %d packets on their way" % sent_histories
56    send_reqs = mpi.RequestList()
57    for i in range(sent_histories):
58        dest = random.randrange(1, mpi.size)
59        send_reqs.append(mpi.world.isend(dest, TAG_DATA, []))
60
61    mpi.wait_all(send_reqs)
62
63    completed_histories = []
64    progress_reports = {}
65    dead_kids = []
66
67    tgl = TagGroupListener(mpi.world,
68            [TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE])
69
70    def is_complete():
71        for i in progress_reports.values():
72            if i != sent_histories:
73                return False
74        return len(dead_kids) == mpi.size-1
75
76    while True:
77        status, data = tgl.wait()
78
79        if status.tag == TAG_DATA:
80            #print "received completed history %s from %d" % (data, status.source)
81            completed_histories.append(data)
82            if len(completed_histories) == sent_histories:
83                print "all histories received, exiting"
84                for rank in range(1, mpi.size):
85                    mpi.world.send(rank, TAG_TERMINATE, None)
86        elif status.tag == TAG_PROGRESS_REPORT:
87            progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1
88        elif status.tag == TAG_DEBUG:
89            print "[DBG %d] %s" % (status.source, data)
90        elif status.tag == TAG_TERMINATE:
91            dead_kids.append(status.source)
92        else:
93            print "unexpected tag %d from %d" % (status.tag, status.source)
94
95        if is_complete():
96            break
97
98    print "OK"
99
100def comm_rank():
101    while True:
102        data, status = mpi.world.recv(return_status=True)
103        if status.tag == TAG_DATA:
104            mpi.world.send(0, TAG_PROGRESS_REPORT, data)
105            data.append(mpi.rank)
106            if len(data) >= MAX_GENERATIONS:
107                dest = 0
108            else:
109                dest = random.randrange(1, mpi.size)
110            mpi.world.send(dest, TAG_DATA, data)
111        elif status.tag == TAG_TERMINATE:
112            from time import sleep
113            mpi.world.send(0, TAG_TERMINATE, 0)
114            break
115        else:
116            print "[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source)
117
118
119def main():
120    # this program sends around messages consisting of lists of visited nodes
121    # randomly. After MAX_GENERATIONS, they are returned to rank 0.
122
123    if mpi.rank == 0:
124        rank0()
125    else:
126        comm_rank()
127
128
129
130if __name__ == "__main__":
131    main()
132