1# -*- coding: utf-8 -*-
2#
3# connect_test_base.py
4#
5# This file is part of NEST.
6#
7# Copyright (C) 2004 The NEST Initiative
8#
9# NEST is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 2 of the License, or
12# (at your option) any later version.
13#
14# NEST is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with NEST.  If not, see <http://www.gnu.org/licenses/>.
21
22import numpy as np
23import scipy.stats
24import nest
25import unittest
26
27try:
28    from mpi4py import MPI
29    haveMPI4Py = True
30except ImportError:
31    haveMPI4Py = False
32
33
34class ConnectTestBase(unittest.TestCase):
35    """
36    Base class for connection tests.
37
38    This class provides overall setup methods and a range of tests
39    that apply to all connection rules. The tests are not to be
40    run from this class, but from derived classes, one per rule.
41    """
42
43    # Setting default parameter. These parameter might be overwritten
44    # by the classes testing one specific rule.
45    # We force subclassing by setting rule to None here and provide default
46    # values for other parameters.
47    rule = None
48    conn_dict = {'rule': rule}
49    # sizes of populations
50    N1 = 6
51    N2 = 6
52    # time step
53    dt = 0.1
54    # default params
55    w0 = 1.0
56    d0 = 1.0
57    r0 = 0
58    syn0 = 'static_synapse'
59    # parameter for test of distributed parameter
60    pval = 0.05  # minimum p-value to pass kolmogorov smirnov test
61    # number of threads
62    nr_threads = 2
63
64    # for now only tests if a multi-thread connect is successfull, not whether
65    # the the threading is actually used
66    def setUp(self):
67        nest.ResetKernel()
68        nest.local_num_threads = self.nr_threads
69
70    def setUpNetwork(self, conn_dict=None, syn_dict=None, N1=None, N2=None):
71        if N1 is None:
72            N1 = self.N1
73        if N2 is None:
74            N2 = self.N2
75        self.pop1 = nest.Create('iaf_psc_alpha', N1)
76        self.pop2 = nest.Create('iaf_psc_alpha', N2)
77        nest.set_verbosity('M_FATAL')
78        nest.Connect(self.pop1, self.pop2, conn_dict, syn_dict)
79
80    def setUpNetworkOnePop(self, conn_dict=None, syn_dict=None, N=None):
81        if N is None:
82            N = self.N1
83        self.pop = nest.Create('iaf_psc_alpha', N)
84        nest.set_verbosity('M_FATAL')
85        nest.Connect(self.pop, self.pop, conn_dict, syn_dict)
86
87    def testWeightSetting(self):
88        # test if weights are set correctly
89
90        # one weight for all connections
91        w0 = 0.351
92        label = 'weight'
93        syn_params = {label: w0}
94        check_synapse([label], [syn_params['weight']], syn_params, self)
95
96    def testDelaySetting(self):
97        # test if delays are set correctly
98
99        # one delay for all connections
100        d0 = 0.275
101        syn_params = {'delay': d0}
102        self.setUpNetwork(self.conn_dict, syn_params)
103        connections = nest.GetConnections(self.pop1, self.pop2)
104        nest_delays = connections.get('delay')
105        # all delays need to be equal
106        self.assertTrue(all_equal(nest_delays))
107        # delay (rounded) needs to equal the delay that was put in
108        self.assertTrue(abs(d0 - nest_delays[0]) < self.dt)
109
110    def testRPortSetting(self):
111        neuron_model = 'iaf_psc_exp_multisynapse'
112        neuron_dict = {'tau_syn': [0.5, 0.7]}
113        rtype = 2
114        self.pop1 = nest.Create(neuron_model, self.N1, neuron_dict)
115        self.pop2 = nest.Create(neuron_model, self.N2, neuron_dict)
116        syn_params = {'synapse_model': 'static_synapse',
117                      'receptor_type': rtype}
118        nest.Connect(self.pop1, self.pop2, self.conn_dict, syn_params)
119        conns = nest.GetConnections(self.pop1, self.pop2)
120        ports = conns.get('receptor')
121        self.assertTrue(all_equal(ports))
122        self.assertTrue(ports[0] == rtype)
123
124    def testSynapseSetting(self):
125        nest.CopyModel("static_synapse", 'test_syn', {'receptor_type': 0})
126        syn_params = {'synapse_model': 'test_syn'}
127        self.setUpNetwork(self.conn_dict, syn_params)
128        conns = nest.GetConnections(self.pop1, self.pop2)
129        syns = conns.get('synapse_model')
130        self.assertTrue(all_equal(syns))
131        self.assertTrue(syns[0] == syn_params['synapse_model'])
132
133    # tested on each mpi process separatly
134    def testDefaultParams(self):
135        self.setUpNetwork(self.conn_dict)
136        conns = nest.GetConnections(self.pop1, self.pop2)
137        self.assertTrue(all(x == self.w0 for x in conns.get('weight')))
138        self.assertTrue(all(x == self.d0 for x in conns.get('delay')))
139        self.assertTrue(all(x == self.r0 for x in conns.get('receptor')))
140        self.assertTrue(all(x == self.syn0 for
141                            x in conns.get('synapse_model')))
142
143    def testAutapsesTrue(self):
144        conn_params = self.conn_dict.copy()
145
146        # test that autapses exist
147        conn_params['allow_autapses'] = True
148        self.pop1 = nest.Create('iaf_psc_alpha', self.N1)
149        nest.Connect(self.pop1, self.pop1, conn_params)
150        # make sure all connections do exist
151        M = get_connectivity_matrix(self.pop1, self.pop1)
152        mpi_assert(np.diag(M), np.ones(self.N1), self)
153
154    def testAutapsesFalse(self):
155        conn_params = self.conn_dict.copy()
156
157        # test that autapses were excluded
158        conn_params['allow_autapses'] = False
159        self.pop1 = nest.Create('iaf_psc_alpha', self.N1)
160        nest.Connect(self.pop1, self.pop1, conn_params)
161        # make sure all connections do exist
162        M = get_connectivity_matrix(self.pop1, self.pop1)
163        mpi_assert(np.diag(M), np.zeros(self.N1), self)
164
165    def testHtSynapse(self):
166        params = ['P', 'delta_P']
167        values = [0.987, 0.362]
168        syn_params = {'synapse_model': 'ht_synapse'}
169        check_synapse(params, values, syn_params, self)
170
171    def testQuantalStpSynapse(self):
172        params = ['U', 'tau_fac', 'tau_rec', 'u', 'a', 'n']
173        values = [0.679, 8.45, 746.2, 0.498, 10, 5]
174        syn_params = {'synapse_model': 'quantal_stp_synapse'}
175        check_synapse(params, values, syn_params, self)
176
177    def testStdpFacetshwSynapseHom(self):
178        params = ['a_acausal', 'a_causal', 'a_thresh_th', 'a_thresh_tl',
179                  'next_readout_time'
180                  ]
181        values = [0.162, 0.263, 20.46, 19.83, 0.1]
182        syn_params = {'synapse_model': 'stdp_facetshw_synapse_hom'}
183        check_synapse(params, values, syn_params, self)
184
185    def testStdpPlSynapseHom(self):
186        params = ['Kplus']
187        values = [0.173]
188        syn_params = {'synapse_model': 'stdp_pl_synapse_hom'}
189        check_synapse(params, values, syn_params, self)
190
191    def testStdpSynapseHom(self):
192        params = ['Kplus']
193        values = [0.382]
194        syn_params = {'synapse_model': 'stdp_synapse_hom'}
195        check_synapse(params, values, syn_params, self)
196
197    def testStdpSynapse(self):
198        params = ['Wmax', 'alpha', 'lambda', 'mu_minus', 'mu_plus', 'tau_plus']
199        values = [98.34, 0.945, 0.02, 0.945, 1.26, 19.73]
200        syn_params = {'synapse_model': 'stdp_synapse'}
201        check_synapse(params, values, syn_params, self)
202
203    def testTsodyks2Synapse(self):
204        params = ['U', 'tau_fac', 'tau_rec', 'u', 'x']
205        values = [0.362, 0.152, 789.2, 0.683, 0.945]
206        syn_params = {'synapse_model': 'tsodyks2_synapse'}
207        check_synapse(params, values, syn_params, self)
208
209    def testTsodyksSynapse(self):
210        params = ['U', 'tau_fac', 'tau_psc', 'tau_rec', 'x', 'y', 'u']
211        values = [0.452, 0.263, 2.56, 801.34, 0.567, 0.376, 0.102]
212        syn_params = {'synapse_model': 'tsodyks_synapse'}
213        check_synapse(params, values, syn_params, self)
214
215    def testStdpDopamineSynapse(self):
216        # ResetKernel() since parameter setting not thread save for this
217        # synapse type
218        nest.ResetKernel()
219        vol = nest.Create('volume_transmitter')
220        nest.SetDefaults('stdp_dopamine_synapse', {'vt': vol.get('global_id')})
221        params = ['c', 'n']
222        values = [0.153, 0.365]
223        syn_params = {'synapse_model': 'stdp_dopamine_synapse'}
224        check_synapse(params, values, syn_params, self)
225
226    def testRPortAllSynapses(self):
227        syns = ['cont_delay_synapse', 'ht_synapse', 'quantal_stp_synapse',
228                'static_synapse_hom_w', 'stdp_dopamine_synapse',
229                'stdp_facetshw_synapse_hom', 'stdp_pl_synapse_hom',
230                'stdp_synapse_hom', 'stdp_synapse', 'tsodyks2_synapse',
231                'tsodyks_synapse'
232                ]
233        syn_params = {'receptor_type': 1}
234
235        for i, syn in enumerate(syns):
236            if syn == 'stdp_dopamine_synapse':
237                vol = nest.Create('volume_transmitter')
238                nest.SetDefaults('stdp_dopamine_synapse', {'vt': vol.get('global_id')})
239            syn_params['synapse_model'] = syn
240            self.pop1 = nest.Create('iaf_psc_exp_multisynapse', self.N1, {
241                                       'tau_syn': [0.2, 0.5]})
242            self.pop2 = nest.Create('iaf_psc_exp_multisynapse', self.N2, {
243                                       'tau_syn': [0.2, 0.5]})
244            nest.Connect(self.pop1, self.pop2, self.conn_dict, syn_params)
245            conns = nest.GetConnections(self.pop1, self.pop2)
246            conn_params = conns.get('receptor')
247            self.assertTrue(all_equal(conn_params))
248            self.assertTrue(conn_params[0] == syn_params['receptor_type'])
249            self.setUp()
250
251    def testWeightAllSynapses(self):
252        # test all synapses apart from static_synapse_hom_w where weight is not
253        # settable
254        syns = ['cont_delay_synapse', 'ht_synapse', 'quantal_stp_synapse',
255                'stdp_dopamine_synapse',
256                'stdp_facetshw_synapse_hom',
257                'stdp_pl_synapse_hom',
258                'stdp_synapse_hom', 'stdp_synapse', 'tsodyks2_synapse',
259                'tsodyks_synapse'
260                ]
261        syn_params = {'weight': 0.372}
262
263        for syn in syns:
264            if syn == 'stdp_dopamine_synapse':
265                vol = nest.Create('volume_transmitter')
266                nest.SetDefaults('stdp_dopamine_synapse', {'vt': vol.get('global_id')})
267            syn_params['synapse_model'] = syn
268            check_synapse(
269                ['weight'], [syn_params['weight']], syn_params, self)
270            self.setUp()
271
272    def testDelayAllSynapses(self):
273        syns = ['cont_delay_synapse',
274                'ht_synapse', 'quantal_stp_synapse',
275                'static_synapse_hom_w',
276                'stdp_dopamine_synapse',
277                'stdp_facetshw_synapse_hom', 'stdp_pl_synapse_hom',
278                'stdp_synapse_hom', 'stdp_synapse', 'tsodyks2_synapse',
279                'tsodyks_synapse'
280                ]
281        syn_params = {'delay': 0.4}
282
283        for syn in syns:
284            if syn == 'stdp_dopamine_synapse':
285                vol = nest.Create('volume_transmitter')
286                nest.SetDefaults('stdp_dopamine_synapse', {'vt': vol.get('global_id')})
287            syn_params['synapse_model'] = syn
288            check_synapse(
289                ['delay'], [syn_params['delay']], syn_params, self)
290            self.setUp()
291
292
293def gather_data(data_array):
294    '''
295    Gathers data from all mpi processes by collecting all element in a list if
296    data is a list and summing all elements to one numpy-array if data is one
297    numpy-array. Returns gathered data if rank of current mpi node is zero and
298    None otherwise.
299
300    '''
301    if haveMPI4Py:
302        data_array_list = MPI.COMM_WORLD.gather(data_array, root=0)
303        if MPI.COMM_WORLD.Get_rank() == 0:
304            if isinstance(data_array, list):
305                gathered_data = [
306                    item for sublist in data_array_list for item in sublist]
307            else:
308                gathered_data = sum(data_array_list)
309            return gathered_data
310        else:
311            return None
312    else:
313        return data_array
314
315
316def bcast_data(data):
317    """
318    Broadcasts data from the root MPI node to all other nodes.
319    """
320    if haveMPI4Py:
321        data = MPI.COMM_WORLD.bcast(data, root=0)
322    return data
323
324
325def is_array(data):
326    '''
327    Returns True if data is a list or numpy-array and False otherwise.
328    '''
329    return isinstance(data, (list, np.ndarray, np.generic))
330
331
332def mpi_barrier():
333    if haveMPI4Py:
334        MPI.COMM_WORLD.Barrier()
335
336
337def mpi_assert(data_original, data_test, TestCase):
338    '''
339    Compares data_original and data_test using assertTrue from the TestCase.
340    '''
341
342    data_original = gather_data(data_original)
343    # only test if on rank 0
344    if data_original is not None:
345        if isinstance(data_original, (np.ndarray, np.generic)) \
346           and isinstance(data_test, (np.ndarray, np.generic)):
347            TestCase.assertTrue(np.allclose(data_original, data_test))
348        else:
349            TestCase.assertTrue(data_original == data_test)
350
351
352def all_equal(x):
353    '''
354    Tests if all elements in a list are equal.
355    Returns True or False
356    '''
357    return x.count(x[0]) == len(x)
358
359
360def get_connectivity_matrix(pop1, pop2):
361    '''
362    Returns a connectivity matrix describing all connections from pop1 to pop2
363    such that M_ij describes the connection between the jth neuron in pop1 to
364    the ith neuron in pop2.
365    '''
366
367    M = np.zeros((len(pop2), len(pop1)))
368    connections = nest.GetConnections(pop1, pop2)
369    index_dic = {}
370    for count, node in enumerate(pop1):
371        index_dic[node.get('global_id')] = count
372    for count, node in enumerate(pop2):
373        index_dic[node.get('global_id')] = count
374    for source, target in zip(connections.sources(), connections.targets()):
375        M[index_dic[target]][index_dic[source]] += 1
376    return M
377
378
379def get_weighted_connectivity_matrix(pop1, pop2, label):
380    '''
381    Returns a weighted connectivity matrix describing all connections from
382    pop1 to pop2 such that M_ij describes the connection between the jth
383    neuron in pop1 to the ith neuron in pop2. Only works without multapses.
384    '''
385
386    M = np.zeros((len(pop2), len(pop1)))
387    connections = nest.GetConnections(pop1, pop2)
388    sources = connections.get('source')
389    targets = connections.get('target')
390    weights = connections.get(label)
391    index_dic = {}
392    for count, node in enumerate(pop1):
393        index_dic[node.get('global_id')] = count
394    for count, node in enumerate(pop2):
395        index_dic[node.get('global_id')] = count
396    for counter, weight in enumerate(weights):
397        source_id = sources[counter]
398        target_id = targets[counter]
399        M[index_dic[target_id]][index_dic[source_id]] += weight
400    return M
401
402
403def check_synapse(params, values, syn_params, TestCase):
404    for i, param in enumerate(params):
405        syn_params[param] = values[i]
406    TestCase.setUpNetwork(TestCase.conn_dict, syn_params)
407    for i, param in enumerate(params):
408        conns = nest.GetConnections(TestCase.pop1, TestCase.pop2)
409        conn_params = conns.get(param)
410        TestCase.assertTrue(all_equal(conn_params))
411        TestCase.assertTrue(conn_params[0] == values[i])
412
413# copied from Masterthesis, Daniel Hjertholm
414
415
416def counter(x, fan, source_pop, target_pop):
417    '''
418    Count similar elements in list.
419
420    Parameters
421    ----------
422        x: Any list.
423
424    Return values
425    -------------
426        list containing counts of similar elements.
427    '''
428
429    N_p = len(source_pop) if fan == 'in' else len(target_pop)  # of pool nodes.
430    start = min(x)
431    counts = [0] * N_p
432    for elem in x:
433        counts[elem - start] += 1
434
435    return counts
436
437
438def get_degrees(fan, pop1, pop2):
439    M = get_connectivity_matrix(pop1, pop2)
440    if fan == 'in':
441        degrees = np.sum(M, axis=1)
442    elif fan == 'out':
443        degrees = np.sum(M, axis=0)
444    return degrees
445
446# adapted from Masterthesis, Daniel Hjertholm
447
448
449def get_expected_degrees_fixedDegrees(N, fan, len_source_pop, len_target_pop):
450    N_d = len_target_pop if fan == 'in' else len_source_pop  # of driver nodes.
451    N_p = len_source_pop if fan == 'in' else len_target_pop  # of pool nodes.
452    expected_degree = N_d * N / float(N_p)
453    expected = [expected_degree] * N_p
454    return expected
455
456# adapted from Masterthesis, Daniel Hjertholm
457
458
459def get_expected_degrees_totalNumber(N, fan, len_source_pop, len_target_pop):
460    expected_indegree = [N / float(len_target_pop)] * len_target_pop
461    expected_outdegree = [N / float(len_source_pop)] * len_source_pop
462    if fan == 'in':
463        return expected_indegree
464    elif fan == 'out':
465        return expected_outdegree
466
467# copied from Masterthesis, Daniel Hjertholm
468
469
470def get_expected_degrees_bernoulli(p, fan, len_source_pop, len_target_pop):
471    '''
472    Calculate expected degree distribution.
473
474    Degrees with expected number of observations below e_min are combined
475    into larger bins.
476
477    Return values
478    -------------
479        2D array. The four columns contain degree,
480        expected number of observation, actual number observations, and
481        the number of bins combined.
482    '''
483
484    n = len_source_pop if fan == 'in' else len_target_pop
485    n_p = len_target_pop if fan == 'in' else len_source_pop
486    mid = int(round(n * p))
487    e_min = 5
488
489    # Combine from front.
490    data_front = []
491    cumexp = 0.0
492    bins_combined = 0
493    for degree in range(mid):
494        cumexp += scipy.stats.binom.pmf(degree, n, p) * n_p
495        bins_combined += 1
496        if cumexp < e_min:
497            if degree == mid - 1:
498                if len(data_front) == 0:
499                    raise RuntimeWarning('Not enough data')
500                deg, exp, obs, num = data_front[-1]
501                data_front[-1] = (deg, exp + cumexp, obs,
502                                  num + bins_combined)
503            else:
504                continue
505        else:
506            data_front.append((degree - bins_combined + 1, cumexp, 0,
507                               bins_combined))
508            cumexp = 0.0
509            bins_combined = 0
510
511    # Combine from back.
512    data_back = []
513    cumexp = 0.0
514    bins_combined = 0
515    for degree in reversed(range(mid, n + 1)):
516        cumexp += scipy.stats.binom.pmf(degree, n, p) * n_p
517        bins_combined += 1
518        if cumexp < e_min:
519            if degree == mid:
520                if len(data_back) == 0:
521                    raise RuntimeWarning('Not enough data')
522                deg, exp, obs, num = data_back[-1]
523                data_back[-1] = (degree, exp + cumexp, obs,
524                                 num + bins_combined)
525            else:
526                continue
527        else:
528            data_back.append((degree, cumexp, 0, bins_combined))
529            cumexp = 0.0
530            bins_combined = 0
531    data_back.reverse()
532
533    expected = np.array(data_front + data_back)
534    if fan == 'out':
535        assert (sum(expected[:, 3]) == len_target_pop + 1)
536    else:  # , 'Something is wrong'
537        assert (sum(expected[:, 3]) == len_source_pop + 1)
538
539    # np.hstack((np.asarray(data_front)[0], np.asarray(data_back)[0]))
540    return expected
541
542# adapted from Masterthesis, Daniel Hjertholm
543
544
545def reset_seed(seed, nr_threads):
546    '''
547    Reset the simulator and seed the PRNGs.
548
549    Parameters
550    ----------
551        seed: PRNG seed value.
552    '''
553
554    nest.ResetKernel()
555    nest.local_num_threads = nr_threads
556    nest.rng_seed = seed
557
558# copied from Masterthesis, Daniel Hjertholm
559
560
561def chi_squared_check(degrees, expected, distribution=None):
562    '''
563    Create a single network and compare the resulting degree distribution
564    with the expected distribution using Pearson's chi-squared GOF test.
565
566    Parameters
567    ----------
568        seed   : PRNG seed value.
569        control: Boolean value. If True, _generate_multinomial_degrees will
570                 be used instead of _get_degrees.
571
572    Return values
573    -------------
574        chi-squared statistic.
575        p-value from chi-squared test.
576    '''
577
578    if distribution in ('pairwise_bernoulli', 'symmetric_pairwise_bernoulli'):
579        observed = {}
580        for degree in degrees:
581            if degree not in observed:
582                observed[degree] = 1
583            else:
584                observed[degree] += 1
585        # Add observations to data structure, combining multiple observations
586        # where necessary.
587        expected[:, 2] = 0.0
588        for row in expected:
589            for i in range(int(row[3])):
590                deg = int(row[0]) + i
591                if deg in observed:
592                    row[2] += observed[deg]
593
594        # ddof: adjustment to the degrees of freedom. df = k-1-ddof
595        return scipy.stats.chisquare(np.array(expected[:, 2]),
596                                     np.array(expected[:, 1]))
597    else:
598        # ddof: adjustment to the degrees of freedom. df = k-1-ddof
599        return scipy.stats.chisquare(np.array(degrees), np.array(expected))
600