1# -*- coding: utf-8 -*-
2#
3# test_connect_all_to_all.py
4#
5# This file is part of NEST.
6#
7# Copyright (C) 2004 The NEST Initiative
8#
9# NEST is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 2 of the License, or
12# (at your option) any later version.
13#
14# NEST is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with NEST.  If not, see <http://www.gnu.org/licenses/>.
21
22
23import unittest
24import numpy as np
25import scipy.stats
26import connect_test_base
27import nest
28
29
30@nest.ll_api.check_stack
31class TestAllToAll(connect_test_base.ConnectTestBase):
32
33    # specify connection pattern
34    rule = 'all_to_all'
35    conn_dict = {'rule': rule}
36    # sizes of populations
37    N1 = 6
38    N2 = 7
39    N1_array = 500
40    N2_array = 10
41
42    def testConnectivity(self):
43        self.setUpNetwork(self.conn_dict)
44        # make sure all connections do exist
45        M = connect_test_base.get_connectivity_matrix(self.pop1, self.pop2)
46        M_all = np.ones((len(self.pop2), len(self.pop1)))
47        connect_test_base.mpi_assert(M, M_all, self)
48        # make sure no connections were drawn from the target to the source
49        # population
50        M = connect_test_base.get_connectivity_matrix(self.pop2, self.pop1)
51        M_none = np.zeros((len(self.pop1), len(self.pop2)))
52        connect_test_base.mpi_assert(M, M_none, self)
53
54    def testInputArray(self):
55        for label in ['weight', 'delay']:
56            syn_params = {}
57            if label == 'weight':
58                self.param_array = np.arange(
59                    self.N1_array * self.N2_array, dtype=float
60                ).reshape(self.N2_array, self.N1_array)
61            elif label == 'delay':
62                self.param_array = np.arange(
63                    1, self.N1_array * self.N2_array + 1
64                ).reshape(self.N2_array, self.N1_array) * 0.1
65            syn_params[label] = self.param_array
66            nest.ResetKernel()
67            self.setUpNetwork(self.conn_dict, syn_params,
68                              N1=self.N1_array, N2=self.N2_array)
69            M_nest = connect_test_base.get_weighted_connectivity_matrix(
70                self.pop1, self.pop2, label)
71            connect_test_base.mpi_assert(M_nest, self.param_array, self)
72
73    def testInputArrayWithoutAutapses(self):
74        self.conn_dict['allow_autapses'] = False
75        for label in ['weight', 'delay']:
76            syn_params = {}
77            if label == 'weight':
78                self.param_array = np.arange(
79                    self.N1 * self.N1, dtype=float).reshape(self.N1, self.N1)
80            elif label == 'delay':
81                self.param_array = np.arange(
82                    1, self.N1 * self.N1 + 1).reshape(self.N1, self.N1) * 0.1
83            syn_params[label] = self.param_array
84            self.setUpNetworkOnePop(self.conn_dict, syn_params)
85            M_nest = connect_test_base.get_weighted_connectivity_matrix(
86                self.pop, self.pop, label)
87            np.fill_diagonal(self.param_array, 0)
88            connect_test_base.mpi_assert(M_nest, self.param_array, self)
89
90    def testInputArrayRPort(self):
91        syn_params = {}
92        neuron_model = 'iaf_psc_exp_multisynapse'
93        neuron_dict = {'tau_syn': [0.1 + i for i in range(self.N2)]}
94        self.pop1 = nest.Create(neuron_model, self.N1)
95        self.pop2 = nest.Create(neuron_model, self.N2, neuron_dict)
96        self.param_array = np.transpose(np.asarray(
97            [np.arange(1, self.N2 + 1) for i in range(self.N1)]))
98        syn_params['receptor_type'] = self.param_array
99        nest.Connect(self.pop1, self.pop2, self.conn_dict, syn_params)
100        M = connect_test_base.get_weighted_connectivity_matrix(
101            self.pop1, self.pop2, 'receptor')
102        connect_test_base.mpi_assert(M, self.param_array, self)
103
104    def testInputArrayToStdpSynapse(self):
105        params = ['Wmax', 'alpha', 'lambda', 'mu_minus', 'mu_plus', 'tau_plus']
106        syn_params = {'synapse_model': 'stdp_synapse'}
107        values = [
108            np.arange(self.N1 * self.N2, dtype=float).reshape(self.N2, self.N1)
109            for i in range(6)
110        ]
111        for i, param in enumerate(params):
112            syn_params[param] = values[i]
113        self.setUpNetwork(self.conn_dict, syn_params)
114        for i, param in enumerate(params):
115            a = connect_test_base.get_weighted_connectivity_matrix(
116                self.pop1, self.pop2, param)
117            connect_test_base.mpi_assert(a, values[i], self)
118
119    # test single threaded for now
120    def testRPortDistribution(self):
121        n_rport = 10
122        nr_neurons = 100
123        nest.ResetKernel()  # To reset local_num_threads
124        neuron_model = 'iaf_psc_exp_multisynapse'
125        neuron_dict = {'tau_syn': [0.1 + i for i in range(n_rport)]}
126        self.pop1 = nest.Create(neuron_model, nr_neurons, neuron_dict)
127        self.pop2 = nest.Create(neuron_model, nr_neurons, neuron_dict)
128        syn_params = {'synapse_model': 'static_synapse'}
129        syn_params['receptor_type'] = 1 + nest.random.uniform_int(n_rport)
130        nest.Connect(self.pop1, self.pop2, self.conn_dict, syn_params)
131        M = connect_test_base.get_weighted_connectivity_matrix(
132            self.pop1, self.pop2, 'receptor')
133        M = connect_test_base.gather_data(M)
134        if M is not None:
135            M = M.flatten()
136            frequencies = scipy.stats.itemfreq(M)
137            self.assertTrue(np.array_equal(frequencies[:, 0], np.arange(
138                1, n_rport + 1)), 'Missing or invalid rports')
139            chi, p = scipy.stats.chisquare(frequencies[:, 1])
140            self.assertGreater(p, self.pval)
141
142
143def suite():
144    suite = unittest.TestLoader().loadTestsFromTestCase(TestAllToAll)
145    return suite
146
147
148def run():
149    runner = unittest.TextTestRunner(verbosity=2)
150    runner.run(suite())
151
152
153if __name__ == '__main__':
154    run()
155