1# -*- coding: utf-8 -*- 2# 3# test_connect_one_to_one.py 4# 5# This file is part of NEST. 6# 7# Copyright (C) 2004 The NEST Initiative 8# 9# NEST is free software: you can redistribute it and/or modify 10# it under the terms of the GNU General Public License as published by 11# the Free Software Foundation, either version 2 of the License, or 12# (at your option) any later version. 13# 14# NEST is distributed in the hope that it will be useful, 15# but WITHOUT ANY WARRANTY; without even the implied warranty of 16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17# GNU General Public License for more details. 18# 19# You should have received a copy of the GNU General Public License 20# along with NEST. If not, see <http://www.gnu.org/licenses/>. 21 22import numpy as np 23import unittest 24import connect_test_base 25import nest 26 27 28class TestOneToOne(connect_test_base.ConnectTestBase): 29 30 # specify connection pattern 31 rule = 'one_to_one' 32 conn_dict = {'rule': rule} 33 # sizes of populations 34 N = 6 35 N1 = N 36 N2 = N 37 N_array = 1000 38 39 def testConnectivity(self): 40 self.setUpNetwork(self.conn_dict) 41 # make sure all connections do exist 42 M = connect_test_base.get_connectivity_matrix(self.pop1, self.pop2) 43 connect_test_base.mpi_assert(M, np.identity(self.N), self) 44 # make sure no connections were drawn from the target to the source 45 # population 46 M = connect_test_base.get_connectivity_matrix(self.pop2, self.pop1) 47 connect_test_base.mpi_assert(M, np.zeros((self.N, self.N)), self) 48 49 def testSymmetricFlag(self): 50 conn_dict_symmetric = self.conn_dict.copy() 51 conn_dict_symmetric['make_symmetric'] = True 52 self.setUpNetwork(conn_dict_symmetric) 53 M1 = connect_test_base.get_connectivity_matrix(self.pop1, self.pop2) 54 M2 = connect_test_base.get_connectivity_matrix(self.pop2, self.pop1) 55 # test that connections were created in both directions 56 connect_test_base.mpi_assert(M1, np.transpose(connect_test_base.gather_data(M2)), self) 57 # test that no other connections were created 58 connect_test_base.mpi_assert(M1, np.zeros_like(M1) + np.identity(self.N), self) 59 60 def testInputArray(self): 61 syn_params = {} 62 for label in ['weight', 'delay']: 63 if label == 'weight': 64 self.param_array = np.arange(self.N_array, dtype=float) 65 elif label == 'delay': 66 self.param_array = np.arange(1, self.N_array + 1) * 0.1 67 syn_params[label] = self.param_array 68 nest.ResetKernel() 69 self.setUpNetwork(self.conn_dict, syn_params, 70 N1=self.N_array, N2=self.N_array) 71 M_nest = connect_test_base.get_weighted_connectivity_matrix( 72 self.pop1, self.pop2, label) 73 connect_test_base.mpi_assert(M_nest, np.diag(self.param_array), self) 74 75 def testInputArrayRPort(self): 76 syn_params = {} 77 neuron_model = 'iaf_psc_exp_multisynapse' 78 neuron_dict = {'tau_syn': [0.1 + i for i in range(self.N1)]} 79 self.pop1 = nest.Create(neuron_model, self.N1, neuron_dict) 80 self.pop2 = nest.Create(neuron_model, self.N1, neuron_dict) 81 self.param_array = np.arange(1, self.N1 + 1, dtype=int) 82 syn_params['receptor_type'] = self.param_array 83 nest.Connect(self.pop1, self.pop2, self.conn_dict, syn_params) 84 M = connect_test_base.get_weighted_connectivity_matrix( 85 self.pop1, self.pop2, 'receptor') 86 connect_test_base.mpi_assert(M, np.diag(self.param_array), self) 87 88 def testInputArrayToStdpSynapse(self): 89 params = ['Wmax', 'alpha', 'lambda', 'mu_minus', 'mu_plus', 'tau_plus'] 90 syn_params = {'synapse_model': 'stdp_synapse'} 91 values = [np.arange(self.N1, dtype=float) for i in range(6)] 92 for i, param in enumerate(params): 93 syn_params[param] = values[i] 94 self.setUpNetwork(self.conn_dict, syn_params) 95 for i, param in enumerate(params): 96 a = connect_test_base.get_weighted_connectivity_matrix( 97 self.pop1, self.pop2, param) 98 connect_test_base.mpi_assert(np.diag(a), values[i], self) 99 100 101def suite(): 102 suite = unittest.TestLoader().loadTestsFromTestCase(TestOneToOne) 103 return suite 104 105 106def run(): 107 runner = unittest.TextTestRunner(verbosity=2) 108 runner.run(suite()) 109 110 111if __name__ == '__main__': 112 run() 113