1# -*- coding: utf-8 -*-
2#
3# test_connect_fixed_total_number.py
4#
5# This file is part of NEST.
6#
7# Copyright (C) 2004 The NEST Initiative
8#
9# NEST is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 2 of the License, or
12# (at your option) any later version.
13#
14# NEST is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with NEST.  If not, see <http://www.gnu.org/licenses/>.
21
22import numpy as np
23import unittest
24import scipy.stats
25import connect_test_base
26import nest
27
28
29class TestFixedTotalNumber(connect_test_base.ConnectTestBase):
30
31    # specify connection pattern and specific params
32    rule = 'fixed_total_number'
33    conn_dict = {'rule': rule}
34    # sizes of source-, target-population and outdegree for connection test
35    N1 = 50
36    N2 = 70
37    Nconn = 100
38    conn_dict['N'] = Nconn
39    # sizes of source-, target-population and total number of connections for
40    # statistical test
41    N_s = 20
42    N_t = 20
43    N = 100
44    # Critical values and number of iterations of two level test
45    stat_dict = {'alpha2': 0.05, 'n_runs': 200}
46
47    # tested on each mpi process separately
48    def testErrorMessages(self):
49        got_error = False
50        conn_params = self.conn_dict.copy()
51        conn_params['allow_autapses'] = True
52        conn_params['allow_multapses'] = False
53        conn_params['N'] = self.N1 * self.N2 + 1
54        try:
55            self.setUpNetwork(conn_params)
56        except nest.kernel.NESTError:
57            got_error = True
58        self.assertTrue(got_error)
59
60    def testTotalNumberOfConnections(self):
61        conn_params = self.conn_dict.copy()
62        self.setUpNetwork(conn_params)
63        total_conn = len(nest.GetConnections(self.pop1, self.pop2))
64        connect_test_base.mpi_assert(total_conn, self.Nconn, self)
65        # make sure no connections were drawn from the target to the source
66        # population
67        M = connect_test_base.get_connectivity_matrix(self.pop2, self.pop1)
68        M_none = np.zeros((len(self.pop1), len(self.pop2)))
69        connect_test_base.mpi_assert(M, M_none, self)
70
71    def testStatistics(self):
72        conn_params = self.conn_dict.copy()
73        conn_params['allow_autapses'] = True
74        conn_params['allow_multapses'] = True
75        conn_params['N'] = self.N
76        for fan in ['in', 'out']:
77            expected = connect_test_base.get_expected_degrees_totalNumber(
78                self.N, fan, self.N_s, self.N_t)
79            pvalues = []
80            for i in range(self.stat_dict['n_runs']):
81                connect_test_base.reset_seed(i + 1, self.nr_threads)
82                self.setUpNetwork(conn_dict=conn_params,
83                                  N1=self.N_s, N2=self.N_t)
84                degrees = connect_test_base.get_degrees(fan, self.pop1, self.pop2)
85                degrees = connect_test_base.gather_data(degrees)
86                if degrees is not None:
87                    chi, p = connect_test_base.chi_squared_check(degrees, expected)
88                    pvalues.append(p)
89                connect_test_base.mpi_barrier()
90            p = None
91            if degrees is not None:
92                ks, p = scipy.stats.kstest(pvalues, 'uniform')
93            p = connect_test_base.bcast_data(p)
94            self.assertGreater(p, self.stat_dict['alpha2'])
95
96    def testAutapsesTrue(self):
97        conn_params = self.conn_dict.copy()
98        N = 3
99
100        # test that autapses exist
101        conn_params['N'] = N * N * N
102        conn_params['allow_autapses'] = True
103        pop = nest.Create('iaf_psc_alpha', N)
104        nest.Connect(pop, pop, conn_params)
105        # make sure all connections do exist
106        M = connect_test_base.get_connectivity_matrix(pop, pop)
107        M = connect_test_base.gather_data(M)
108        if M is not None:
109            self.assertTrue(np.sum(np.diag(M)) > N)
110
111    def testAutapsesFalse(self):
112        conn_params = self.conn_dict.copy()
113        N = 3
114
115        # test that autapses were excluded
116        conn_params['N'] = N * (N - 1)
117        conn_params['allow_autapses'] = False
118        pop = nest.Create('iaf_psc_alpha', N)
119        nest.Connect(pop, pop, conn_params)
120        # make sure all connections do exist
121        M = connect_test_base.get_connectivity_matrix(pop, pop)
122        connect_test_base.mpi_assert(np.diag(M), np.zeros(N), self)
123
124
125def suite():
126    suite = unittest.TestLoader().loadTestsFromTestCase(TestFixedTotalNumber)
127    return suite
128
129
130def run():
131    runner = unittest.TextTestRunner(verbosity=2)
132    runner.run(suite())
133
134
135if __name__ == '__main__':
136    run()
137