1# -*- coding: utf-8 -*-
2#
3# test_connect_one_to_one.py
4#
5# This file is part of NEST.
6#
7# Copyright (C) 2004 The NEST Initiative
8#
9# NEST is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 2 of the License, or
12# (at your option) any later version.
13#
14# NEST is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with NEST.  If not, see <http://www.gnu.org/licenses/>.
21
22import numpy as np
23import unittest
24import connect_test_base
25import nest
26
27
28class TestOneToOne(connect_test_base.ConnectTestBase):
29
30    # specify connection pattern
31    rule = 'one_to_one'
32    conn_dict = {'rule': rule}
33    # sizes of populations
34    N = 6
35    N1 = N
36    N2 = N
37    N_array = 1000
38
39    def testConnectivity(self):
40        self.setUpNetwork(self.conn_dict)
41        # make sure all connections do exist
42        M = connect_test_base.get_connectivity_matrix(self.pop1, self.pop2)
43        connect_test_base.mpi_assert(M, np.identity(self.N), self)
44        # make sure no connections were drawn from the target to the source
45        # population
46        M = connect_test_base.get_connectivity_matrix(self.pop2, self.pop1)
47        connect_test_base.mpi_assert(M, np.zeros((self.N, self.N)), self)
48
49    def testSymmetricFlag(self):
50        conn_dict_symmetric = self.conn_dict.copy()
51        conn_dict_symmetric['make_symmetric'] = True
52        self.setUpNetwork(conn_dict_symmetric)
53        M1 = connect_test_base.get_connectivity_matrix(self.pop1, self.pop2)
54        M2 = connect_test_base.get_connectivity_matrix(self.pop2, self.pop1)
55        # test that connections were created in both directions
56        connect_test_base.mpi_assert(M1, np.transpose(connect_test_base.gather_data(M2)), self)
57        # test that no other connections were created
58        connect_test_base.mpi_assert(M1, np.zeros_like(M1) + np.identity(self.N), self)
59
60    def testInputArray(self):
61        syn_params = {}
62        for label in ['weight', 'delay']:
63            if label == 'weight':
64                self.param_array = np.arange(self.N_array, dtype=float)
65            elif label == 'delay':
66                self.param_array = np.arange(1, self.N_array + 1) * 0.1
67            syn_params[label] = self.param_array
68            nest.ResetKernel()
69            self.setUpNetwork(self.conn_dict, syn_params,
70                              N1=self.N_array, N2=self.N_array)
71            M_nest = connect_test_base.get_weighted_connectivity_matrix(
72                self.pop1, self.pop2, label)
73            connect_test_base.mpi_assert(M_nest, np.diag(self.param_array), self)
74
75    def testInputArrayRPort(self):
76        syn_params = {}
77        neuron_model = 'iaf_psc_exp_multisynapse'
78        neuron_dict = {'tau_syn': [0.1 + i for i in range(self.N1)]}
79        self.pop1 = nest.Create(neuron_model, self.N1, neuron_dict)
80        self.pop2 = nest.Create(neuron_model, self.N1, neuron_dict)
81        self.param_array = np.arange(1, self.N1 + 1, dtype=int)
82        syn_params['receptor_type'] = self.param_array
83        nest.Connect(self.pop1, self.pop2, self.conn_dict, syn_params)
84        M = connect_test_base.get_weighted_connectivity_matrix(
85            self.pop1, self.pop2, 'receptor')
86        connect_test_base.mpi_assert(M, np.diag(self.param_array), self)
87
88    def testInputArrayToStdpSynapse(self):
89        params = ['Wmax', 'alpha', 'lambda', 'mu_minus', 'mu_plus', 'tau_plus']
90        syn_params = {'synapse_model': 'stdp_synapse'}
91        values = [np.arange(self.N1, dtype=float) for i in range(6)]
92        for i, param in enumerate(params):
93            syn_params[param] = values[i]
94        self.setUpNetwork(self.conn_dict, syn_params)
95        for i, param in enumerate(params):
96            a = connect_test_base.get_weighted_connectivity_matrix(
97                self.pop1, self.pop2, param)
98            connect_test_base.mpi_assert(np.diag(a), values[i], self)
99
100
101def suite():
102    suite = unittest.TestLoader().loadTestsFromTestCase(TestOneToOne)
103    return suite
104
105
106def run():
107    runner = unittest.TextTestRunner(verbosity=2)
108    runner.run(suite())
109
110
111if __name__ == '__main__':
112    run()
113