1"""
2Tests of the common implementation of the PopulationView class, using the
3pyNN.mock backend.
4
5:copyright: Copyright 2006-2021 by the PyNN team, see AUTHORS.
6:license: CeCILL, see LICENSE for details.
7"""
8
9import unittest
10import numpy as np
11import sys
12from numpy.testing import assert_array_equal, assert_array_almost_equal
13import quantities as pq
14try:
15    from unittest.mock import Mock, patch
16except ImportError:
17    from mock import Mock, patch
18from .mocks import MockRNG
19import pyNN.mock as sim
20from pyNN import random, errors, space
21from pyNN.parameters import Sequence
22
23
24def setUp():
25    pass
26
27
28def tearDown():
29    pass
30
31
32class PopulationViewTest(unittest.TestCase):
33
34    def setUp(self, sim=sim, **extra):
35        sim.setup(**extra)
36
37    def tearDown(self, sim=sim):
38        sim.end()
39
40    # test create with population parent and mask selector
41
42    def test_create_with_slice_selector(self, sim=sim):
43        p = sim.Population(11, sim.IF_cond_exp())
44        mask = slice(3, 9, 2)
45        pv = sim.PopulationView(parent=p, selector=mask)
46        self.assertEqual(pv.parent, p)
47        self.assertEqual(pv.size, 3)
48        self.assertEqual(pv.mask, mask)
49        assert_array_equal(pv.all_cells, np.array(
50            [p.all_cells[3], p.all_cells[5], p.all_cells[7]]))
51        #assert_array_equal(pv.local_cells, np.array([p.all_cells[3]]))
52        #assert_array_equal(pv._mask_local, np.array([1,0,0], dtype=bool))
53        self.assertEqual(pv.celltype, p.celltype)
54        self.assertEqual(pv.first_id, p.all_cells[3])
55        self.assertEqual(pv.last_id, p.all_cells[7])
56
57    def test_create_with_boolean_array_selector(self, sim=sim):
58        p = sim.Population(11, sim.IF_cond_exp())
59        mask = np.array([0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0], dtype=bool)
60        pv = sim.PopulationView(parent=p, selector=mask)
61        assert_array_equal(pv.all_cells, np.array(
62            [p.all_cells[3], p.all_cells[5], p.all_cells[7]]))
63        #assert_array_equal(pv.mask, mask)
64
65    def test_create_with_index_array_selector(self, sim=sim):
66        p = sim.Population(11, sim.IF_cond_alpha())
67        mask = np.array([3, 5, 7])
68        pv = sim.PopulationView(parent=p, selector=mask)
69        assert_array_equal(pv.all_cells, np.array(
70            [p.all_cells[3], p.all_cells[5], p.all_cells[7]]))
71        assert_array_equal(pv.mask, mask)
72
73    # test create with populationview parent and mask selector
74
75    def test_create_with_slice_selector(self, sim=sim):
76        p = sim.Population(11, sim.HH_cond_exp())
77        mask1 = slice(0, 9, 1)
78        pv1 = sim.PopulationView(parent=p, selector=mask1)
79        assert_array_equal(pv1.all_cells, p.all_cells[0:9])
80        mask2 = slice(3, 9, 2)
81        pv2 = sim.PopulationView(parent=pv1, selector=mask2)
82        # or would it be better to resolve the parent chain up to an actual Population?
83        self.assertEqual(pv2.parent, pv1)
84        assert_array_equal(pv2.all_cells, np.array(
85            [p.all_cells[3], p.all_cells[5], p.all_cells[7]]))
86        #assert_array_equal(pv2._mask_local, np.array([1,0,0], dtype=bool))
87
88    # test initial values property
89
90    def test_structure_property(self, sim=sim):
91        p = sim.Population(11, sim.SpikeSourcePoisson())
92        mask = slice(3, 9, 2)
93        pv = sim.PopulationView(parent=p, selector=mask)
94        self.assertEqual(pv.structure, p.structure)
95
96    # test positions property
97
98    def test_get_positions(self, sim=sim):
99        p = sim.Population(11, sim.IF_curr_exp())
100        ppos = np.random.uniform(size=(3, 11))
101        p._positions = ppos
102        pv = sim.PopulationView(parent=p, selector=slice(3, 9, 2))
103        assert_array_equal(pv.positions, np.array([ppos[:, 3], ppos[:, 5], ppos[:, 7]]).T)
104
105    def test_id_to_index(self, sim=sim):
106        p = sim.Population(11, sim.IF_curr_alpha())
107        pv = p[2, 5, 7, 8]
108        self.assertEqual(pv.id_to_index(pv[0]), 0)
109        self.assertEqual(pv.id_to_index(pv[3]), 3)
110        self.assertEqual(pv.id_to_index(p[2]), 0)
111        self.assertEqual(pv.id_to_index(p[8]), 3)
112
113    def test_id_to_index_with_array(self, sim=sim):
114        p = sim.Population(121, sim.IF_curr_alpha())
115        pv = p[2, 5, 7, 8, 19, 37, 49, 82, 83, 99]
116        assert_array_equal(pv.id_to_index(pv.all_cells[3:9:2]), np.arange(3, 9, 2))
117
118    def test_id_to_index_with_invalid_id(self, sim=sim):
119        p = sim.Population(11, sim.IF_curr_alpha())
120        pv = p[2, 5, 7, 8]
121        self.assertRaises(IndexError, pv.id_to_index, p[0])
122        self.assertRaises(IndexError, pv.id_to_index, p[9])
123
124#    def test_id_to_index_with_invalid_ids(self, sim=sim):
125#        p = sim.Population(11, sim.IF_curr_alpha())
126#        pv = p[2, 5, 7, 8]
127#        self.assertRaises(IndexError, pv.id_to_index, p.all_cells[[2, 5, 6]])
128# currently failing
129
130    # def test_id_to_local_index():
131
132    # test structure property
133
134    def test_set_structure(self, sim=sim):
135        p = sim.Population(11, sim.IF_cond_exp(), structure=space.Grid2D())
136        pv = p[2, 5, 7, 8]
137        new_struct = space.Line()
138
139        def set_struct(struct):
140            pv.structure = struct
141        self.assertRaises(AttributeError, set_struct, new_struct)
142
143    # test positions property
144
145    def test_get_positions(self, sim=sim):
146        p = sim.Population(11, sim.IF_cond_exp())
147        pos = np.arange(33).reshape(3, 11)
148        p.positions = pos
149        pv = p[2, 5, 7, 8]
150        assert_array_equal(pv.positions, pos[:, [2, 5, 7, 8]])
151
152    def test_position_generator(self, sim=sim):
153        p = sim.Population(11, sim.IF_cond_exp())
154        pv = p[2, 5, 7, 8]
155        assert_array_equal(pv.position_generator(0), p.positions[:, 2])
156        assert_array_equal(pv.position_generator(3), p.positions[:, 8])
157        assert_array_equal(pv.position_generator(-1), p.positions[:, 8])
158        assert_array_equal(pv.position_generator(-4), p.positions[:, 2])
159        self.assertRaises(IndexError, pv.position_generator, 4)
160        self.assertRaises(IndexError, pv.position_generator, -5)
161
162    def test__getitem__int(self, sim=sim):
163        # Should return the correct ID object
164        p = sim.Population(12, sim.IF_cond_exp())
165        pv = p[1, 5, 6, 8, 11]
166
167        self.assertEqual(pv[0], p[1], 42)
168        self.assertEqual(pv[4], p[11], 53)
169        self.assertRaises(IndexError, pv.__getitem__, 6)
170        self.assertEqual(pv[-1], p[11], 53)
171
172    def test__getitem__slice(self, sim=sim):
173        # Should return a PopulationView with the correct parent and value
174        # of all_cells
175        p = sim.Population(17, sim.HH_cond_exp())
176        pv1 = p[1, 5, 6, 8, 11, 12, 15, 16]
177
178        pv2 = pv1[2:6]
179        self.assertEqual(pv2.parent, pv1)
180        self.assertEqual(pv2.grandparent, p)
181        assert_array_equal(pv2.all_cells, pv1.all_cells[[2, 3, 4, 5]])
182        assert_array_equal(pv2.all_cells, p.all_cells[[6, 8, 11, 12]])
183
184    def test__getitem__list(self, sim=sim):
185        p = sim.Population(23, sim.HH_cond_exp())
186        pv1 = p[1, 5, 6, 8, 11, 12, 15, 16, 19, 20]
187
188        pv2 = pv1[list(range(3, 8))]
189        self.assertEqual(pv2.parent, pv1)
190        assert_array_almost_equal(pv2.all_cells, p.all_cells[[8, 11, 12, 15, 16]])
191
192    def test__getitem__tuple(self, sim=sim):
193        p = sim.Population(23, sim.HH_cond_exp())
194        pv1 = p[1, 5, 6, 8, 11, 12, 15, 16, 19, 20]
195
196        pv2 = pv1[(3, 5, 7)]
197        self.assertEqual(pv2.parent, pv1)
198        assert_array_almost_equal(pv2.all_cells, p.all_cells[[8, 12, 16]])
199
200    def test__getitem__invalid(self, sim=sim):
201        p = sim.Population(23, sim.IF_curr_alpha())
202        pv = p[1, 5, 6, 8, 11, 12, 15, 16, 19, 20]
203        self.assertRaises(TypeError, pv.__getitem__, "foo")
204
205    def test__len__(self, sim=sim):
206        # len(p) should give the global size (all MPI nodes)
207        p = sim.Population(77, sim.IF_cond_exp())
208        pv = p[1, 5, 6, 8, 11, 12, 15, 16, 19, 20]
209        self.assertEqual(len(pv), pv.size, 10)
210
211    def test_iter(self, sim=sim):
212        p = sim.Population(33, sim.IF_curr_exp())
213        pv = p[1, 5, 6, 8, 11, 12]
214        itr = pv.__iter__()
215        assert hasattr(itr, "next") or hasattr(itr, "__next__")
216        self.assertEqual(len(list(itr)), 6)
217
218    def test___add__two(self, sim=sim):
219        # adding two population views should give an Assembly
220        pv1 = sim.Population(6, sim.IF_curr_exp())[2, 3, 5]
221        pv2 = sim.Population(17, sim.IF_cond_exp())[4, 2, 16]
222        assembly = pv1 + pv2
223        self.assertIsInstance(assembly, sim.Assembly)
224        self.assertEqual(assembly.populations, [pv1, pv2])
225
226    def test___add__three(self, sim=sim):
227        # adding three population views should give an Assembly
228        pv1 = sim.Population(6, sim.IF_curr_exp())[0:3]
229        pv2 = sim.Population(17, sim.IF_cond_exp())[1, 5, 14]
230        pv3 = sim.Population(9, sim.HH_cond_exp())[3:8]
231        assembly = pv1 + pv2 + pv3
232        self.assertIsInstance(assembly, sim.Assembly)
233        self.assertEqual(assembly.populations, [pv1, pv2, pv3])
234
235    def test_nearest(self, sim=sim):
236        p = sim.Population(13, sim.IF_cond_exp())
237        p.positions = np.arange(39).reshape((13, 3)).T
238        pv = p[0, 2, 5, 11]
239        self.assertEqual(pv.nearest((0.0, 1.0, 2.0)), pv[0])
240        self.assertEqual(pv.nearest((3.0, 4.0, 5.0)), pv[0])
241        self.assertEqual(pv.nearest((36.0, 37.0, 38.0)), pv[3])
242        self.assertEqual(pv.nearest((1.49, 2.49, 3.49)), pv[0])
243        self.assertEqual(pv.nearest((1.51, 2.51, 3.51)), pv[0])
244
245    def test_sample(self, sim=sim):
246        p = sim.Population(13, sim.IF_cond_exp())
247        pv1 = p[0, 3, 7, 10, 12]
248
249        rng = Mock()
250        rng.permutation = Mock(return_value=np.array([3, 1, 0, 2, 4]))
251        pv2 = pv1.sample(3, rng=rng)
252        assert_array_equal(pv2.all_cells,
253                           sorted(p.all_cells[[10, 3, 0]]))
254
255    def test_get_multiple_homogeneous_params_with_gather(self, sim=sim):
256        p = sim.Population(10, sim.IF_cond_exp, {
257                           'tau_m': 12.3, 'tau_syn_E': 0.987, 'tau_syn_I': 0.7})
258        pv = p[3:7]
259        tau_syn_E, tau_m = pv.get(('tau_syn_E', 'tau_m'), gather=True)
260        self.assertEqual(tau_syn_E, 0.987)
261        self.assertAlmostEqual(tau_m, 12.3)
262
263    def test_get_single_homogeneous_param_with_gather(self, sim=sim):
264        p = sim.Population(4, sim.IF_cond_exp, {
265                           'tau_m': 12.3, 'tau_syn_E': 0.987, 'tau_syn_I': 0.7})
266        pv = p[:]
267        tau_syn_E = pv.get('tau_syn_E', gather=True)
268        self.assertEqual(tau_syn_E, 0.987)
269
270    def test_get_multiple_inhomogeneous_params_with_gather(self, sim=sim):
271        p = sim.Population(4, sim.IF_cond_exp(tau_m=12.3,
272                                              tau_syn_E=[0.987, 0.988, 0.989, 0.990],
273                                              tau_syn_I=lambda i: 0.5 + 0.1 * i))
274        pv = p[0, 1, 3]
275        tau_syn_E, tau_m, tau_syn_I = pv.get(('tau_syn_E', 'tau_m', 'tau_syn_I'), gather=True)
276        self.assertIsInstance(tau_m, float)
277        self.assertIsInstance(tau_syn_E, np.ndarray)
278        assert_array_equal(tau_syn_E, np.array([0.987, 0.988, 0.990]))
279        self.assertAlmostEqual(tau_m, 12.3)
280        assert_array_almost_equal(tau_syn_I, np.array([0.5, 0.6, 0.8]), decimal=12)
281
282    # def test_get_multiple_params_no_gather(self, sim=sim):
283
284    def test_get_sequence_param(self, sim=sim):
285        p = sim.Population(3, sim.SpikeSourceArray,
286                           {'spike_times': [Sequence([1, 2, 3, 4]),
287                                            Sequence([2, 3, 4, 5]),
288                                            Sequence([3, 4, 5, 6])]})
289        pv = p[1:]
290        spike_times = pv.get('spike_times')
291        self.assertEqual(spike_times.size, 2)
292        assert_array_equal(spike_times[1], Sequence([3, 4, 5, 6]))
293
294    def test_set(self, sim=sim):
295        p = sim.Population(4, sim.IF_cond_exp, {
296                           'tau_m': 12.3, 'tau_syn_E': 0.987, 'tau_syn_I': 0.7})
297        pv = p[:3]
298        rng = MockRNG(start=1.21, delta=0.01, parallel_safe=True)
299        pv.set(tau_syn_E=random.RandomDistribution('uniform', (0.8, 1.2), rng=rng), tau_m=9.87)
300        tau_m, tau_syn_E, tau_syn_I = p.get(('tau_m', 'tau_syn_E', 'tau_syn_I'), gather=True)
301        assert_array_equal(tau_syn_E, np.array([1.21, 1.22, 1.23, 0.987]))
302        assert_array_almost_equal(tau_m, np.array([9.87, 9.87, 9.87, 12.3]))
303        assert_array_equal(tau_syn_I, 0.7 * np.ones((4,)))
304
305        tau_m, tau_syn_E, tau_syn_I = pv.get(('tau_m', 'tau_syn_E', 'tau_syn_I'), gather=True)
306        assert_array_equal(tau_syn_E, np.array([1.21, 1.22, 1.23]))
307        assert_array_almost_equal(tau_m, np.array([9.87, 9.87, 9.87]))
308        assert_array_equal(tau_syn_I, 0.7 * np.ones((3,)))
309
310    def test_set_invalid_name(self, sim=sim):
311        p = sim.Population(9, sim.HH_cond_exp())
312        pv = p[3:5]
313        self.assertRaises(errors.NonExistentParameterError, pv.set, foo=13.2)
314
315    def test_set_invalid_type(self, sim=sim):
316        p = sim.Population(9, sim.IF_cond_exp())
317        pv = p[::3]
318        self.assertRaises(errors.InvalidParameterValueError, pv.set, tau_m={})
319        self.assertRaises(errors.InvalidParameterValueError, pv.set, v_reset='bar')
320
321    def test_set_sequence(self, sim=sim):
322        p = sim.Population(5, sim.SpikeSourceArray())
323        pv = p[0, 2, 4]
324        pv.set(spike_times=[Sequence([1, 2, 3, 4]),
325                            Sequence([2, 3, 4, 5]),
326                            Sequence([3, 4, 5, 6])])
327        spike_times = p.get('spike_times', gather=True)
328        self.assertEqual(spike_times.size, 5)
329        assert_array_equal(spike_times[1], Sequence([]))
330        assert_array_equal(spike_times[2], Sequence([2, 3, 4, 5]))
331
332    def test_set_array(self, sim=sim):
333        p = sim.Population(5, sim.IF_cond_exp, {'v_thresh': -54.3})
334        pv = p[2:]
335        pv.set(v_thresh=-50.0 + np.arange(3))
336        assert_array_equal(p.get('v_thresh', gather=True),
337                           np.array([-54.3, -54.3, -50.0, -49.0, -48.0]))
338
339    def test_tset(self, sim=sim):
340        p = sim.Population(17, sim.IF_cond_alpha())
341        pv = p[::4]
342        pv.set = Mock()
343        tau_m = np.linspace(10.0, 20.0, num=pv.size)
344        pv.tset("tau_m", tau_m)
345        pv.set.assert_called_with(tau_m=tau_m)
346
347    def test_rset(self, sim=sim):
348        p = sim.Population(17, sim.IF_cond_alpha())
349        pv = p[::4]
350        pv.set = Mock()
351        v_rest = random.RandomDistribution('uniform', low=-70.0, high=-60.0)
352        pv.rset("v_rest", v_rest)
353        pv.set.assert_called_with(v_rest=v_rest)
354
355    # def test_set_with_native_rng():
356
357    # def test_initialize(self, sim=sim):
358    #    p = sim.Population(7, sim.EIF_cond_exp_isfa_ista,
359    #                       initial_values={'v': -65.4, 'w': 0.0})
360    #    pv = p[::2]
361    #
362    #    v_init = np.linspace(-70.0, -67.0, num=pv.size)
363    #    w_init = 0.1
364    #    pv.initialize(v=v_init, w=w_init)
365    #    assert_array_equal(p.initial_values['v'].evaluate(simplify=True),
366    #                       np.array([-70.0, -65.4, -69.0, -65.4, -68.0, -65.4, -67.0]))
367    #    assert_array_equal(p.initial_values['w'].evaluate(simplify=True),
368    #                       np.array([0.1, 0.0, 0.1, 0.0, 0.1, 0.0, 0.1]))
369    #    # should call p.record(('v', 'w')) and check that the recorded data starts with the initial value
370
371    def test_can_record(self, sim=sim):
372        pv = sim.Population(17, sim.EIF_cond_exp_isfa_ista())[::2]
373        assert pv.can_record('v')
374        assert pv.can_record('w')
375        assert pv.can_record('gsyn_inh')
376        assert pv.can_record('spikes')
377        assert not pv.can_record('foo')
378
379    def test_record_with_single_variable(self, sim=sim):
380        p = sim.Population(14, sim.EIF_cond_exp_isfa_ista())
381        pv = p[0, 4, 6, 13]
382        pv.record('v')
383        sim.run(12.3)
384        data = p.get_data(gather=True).segments[0]
385        self.assertEqual(len(data.analogsignals), 1)
386        n_values = int(round(12.3 / sim.get_time_step())) + 1
387        self.assertEqual(data.analogsignals[0].name, 'v')
388        self.assertEqual(data.analogsignals[0].shape, (n_values, pv.size))
389
390    def test_record_with_multiple_variables(self, sim=sim):
391        p = sim.Population(4, sim.EIF_cond_exp_isfa_ista())
392        pv = p[0, 3]
393        pv.record(('v', 'w', 'gsyn_exc'))
394        sim.run(10.0)
395        data = p.get_data(gather=True).segments[0]
396        self.assertEqual(len(data.analogsignals), 3)
397        n_values = int(round(10.0 / sim.get_time_step())) + 1
398        names = set(arr.name for arr in data.analogsignals)
399        self.assertEqual(names, set(('v', 'w', 'gsyn_exc')))
400        for arr in data.analogsignals:
401            self.assertEqual(arr.shape, (n_values, pv.size))
402
403    def test_record_with_v_spikes(self, sim=sim):
404        p = sim.Population(4, sim.EIF_cond_exp_isfa_ista())
405        pv = p[0, 3]
406        pv.record(('v', 'spikes'))
407        sim.run(10.0)
408        data = p.get_data(gather=True).segments[0]
409        self.assertEqual(len(data.analogsignals), 1)
410        n_values = int(round(10.0 / sim.get_time_step())) + 1
411        names = set(arr.name for arr in data.analogsignals)
412        self.assertEqual(names, set(('v')))
413        for arr in data.analogsignals:
414            self.assertEqual(arr.shape, (n_values, pv.size))
415
416    def test_record_v(self, sim=sim):
417        pv = sim.Population(2, sim.EIF_cond_exp_isfa_ista())[0:1]
418        pv.record = Mock()
419        pv.record_v("arg1")
420        pv.record.assert_called_with('v', "arg1")
421
422    def test_record_gsyn(self, sim=sim):
423        pv = sim.Population(2, sim.EIF_cond_exp_isfa_ista())[1:]
424        pv.record = Mock()
425        pv.record_gsyn("arg1")
426        pv.record.assert_called_with(['gsyn_exc', 'gsyn_inh'], "arg1")
427
428    def test_record_invalid_variable(self, sim=sim):
429        pv = sim.Population(14, sim.IF_curr_alpha())[::3]
430        self.assertRaises(errors.RecordingError,
431                          pv.record, ('v', 'gsyn_exc'))  # can't record gsyn_exc from this celltype
432
433    # def test_write_data(self, sim=sim):
434    #    self.fail()
435    #
436
437    def test_get_data_with_gather(self, sim=sim):
438        t1 = 12.3
439        t2 = 13.4
440        t3 = 14.5
441        p = sim.Population(14, sim.EIF_cond_exp_isfa_ista())
442        pv = p[::3]
443        pv.record('v')
444        sim.run(t1)
445        # what if we call p.record between two run statements?
446        # would be nice to get an AnalogSignal with a non-zero t_start
447        # but then need to make sure we get the right initial value
448        sim.run(t2)
449        sim.reset()
450        pv.record('spikes')
451        pv.record('w')
452        sim.run(t3)
453        data = p.get_data(gather=True)
454        self.assertEqual(len(data.segments), 2)
455
456        seg0 = data.segments[0]
457        self.assertEqual(len(seg0.analogsignals), 1)
458        v = seg0.analogsignals[0]
459        self.assertEqual(v.name, 'v')
460        num_points = int(round((t1 + t2) / sim.get_time_step())) + 1
461        self.assertEqual(v.shape, (num_points, pv.size))
462        self.assertEqual(v.t_start, 0.0 * pq.ms)
463        self.assertEqual(v.units, pq.mV)
464        self.assertEqual(v.sampling_period, 0.1 * pq.ms)
465        self.assertEqual(len(seg0.spiketrains), 0)
466
467        seg1 = data.segments[1]
468        self.assertEqual(len(seg1.analogsignals), 2)
469        w = seg1.filter(name='w')[0]
470        self.assertEqual(w.name, 'w')
471        num_points = int(round(t3 / sim.get_time_step())) + 1
472        self.assertEqual(w.shape, (num_points, pv.size))
473        self.assertEqual(v.t_start, 0.0)
474        self.assertEqual(len(seg1.spiketrains), pv.size)
475
476    def test_get_data_with_gather(self, sim=sim):
477        t1 = 12.3
478        t2 = 13.4
479        t3 = 14.5
480        p = sim.Population(14, sim.EIF_cond_exp_isfa_ista())
481        pv = p[::3]
482        pv.record('v')
483        sim.run(t1)
484        # what if we call p.record between two run statements?
485        # would be nice to get an AnalogSignal with a non-zero t_start
486        # but then need to make sure we get the right initial value
487        sim.run(t2)
488        sim.reset()
489        pv.record('spikes')
490        pv.record('w')
491        sim.run(t3)
492        data = p.get_data(gather=True)
493        self.assertEqual(len(data.segments), 2)
494
495        seg0 = data.segments[0]
496        self.assertEqual(len(seg0.analogsignals), 1)
497        self.assertEqual(len(seg0.spiketrains), 0)
498
499        seg1 = data.segments[1]
500        self.assertEqual(len(seg1.analogsignals), 2)
501        self.assertEqual(len(seg1.spiketrains), pv.size)
502        assert_array_equal(seg1.spiketrains[2],
503                           np.array([p.first_id + 6, p.first_id + 6 + 5]) % t3)
504
505    # def test_get_data_no_gather(self, sim=sim):
506    #    self.fail()
507
508    def test_get_spike_counts(self, sim=sim):
509        p = sim.Population(5, sim.EIF_cond_exp_isfa_ista())
510        pv = p[0, 1, 4]
511        pv.record('spikes')
512        sim.run(100.0)
513        self.assertEqual(p.get_spike_counts(),
514                         {p.all_cells[0]: 2,
515                          p.all_cells[1]: 2,
516                          p.all_cells[4]: 2})
517
518    def test_mean_spike_count(self, sim=sim):
519        p = sim.Population(14, sim.EIF_cond_exp_isfa_ista())
520        pv = p[2::3]
521        pv.record('spikes')
522        sim.run(100.0)
523        self.assertEqual(p.mean_spike_count(), 2.0)
524
525    # def test_mean_spike_count_on_slave_node():
526
527    def test_inject(self, sim=sim):
528        pv = sim.Population(3, sim.IF_curr_alpha())[1, 2]
529        cs = Mock()
530        pv.inject(cs)
531        meth, args, kwargs = cs.method_calls[0]
532        self.assertEqual(meth, "inject_into")
533        self.assertEqual(args, (pv,))
534
535    def test_inject_into_invalid_celltype(self, sim=sim):
536        pv = sim.Population(3, sim.SpikeSourceArray())[:2]
537        self.assertRaises(TypeError, pv.inject, Mock())
538
539    # def test_save_positions(self, sim=sim):
540    #    self.fail()
541
542    # test describe method
543
544    def test_describe(self, sim=sim):
545        pv = sim.Population(11, sim.IF_cond_exp())[::4]
546        self.assertIsInstance(pv.describe(), str)
547        self.assertIsInstance(pv.describe(template=None), dict)
548
549    def test_index_in_grandparent(self, sim=sim):
550        pv1 = sim.Population(11, sim.IF_cond_exp())[0, 1, 3, 4, 6, 7, 9]
551        pv2 = pv1[2, 3, 5, 6]
552        assert_array_equal(pv1.index_in_grandparent([2, 4, 6]), np.array([3, 6, 9]))
553        assert_array_equal(pv2.index_in_grandparent([0, 1, 3]), np.array([3, 4, 9]))
554
555    def test_index_from_parent_index(self, sim=sim):
556        parent = sim.Population(20, sim.IF_cond_exp())
557
558        # test with slice mask
559        pv1 = parent[2:16:3]
560        assert_array_equal(
561            pv1.index_from_parent_index(np.array([2, 5, 8, 11, 14])),
562            np.array([0, 1, 2, 3, 4])
563        )
564        self.assertEqual(pv1.index_from_parent_index(11), 3)
565
566        # test with array mask
567        pv2 = parent[np.array([1, 2, 3, 5, 8, 13])]
568        assert_array_equal(
569            pv2.index_from_parent_index(np.array([2, 5, 13])),
570            np.array([1, 3, 5])
571        )
572
573    def test_save_positions(self, sim=sim):
574        import os
575        p = sim.Population(7, sim.IF_cond_exp())
576        p.positions = np.arange(15, 36).reshape((7, 3)).T
577        pv = p[2, 4, 5]
578        output_file = Mock()
579        pv.save_positions(output_file)
580        assert_array_equal(output_file.write.call_args[0][0],
581                           np.array([[0, 21, 22, 23],
582                                        [1, 27, 28, 29],
583                                        [2, 30, 31, 32]]))
584        self.assertEqual(output_file.write.call_args[0][1], {'population': pv.label})
585
586
587if __name__ == "__main__":
588    unittest.main()
589