1#import mpi4py
2#mpi4py.profile("mpe")
3from mpi4py import MPI
4
5import unittest
6
7import sys, os
8sys.path.insert(0, os.path.dirname(__file__))
9from reductions import Intracomm
10del sys.path[0]
11
12class BaseTest(object):
13
14    def test_reduce(self):
15        rank = self.comm.rank
16        size = self.comm.size
17        for root in range(size):
18            msg = rank
19            res = self.comm.reduce(sendobj=msg, root=root)
20            if self.comm.rank == root:
21                self.assertEqual(res, sum(range(size)))
22            else:
23                self.assertEqual(res, None)
24
25    def test_reduce_min(self):
26        rank = self.comm.rank
27        size = self.comm.size
28        for root in range(size):
29            msg = rank
30            res = self.comm.reduce(sendobj=msg, op=MPI.MIN, root=root)
31            if self.comm.rank == root:
32                self.assertEqual(res, 0)
33            else:
34                self.assertEqual(res, None)
35
36    def test_reduce_max(self):
37        rank = self.comm.rank
38        size = self.comm.size
39        for root in range(size):
40            msg = rank
41            res = self.comm.reduce(sendobj=msg, op=MPI.MAX, root=root)
42            if self.comm.rank == root:
43                self.assertEqual(res, size-1)
44            else:
45                self.assertEqual(res, None)
46
47    def test_reduce_minloc(self):
48        rank = self.comm.rank
49        size = self.comm.size
50        for root in range(size):
51            msg = rank
52            res = self.comm.reduce(sendobj=(msg, rank), op=MPI.MINLOC, root=root)
53            if self.comm.rank == root:
54                self.assertEqual(res, (0, 0))
55            else:
56                self.assertEqual(res, None)
57
58    def test_reduce_maxloc(self):
59        rank = self.comm.rank
60        size = self.comm.size
61        for root in range(size):
62            msg = rank
63            res = self.comm.reduce(sendobj=(msg, rank), op=MPI.MAXLOC, root=root)
64            if self.comm.rank == root:
65                self.assertEqual(res, (size-1, size-1))
66            else:
67                self.assertEqual(res, None)
68
69    def test_allreduce(self):
70        rank = self.comm.rank
71        size = self.comm.size
72        msg = rank
73        res = self.comm.allreduce(sendobj=msg)
74        self.assertEqual(res, sum(range(size)))
75
76    def test_allreduce_min(self):
77        rank = self.comm.rank
78        size = self.comm.size
79        msg = rank
80        res = self.comm.allreduce(sendobj=msg, op=MPI.MIN)
81        self.assertEqual(res, 0)
82
83    def test_allreduce_max(self):
84        rank = self.comm.rank
85        size = self.comm.size
86        msg = rank
87        res = self.comm.allreduce(sendobj=msg, op=MPI.MAX)
88        self.assertEqual(res, size-1)
89
90    def test_allreduce_minloc(self):
91        rank = self.comm.rank
92        size = self.comm.size
93        msg = rank
94        res = self.comm.allreduce(sendobj=(msg, rank), op=MPI.MINLOC)
95        self.assertEqual(res, (0, 0))
96
97    def test_allreduce_maxloc(self):
98        rank = self.comm.rank
99        size = self.comm.size
100        msg = rank
101        res = self.comm.allreduce(sendobj=(msg, rank), op=MPI.MAXLOC)
102        self.assertEqual(res, (size-1, size-1))
103
104    def test_scan(self):
105        rank = self.comm.rank
106        size = self.comm.size
107        msg = rank
108        res = self.comm.scan(sendobj=msg)
109        self.assertEqual(res, sum(list(range(size))[:rank+1]))
110
111    def test_scan_min(self):
112        rank = self.comm.rank
113        size = self.comm.size
114        msg = rank
115        res = self.comm.scan(sendobj=msg, op=MPI.MIN)
116        self.assertEqual(res, 0)
117
118    def test_scan_max(self):
119        rank = self.comm.rank
120        size = self.comm.size
121        msg = rank
122        res = self.comm.scan(sendobj=msg, op=MPI.MAX)
123        self.assertEqual(res, rank)
124
125    def test_scan_minloc(self):
126        rank = self.comm.rank
127        size = self.comm.size
128        msg = rank
129        res = self.comm.scan(sendobj=(msg, rank), op=MPI.MINLOC)
130        self.assertEqual(res, (0, 0))
131
132    def test_scan_maxloc(self):
133        rank = self.comm.rank
134        size = self.comm.size
135        msg = rank
136        res = self.comm.scan(sendobj=(msg, rank), op=MPI.MAXLOC)
137        self.assertEqual(res, (rank, rank))
138
139    def test_exscan(self):
140        rank = self.comm.rank
141        size = self.comm.size
142        msg = rank
143        res = self.comm.exscan(sendobj=msg)
144        if self.comm.rank == 0:
145            self.assertEqual(res, None)
146        else:
147            self.assertEqual(res, sum(list(range(size))[:rank]))
148
149    def test_exscan_min(self):
150        rank = self.comm.rank
151        size = self.comm.size
152        msg = rank
153        res = self.comm.exscan(sendobj=msg, op=MPI.MIN)
154        if self.comm.rank == 0:
155            self.assertEqual(res, None)
156        else:
157            self.assertEqual(res, 0)
158
159    def test_exscan_max(self):
160        rank = self.comm.rank
161        size = self.comm.size
162        msg = rank
163        res = self.comm.exscan(sendobj=msg, op=MPI.MAX)
164        if self.comm.rank == 0:
165            self.assertEqual(res, None)
166        else:
167            self.assertEqual(res, rank-1)
168
169    def test_exscan_minloc(self):
170        rank = self.comm.rank
171        size = self.comm.size
172        msg = rank
173        res = self.comm.exscan(sendobj=(msg, rank), op=MPI.MINLOC)
174        if self.comm.rank == 0:
175            self.assertEqual(res, None)
176        else:
177            self.assertEqual(res, (0, 0))
178
179    def test_exscan_maxloc(self):
180        rank = self.comm.rank
181        size = self.comm.size
182        msg = rank
183        res = self.comm.exscan(sendobj=(msg, rank), op=MPI.MAXLOC)
184        if self.comm.rank == 0:
185            self.assertEqual(res, None)
186        else:
187            self.assertEqual(res, (rank-1, rank-1))
188
189class TestS(BaseTest, unittest.TestCase):
190    def setUp(self):
191        self.comm = Intracomm(MPI.COMM_SELF)
192
193class TestW(BaseTest, unittest.TestCase):
194    def setUp(self):
195        self.comm = Intracomm(MPI.COMM_WORLD)
196
197class TestSD(BaseTest, unittest.TestCase):
198    def setUp(self):
199        self.comm = Intracomm(MPI.COMM_SELF.Dup())
200    def tearDown(self):
201        self.comm.Free()
202
203class TestWD(BaseTest, unittest.TestCase):
204    def setUp(self):
205        self.comm = Intracomm(MPI.COMM_WORLD.Dup())
206    def tearDown(self):
207        self.comm.Free()
208
209if __name__ == "__main__":
210    unittest.main()
211