1# -*- coding: utf-8 -*- 2# 3# test_connect_fixed_total_number.py 4# 5# This file is part of NEST. 6# 7# Copyright (C) 2004 The NEST Initiative 8# 9# NEST is free software: you can redistribute it and/or modify 10# it under the terms of the GNU General Public License as published by 11# the Free Software Foundation, either version 2 of the License, or 12# (at your option) any later version. 13# 14# NEST is distributed in the hope that it will be useful, 15# but WITHOUT ANY WARRANTY; without even the implied warranty of 16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17# GNU General Public License for more details. 18# 19# You should have received a copy of the GNU General Public License 20# along with NEST. If not, see <http://www.gnu.org/licenses/>. 21 22import numpy as np 23import unittest 24import scipy.stats 25import connect_test_base 26import nest 27 28 29class TestFixedTotalNumber(connect_test_base.ConnectTestBase): 30 31 # specify connection pattern and specific params 32 rule = 'fixed_total_number' 33 conn_dict = {'rule': rule} 34 # sizes of source-, target-population and outdegree for connection test 35 N1 = 50 36 N2 = 70 37 Nconn = 100 38 conn_dict['N'] = Nconn 39 # sizes of source-, target-population and total number of connections for 40 # statistical test 41 N_s = 20 42 N_t = 20 43 N = 100 44 # Critical values and number of iterations of two level test 45 stat_dict = {'alpha2': 0.05, 'n_runs': 200} 46 47 # tested on each mpi process separately 48 def testErrorMessages(self): 49 got_error = False 50 conn_params = self.conn_dict.copy() 51 conn_params['allow_autapses'] = True 52 conn_params['allow_multapses'] = False 53 conn_params['N'] = self.N1 * self.N2 + 1 54 try: 55 self.setUpNetwork(conn_params) 56 except nest.kernel.NESTError: 57 got_error = True 58 self.assertTrue(got_error) 59 60 def testTotalNumberOfConnections(self): 61 conn_params = self.conn_dict.copy() 62 self.setUpNetwork(conn_params) 63 total_conn = len(nest.GetConnections(self.pop1, self.pop2)) 64 connect_test_base.mpi_assert(total_conn, self.Nconn, self) 65 # make sure no connections were drawn from the target to the source 66 # population 67 M = connect_test_base.get_connectivity_matrix(self.pop2, self.pop1) 68 M_none = np.zeros((len(self.pop1), len(self.pop2))) 69 connect_test_base.mpi_assert(M, M_none, self) 70 71 def testStatistics(self): 72 conn_params = self.conn_dict.copy() 73 conn_params['allow_autapses'] = True 74 conn_params['allow_multapses'] = True 75 conn_params['N'] = self.N 76 for fan in ['in', 'out']: 77 expected = connect_test_base.get_expected_degrees_totalNumber( 78 self.N, fan, self.N_s, self.N_t) 79 pvalues = [] 80 for i in range(self.stat_dict['n_runs']): 81 connect_test_base.reset_seed(i + 1, self.nr_threads) 82 self.setUpNetwork(conn_dict=conn_params, 83 N1=self.N_s, N2=self.N_t) 84 degrees = connect_test_base.get_degrees(fan, self.pop1, self.pop2) 85 degrees = connect_test_base.gather_data(degrees) 86 if degrees is not None: 87 chi, p = connect_test_base.chi_squared_check(degrees, expected) 88 pvalues.append(p) 89 connect_test_base.mpi_barrier() 90 p = None 91 if degrees is not None: 92 ks, p = scipy.stats.kstest(pvalues, 'uniform') 93 p = connect_test_base.bcast_data(p) 94 self.assertGreater(p, self.stat_dict['alpha2']) 95 96 def testAutapsesTrue(self): 97 conn_params = self.conn_dict.copy() 98 N = 3 99 100 # test that autapses exist 101 conn_params['N'] = N * N * N 102 conn_params['allow_autapses'] = True 103 pop = nest.Create('iaf_psc_alpha', N) 104 nest.Connect(pop, pop, conn_params) 105 # make sure all connections do exist 106 M = connect_test_base.get_connectivity_matrix(pop, pop) 107 M = connect_test_base.gather_data(M) 108 if M is not None: 109 self.assertTrue(np.sum(np.diag(M)) > N) 110 111 def testAutapsesFalse(self): 112 conn_params = self.conn_dict.copy() 113 N = 3 114 115 # test that autapses were excluded 116 conn_params['N'] = N * (N - 1) 117 conn_params['allow_autapses'] = False 118 pop = nest.Create('iaf_psc_alpha', N) 119 nest.Connect(pop, pop, conn_params) 120 # make sure all connections do exist 121 M = connect_test_base.get_connectivity_matrix(pop, pop) 122 connect_test_base.mpi_assert(np.diag(M), np.zeros(N), self) 123 124 125def suite(): 126 suite = unittest.TestLoader().loadTestsFromTestCase(TestFixedTotalNumber) 127 return suite 128 129 130def run(): 131 runner = unittest.TextTestRunner(verbosity=2) 132 runner.run(suite()) 133 134 135if __name__ == '__main__': 136 run() 137