1from mpi4py import MPI
2import mpiunittest as unittest
3
4class BaseTestTopo(object):
5
6    COMM = MPI.COMM_NULL
7
8    def checkFortran(self, oldcomm):
9        fint = oldcomm.py2f()
10        newcomm = MPI.Comm.f2py(fint)
11        self.assertEqual(newcomm, oldcomm)
12        self.assertEqual(type(newcomm), type(oldcomm))
13
14    def testCartcomm(self):
15        comm = self.COMM
16        size = comm.Get_size()
17        rank = comm.Get_rank()
18        for ndim in (1,2,3,4,5):
19            dims = MPI.Compute_dims(size, [0]*ndim)
20            periods = [True] * len(dims)
21            topo = comm.Create_cart(dims, periods=periods)
22            self.assertTrue(topo.is_topo)
23            self.assertTrue(topo.topology, MPI.CART)
24            self.checkFortran(topo)
25            self.assertEqual(topo.dim, len(dims))
26            self.assertEqual(topo.ndim, len(dims))
27            coordinates = topo.coords
28            self.assertEqual(coordinates, topo.Get_coords(topo.rank))
29            neighbors = []
30            for i in range(ndim):
31                for d in (-1, +1):
32                    coord = list(coordinates)
33                    coord[i] = (coord[i]+d) % dims[i]
34                    neigh = topo.Get_cart_rank(coord)
35                    self.assertEqual(coord, topo.Get_coords(neigh))
36                    source, dest = topo.Shift(i, d)
37                    self.assertEqual(neigh, dest)
38                    neighbors.append(neigh)
39            self.assertEqual(topo.indegree, len(neighbors))
40            self.assertEqual(topo.outdegree, len(neighbors))
41            self.assertEqual(topo.inedges, neighbors)
42            self.assertEqual(topo.outedges, neighbors)
43            inedges, outedges = topo.inoutedges
44            self.assertEqual(inedges, neighbors)
45            self.assertEqual(outedges, neighbors)
46            if ndim == 1:
47                topo.Free()
48                continue
49            for i in range(ndim):
50                rem_dims = [1]*ndim
51                rem_dims[i] = 0
52                sub = topo.Sub(rem_dims)
53                if sub != MPI.COMM_NULL:
54                    self.assertEqual(sub.dim, ndim-1)
55                    dims = topo.dims
56                    del dims[i]
57                    self.assertEqual(sub.dims, dims)
58                    sub.Free()
59            topo.Free()
60
61    @unittest.skipMPI('MPI(<2.0)')
62    def testCartcommZeroDim(self):
63        comm = self.COMM
64        topo = comm.Create_cart([])
65        if topo == MPI.COMM_NULL: return
66        self.assertEqual(topo.dim, 0)
67        self.assertEqual(topo.dims, [])
68        self.assertEqual(topo.periods, [])
69        self.assertEqual(topo.coords, [])
70        rank = topo.Get_cart_rank([])
71        self.assertEqual(rank, 0)
72        inedges, outedges = topo.inoutedges
73        self.assertEqual(inedges, [])
74        self.assertEqual(outedges, [])
75        topo.Free()
76
77    def testGraphcomm(self):
78        comm = self.COMM
79        size = comm.Get_size()
80        rank = comm.Get_rank()
81        index, edges = [0], []
82        for i in range(size):
83            pos = index[-1]
84            index.append(pos+2)
85            edges.append((i-1)%size)
86            edges.append((i+1)%size)
87        topo = comm.Create_graph(index[1:], edges)
88        self.assertTrue(topo.is_topo)
89        self.assertTrue(topo.topology, MPI.GRAPH)
90        self.checkFortran(topo)
91        topo.Free()
92        topo = comm.Create_graph(index, edges)
93        self.assertEqual(topo.dims, (len(index)-1, len(edges)))
94        self.assertEqual(topo.nnodes, len(index)-1)
95        self.assertEqual(topo.nedges, len(edges))
96        self.assertEqual(topo.index, index[1:])
97        self.assertEqual(topo.edges, edges)
98        neighbors = edges[index[rank]:index[rank+1]]
99        self.assertEqual(neighbors, topo.neighbors)
100        for rank in range(size):
101            neighs = topo.Get_neighbors(rank)
102            self.assertEqual(neighs, [(rank-1)%size, (rank+1)%size])
103        self.assertEqual(topo.indegree, len(neighbors))
104        self.assertEqual(topo.outdegree, len(neighbors))
105        self.assertEqual(topo.inedges, neighbors)
106        self.assertEqual(topo.outedges, neighbors)
107        inedges, outedges = topo.inoutedges
108        self.assertEqual(inedges, neighbors)
109        self.assertEqual(outedges, neighbors)
110        topo.Free()
111
112    @unittest.skipMPI('msmpi')
113    def testDistgraphcommAdjacent(self):
114        comm = self.COMM
115        size = comm.Get_size()
116        rank = comm.Get_rank()
117        try:
118            topo = comm.Create_dist_graph_adjacent(None, None)
119            topo.Free()
120        except NotImplementedError:
121            self.skipTest('mpi-comm-create_dist_graph_adjacent')
122        #
123        sources = [(rank-2)%size, (rank-1)%size]
124        destinations = [(rank+1)%size, (rank+2)%size]
125        topo = comm.Create_dist_graph_adjacent(sources, destinations)
126        self.assertTrue(topo.is_topo)
127        self.assertTrue(topo.topology, MPI.DIST_GRAPH)
128        self.checkFortran(topo)
129        self.assertEqual(topo.Get_dist_neighbors_count(), (2, 2, False))
130        self.assertEqual(topo.Get_dist_neighbors(), (sources, destinations, None))
131        self.assertEqual(topo.indegree, len(sources))
132        self.assertEqual(topo.outdegree, len(destinations))
133        self.assertEqual(topo.inedges, sources)
134        self.assertEqual(topo.outedges, destinations)
135        inedges, outedges = topo.inoutedges
136        self.assertEqual(inedges, sources)
137        self.assertEqual(outedges, destinations)
138        topo.Free()
139        #
140        sourceweights = [1, 2]
141        destweights   = [3, 4]
142        weights = (sourceweights, destweights)
143        topo = comm.Create_dist_graph_adjacent(sources, destinations,
144                                               sourceweights, destweights)
145        self.assertEqual(topo.Get_dist_neighbors_count(), (2, 2, True))
146        self.assertEqual(topo.Get_dist_neighbors(), (sources, destinations, weights))
147        topo.Free()
148        #
149        topo = comm.Create_dist_graph_adjacent(sources, None, MPI.UNWEIGHTED, None)
150        self.assertEqual(topo.Get_dist_neighbors_count(), (2, 0, False))
151        self.assertEqual(topo.Get_dist_neighbors(), (sources, [], None))
152        topo.Free()
153        topo = comm.Create_dist_graph_adjacent(None, destinations, None, MPI.UNWEIGHTED)
154        self.assertEqual(topo.Get_dist_neighbors_count(), (0, 2, False))
155        self.assertEqual(topo.Get_dist_neighbors(), ([], destinations, None))
156        topo.Free()
157        if MPI.VERSION < 3: return
158        topo = comm.Create_dist_graph_adjacent([], [], MPI.WEIGHTS_EMPTY, MPI.WEIGHTS_EMPTY)
159        self.assertEqual(topo.Get_dist_neighbors_count(), (0, 0, True))
160        self.assertEqual(topo.Get_dist_neighbors(), ([], [], ([], [])))
161        topo.Free()
162
163    @unittest.skipMPI('msmpi')
164    @unittest.skipMPI('PlatformMPI')
165    def testDistgraphcomm(self):
166        comm = self.COMM
167        size = comm.Get_size()
168        rank = comm.Get_rank()
169        #
170        try:
171            topo = comm.Create_dist_graph([], [], [], MPI.UNWEIGHTED)
172            topo.Free()
173        except NotImplementedError:
174            self.skipTest('mpi-comm-create_dist_graph')
175        #
176        sources = [rank]
177        degrees = [3]
178        destinations = [(rank-1)%size, rank, (rank+1)%size]
179        topo = comm.Create_dist_graph(sources, degrees, destinations, MPI.UNWEIGHTED)
180        self.assertTrue(topo.is_topo)
181        self.assertTrue(topo.topology, MPI.DIST_GRAPH)
182        self.checkFortran(topo)
183        self.assertEqual(topo.Get_dist_neighbors_count(), (3, 3, False))
184        topo.Free()
185        weights = list(range(1,4))
186        topo = comm.Create_dist_graph(sources, degrees, destinations, weights)
187        self.assertEqual(topo.Get_dist_neighbors_count(), (3, 3, True))
188        topo.Free()
189
190    def testCartMap(self):
191        comm = self.COMM
192        size = comm.Get_size()
193        for ndim in (1,2,3,4,5):
194            for periods in (None, True, False):
195                dims = MPI.Compute_dims(size, [0]*ndim)
196                topo = comm.Create_cart(dims, periods, reorder=True)
197                rank = comm.Cart_map(dims, periods)
198                self.assertEqual(topo.Get_rank(), rank)
199                topo.Free()
200
201    def testGraphMap(self):
202        comm = self.COMM
203        size = comm.Get_size()
204        index, edges = [0], []
205        for i in range(size):
206            pos = index[-1]
207            index.append(pos+2)
208            edges.append((i-1)%size)
209            edges.append((i+1)%size)
210        # Version 1
211        topo = comm.Create_graph(index, edges, reorder=True)
212        rank = comm.Graph_map(index, edges)
213        self.assertEqual(topo.Get_rank(), rank)
214        topo.Free()
215        # Version 2
216        topo = comm.Create_graph(index[1:], edges, reorder=True)
217        rank = comm.Graph_map(index[1:], edges)
218        self.assertEqual(topo.Get_rank(), rank)
219        topo.Free()
220
221
222class TestTopoSelf(BaseTestTopo, unittest.TestCase):
223    COMM = MPI.COMM_SELF
224
225class TestTopoWorld(BaseTestTopo, unittest.TestCase):
226    COMM = MPI.COMM_WORLD
227
228class TestTopoSelfDup(TestTopoSelf):
229    def setUp(self):
230        self.COMM = MPI.COMM_SELF.Dup()
231    def tearDown(self):
232        self.COMM.Free()
233
234class TestTopoWorldDup(TestTopoWorld):
235    def setUp(self):
236        self.COMM = MPI.COMM_WORLD.Dup()
237    def tearDown(self):
238        self.COMM.Free()
239
240
241if __name__ == '__main__':
242    unittest.main()
243