1# Copyright 2020, 2021 PaGMO development team
2#
3# This file is part of the pygmo library.
4#
5# This Source Code Form is subject to the terms of the Mozilla
6# Public License v. 2.0. If a copy of the MPL was not distributed
7# with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8
9import unittest as _ut
10
11
12class _topo(object):
13
14    def get_connections(self, n):
15        return [[], []]
16
17    def push_back(self):
18        return
19
20
21class topology_test_case(_ut.TestCase):
22    """Test case for the :class:`~pygmo.topology` class.
23
24    """
25
26    def runTest(self):
27        self.run_basic_tests()
28        self.run_extract_tests()
29        self.run_name_info_tests()
30        self.run_pickle_tests()
31        self.run_to_networkx_tests()
32
33    def run_basic_tests(self):
34        # Tests for minimal topology, and mandatory methods.
35        from numpy import ndarray, dtype
36        from .core import topology, ring, unconnected
37        # Def construction.
38        t = topology()
39        self.assertTrue(t.extract(unconnected) is not None)
40        self.assertTrue(t.extract(ring) is None)
41
42        # First a few non-topos.
43        self.assertRaises(NotImplementedError, lambda: topology(1))
44        self.assertRaises(NotImplementedError,
45                          lambda: topology("hello world"))
46        self.assertRaises(NotImplementedError, lambda: topology([]))
47        self.assertRaises(TypeError, lambda: topology(int))
48        # Some topologies missing methods, wrong arity, etc.
49
50        class nt0(object):
51            pass
52        self.assertRaises(NotImplementedError, lambda: topology(nt0()))
53
54        class nt1(object):
55
56            get_connections = 45
57            push_back = 45
58        self.assertRaises(NotImplementedError, lambda: topology(nt1()))
59
60        # The minimal good citizen.
61        glob = []
62
63        class t(object):
64
65            def __init__(self, g):
66                self.g = g
67
68            def push_back(self):
69                self.g.append(1)
70                return 1
71
72            def get_connections(self, n):
73                self.g.append(2)
74                return [[], []]
75
76        t_inst = t(glob)
77        topo = topology(t_inst)
78
79        with self.assertRaises(TypeError) as cm:
80            topo.push_back(n=-1)
81
82        # Test the keyword arg.
83        topo = topology(udt=ring())
84        topo = topology(udt=t_inst)
85
86        # Check a few topo properties.
87        self.assertEqual(topo.get_extra_info(), "")
88        self.assertTrue(topo.extract(int) is None)
89        self.assertTrue(topo.extract(ring) is None)
90        self.assertFalse(topo.extract(t) is None)
91        self.assertTrue(topo.is_(t))
92        self.assertTrue(isinstance(topo.get_connections(0), tuple))
93        self.assertTrue(isinstance(topo.get_connections(0)[0], ndarray))
94        self.assertTrue(isinstance(topo.get_connections(0)[1], ndarray))
95        self.assertTrue(topo.get_connections(n=0)[1].dtype == dtype(float))
96        # Assert that t_inst was deep-copied into topo:
97        # the instance in topo will have its own copy of glob
98        # and it will not be a reference the outside object.
99        self.assertEqual(len(glob), 0)
100        self.assertEqual(len(topo.extract(t).g), 4)
101        self.assertEqual(topo.extract(t).g, [2]*4)
102        self.assertTrue(topo.push_back() is None)
103        self.assertEqual(topo.extract(t).g, [2]*4 + [1])
104
105        topo = topology(ring())
106        self.assertTrue(topo.get_extra_info() != "")
107        self.assertTrue(topo.extract(int) is None)
108        self.assertTrue(topo.extract(t) is None)
109        self.assertFalse(topo.extract(ring) is None)
110        self.assertTrue(topo.is_(ring))
111        self.assertTrue(isinstance(topo.push_back(), type(None)))
112
113        # Wrong retval for get_connections().
114
115        class t(object):
116
117            def push_back(self):
118                pass
119
120            def get_connections(self, n):
121                return []
122        topo = topology(t())
123        self.assertRaises(RuntimeError, lambda: topo.get_connections(0))
124
125        class t(object):
126
127            def push_back(self):
128                pass
129
130            def get_connections(self, n):
131                return [1]
132        topo = topology(t())
133        self.assertRaises(ValueError, lambda: topo.get_connections(0))
134
135        class t(object):
136
137            def push_back(self):
138                pass
139
140            def get_connections(self, n):
141                return [1, 2, 3]
142        topo = topology(t())
143        self.assertRaises(ValueError, lambda: topo.get_connections(0))
144
145        class t(object):
146
147            def push_back(self):
148                pass
149
150            def get_connections(self, n):
151                return [[1, 2, 3], [.5]]
152        topo = topology(t())
153        with self.assertRaises(ValueError) as cm:
154            topo.get_connections(0)
155        err = cm.exception
156        self.assertTrue(
157            "while the vector of migration probabilities has a size of" in str(err))
158
159        class t(object):
160
161            def push_back(self):
162                pass
163
164            def get_connections(self, n):
165                return [[1, 2, 3], [.5, .6, 1.4]]
166        topo = topology(t())
167        with self.assertRaises(ValueError) as cm:
168            topo.get_connections(0)
169        err = cm.exception
170        self.assertTrue(
171            "An invalid migration probability of " in str(err))
172
173        class t(object):
174
175            def push_back(self):
176                pass
177
178            def get_connections(self, n):
179                return [[1, 2, 3], [.5, .6, float("inf")]]
180        topo = topology(t())
181        with self.assertRaises(ValueError) as cm:
182            topo.get_connections(0)
183        err = cm.exception
184        self.assertTrue(
185            "An invalid non-finite migration probability of " in str(err))
186
187        # Test that construction from another pygmo.topology fails.
188        with self.assertRaises(TypeError) as cm:
189            topology(topo)
190        err = cm.exception
191        self.assertTrue(
192            "a pygmo.topology cannot be used as a UDT for another pygmo.topology (if you need to copy a topology please use the standard Python copy()/deepcopy() functions)" in str(err))
193
194    def run_extract_tests(self):
195        from .core import topology, _test_topology, ring
196        import sys
197
198        # First we try with a C++ test topo.
199        t = topology(_test_topology())
200        # Verify the refcount of p is increased after extract().
201        rc = sys.getrefcount(t)
202        ttopo = t.extract(_test_topology)
203        self.assertEqual(sys.getrefcount(t), rc + 1)
204        del ttopo
205        self.assertEqual(sys.getrefcount(t), rc)
206        # Verify we are modifying the inner object.
207        t.extract(_test_topology).set_n(5)
208        self.assertEqual(t.extract(_test_topology).get_n(), 5)
209
210        class ttopology(object):
211
212            def __init__(self):
213                self._n = 1
214
215            def get_n(self):
216                return self._n
217
218            def set_n(self, n):
219                self._n = n
220
221            def get_connections(self, n):
222                return [[], []]
223
224            def push_back(self):
225                pass
226
227        # Test with Python topology.
228        t = topology(ttopology())
229        rc = sys.getrefcount(t)
230        ttopo = t.extract(ttopology)
231        # Reference count does not increase because
232        # ttopology is stored as a proper Python object
233        # with its own refcount.
234        self.assertTrue(sys.getrefcount(t) == rc)
235        self.assertTrue(ttopo.get_n() == 1)
236        ttopo.set_n(12)
237        self.assert_(t.extract(ttopology).get_n() == 12)
238
239        # Check that we can extract Python UDTs also via Python's object type.
240        t = topology(ttopology())
241        self.assertTrue(not t.extract(object) is None)
242        # Check we are referring to the same object.
243        self.assertEqual(id(t.extract(object)), id(t.extract(ttopology)))
244        # Check that it will not work with exposed C++ topologies.
245        t = topology(ring())
246        self.assertTrue(t.extract(object) is None)
247        self.assertTrue(not t.extract(ring) is None)
248
249    def run_name_info_tests(self):
250        from .core import topology
251
252        class t(object):
253
254            def get_connections(self, n):
255                return [[], []]
256
257            def push_back(self):
258                pass
259
260        topo = topology(t())
261        self.assertTrue(topo.get_name() != '')
262        self.assertTrue(topo.get_extra_info() == '')
263
264        class t(object):
265
266            def get_connections(self, n):
267                return [[], []]
268
269            def push_back(self):
270                pass
271
272            def get_name(self):
273                return 'pippo'
274
275        topo = topology(t())
276        self.assertTrue(topo.get_name() == 'pippo')
277        self.assertTrue(topo.get_extra_info() == '')
278
279        class t(object):
280
281            def get_connections(self, n):
282                return [[], []]
283
284            def push_back(self):
285                pass
286
287            def get_extra_info(self):
288                return 'pluto'
289
290        topo = topology(t())
291        self.assertTrue(topo.get_name() != '')
292        self.assertTrue(topo.get_extra_info() == 'pluto')
293
294        class t(object):
295
296            def get_connections(self, n):
297                return [[], []]
298
299            def push_back(self):
300                pass
301
302            def get_name(self):
303                return 'pippo'
304
305            def get_extra_info(self):
306                return 'pluto'
307
308        topo = topology(t())
309        self.assertTrue(topo.get_name() == 'pippo')
310        self.assertTrue(topo.get_extra_info() == 'pluto')
311
312    def run_pickle_tests(self):
313        from .core import topology, ring
314        from pickle import dumps, loads
315        t_ = topology(ring())
316        t = loads(dumps(t_))
317        self.assertEqual(repr(t), repr(t_))
318        self.assertTrue(t.is_(ring))
319
320        t_ = topology(_topo())
321        t = loads(dumps(t_))
322        self.assertEqual(repr(t), repr(t_))
323        self.assertTrue(t.is_(_topo))
324
325    def run_to_networkx_tests(self):
326        from .core import topology
327
328        try:
329            import networkx as nx
330        except ImportError:
331            return
332
333        g = nx.DiGraph()
334        g.add_weighted_edges_from([(0, 1, .5), (1, 2, 1.)])
335
336        # Good implementation.
337        class t:
338
339            def get_connections(self, n):
340                return [[], []]
341
342            def push_back(self):
343                pass
344
345            def to_networkx(self):
346                ret = nx.DiGraph()
347                ret.add_weighted_edges_from([(0, 1, .5), (1, 2, 1.)])
348                return ret
349
350        self.assertTrue(nx.is_isomorphic(
351            topology(t()).to_networkx(), g))
352
353        # Graph with isolated nodes, and nodes not numbered
354        # sequentially.
355        g = nx.DiGraph()
356        g.add_weighted_edges_from([(0, 1, .5), (1, 2, 1.)])
357        g.add_node(3)
358        g.add_node(4)
359
360        class t:
361
362            def get_connections(self, n):
363                return [[], []]
364
365            def push_back(self):
366                pass
367
368            def to_networkx(self):
369                ret = nx.DiGraph()
370                ret.add_weighted_edges_from([(0, 1, .5), (1, 2, 1.)])
371                ret.add_node(7)
372                ret.add_node(8)
373                return ret
374
375        self.assertTrue(nx.is_isomorphic(
376            topology(t()).to_networkx(), g))
377        self.assertEqual(
378            list(topology(t()).to_networkx().nodes), [0, 1, 2, 3, 4])
379
380        # Nodes attributes stripped away.
381        class t:
382
383            def get_connections(self, n):
384                return [[], []]
385
386            def push_back(self):
387                pass
388
389            def to_networkx(self):
390                ret = nx.DiGraph()
391                ret.add_weighted_edges_from([(0, 1, .5), (1, 2, 1.)])
392                ret.add_node(7, size=10)
393                ret.add_node(8, weight=20)
394                return ret
395
396        tmp = topology(t()).to_networkx()
397        self.assertTrue(nx.is_isomorphic(tmp, g))
398        self.assertEqual(list(tmp.nodes), [0, 1, 2, 3, 4])
399        self.assertEqual(tmp[3], {})
400        self.assertEqual(tmp[4], {})
401
402        # Edge attributes other than weight stripped away.
403        class t:
404
405            def get_connections(self, n):
406                return [[], []]
407
408            def push_back(self):
409                pass
410
411            def to_networkx(self):
412                ret = nx.DiGraph()
413                ret.add_edge(0, 1, size=56, weight=.5)
414                ret.add_edge(1, 2, color='blue', weight=1)
415                ret.add_node(7, size=10)
416                ret.add_node(8, weight=20)
417                return ret
418
419        tmp = topology(t()).to_networkx()
420        self.assertTrue(nx.is_isomorphic(tmp, g))
421        self.assertEqual(list(tmp.nodes), [0, 1, 2, 3, 4])
422        self.assertEqual(tmp[3], {})
423        self.assertEqual(tmp[4], {})
424        self.assertEqual(tmp.edges[0, 1], {'weight': .5})
425        self.assertEqual(tmp.edges[1, 2], {'weight': 1.})
426
427        # Error handling.
428        # No method.
429        class t:
430
431            def get_connections(self, n):
432                return [[], []]
433
434            def push_back(self):
435                pass
436
437        with self.assertRaises(NotImplementedError) as cm:
438            topology(t()).to_networkx()
439        err = cm.exception
440        self.assertTrue(
441            "the to_networkx() conversion method has been invoked in the user-defined Python topology" in str(err))
442
443        # Wrong return type.
444        class t:
445
446            def get_connections(self, n):
447                return [[], []]
448
449            def push_back(self):
450                pass
451
452            def to_networkx(self):
453                return 1
454
455        with self.assertRaises(TypeError) as cm:
456            topology(t()).to_networkx()
457        err = cm.exception
458        self.assertTrue(
459            "in order to construct a pagmo::bgl_graph_t object a NetworX DiGraph is needed, but an" in str(err))
460
461        # Weightless edges.
462        class t:
463
464            def get_connections(self, n):
465                return [[], []]
466
467            def push_back(self):
468                pass
469
470            def to_networkx(self):
471                ret = nx.DiGraph()
472                ret.add_edges_from([(0, 1), (1, 2)])
473                return ret
474
475        with self.assertRaises(ValueError) as cm:
476            topology(t()).to_networkx()
477        err = cm.exception
478        self.assertTrue(
479            "without a 'weight' attribute was encountered" in str(err))
480