1from pyNN.standardmodels import build_translations, StandardModelType, \ 2 STDPWeightDependence, STDPTimingDependence 3from pyNN.standardmodels.synapses import StaticSynapse, STDPMechanism 4from pyNN import errors 5from pyNN.parameters import ParameterSpace 6from nose.tools import assert_equal, assert_raises 7try: 8 from unittest.mock import Mock 9except ImportError: 10 from mock import Mock 11import numpy as np 12 13 14def test_build_translations(): 15 t = build_translations( 16 ('a', 'A'), 17 ('b', 'B', 1000.0), 18 ('c', 'C', 'c + a', 'C - A') 19 ) 20 assert_equal(set(t.keys()), set(['a', 'b', 'c'])) 21 assert_equal(set(t['a'].keys()), 22 set(['translated_name', 'forward_transform', 'reverse_transform'])) 23 assert_equal(t['a']['translated_name'], 'A') 24 assert_equal(t['a']['forward_transform'], 'a') 25 assert_equal(t['a']['reverse_transform'], 'A') 26 assert_equal(t['b']['translated_name'], 'B') 27 assert_equal(t['b']['forward_transform'], 'float(1000)*b') 28 assert_equal(t['b']['reverse_transform'], 'B/float(1000)') 29 assert_equal(t['c']['translated_name'], 'C') 30 assert_equal(t['c']['forward_transform'], 'c + a') 31 assert_equal(t['c']['reverse_transform'], 'C - A') 32 33 34def test_has_parameter(): 35 M = StandardModelType 36 M.default_parameters = {'a': 22.2, 'b': 33.3} 37 assert M.has_parameter('a') 38 assert M.has_parameter('b') 39 assert not M.has_parameter('z') 40 41 42def test_get_parameter_names(): 43 M = StandardModelType 44 M.default_parameters = {'a': 22.2, 'b': 33.3} 45 assert_equal(set(M.get_parameter_names()), set(['a', 'b'])) 46 47 48def test_instantiate(): 49 """ 50 Instantiating a StandardModelType should set self.parameter_space to a 51 ParameterSpace object containing the provided parameters. 52 """ 53 M = StandardModelType 54 M.default_parameters = {'a': 0.0, 'b': 0.0} 55 P1 = {'a': 22.2, 'b': 33.3} 56 m = M(**P1) 57 assert_equal(m.parameter_space._parameters, ParameterSpace(P1, None, None)._parameters) 58 M.default_parameters = {} 59 60 61def _parameter_space_to_dict(parameter_space, size): 62 parameter_space.shape = (size,) 63 parameter_space.evaluate(simplify=True) 64 return parameter_space.as_dict() 65 66 67def test_translate(): 68 M = StandardModelType 69 M.default_parameters = {'a': 22.2, 'b': 33.3, 'c': 44.4} 70 M.translations = build_translations( 71 ('a', 'A'), 72 ('b', 'B', 1000.0), 73 ('c', 'C', 'c + a', 'C - A'), 74 ) 75 m = M() 76 native_parameters = m.translate(ParameterSpace( 77 {'a': 23.4, 'b': 34.5, 'c': 45.6}, m.get_schema(), None)) 78 assert_equal(_parameter_space_to_dict(native_parameters, 77), 79 {'A': 23.4, 'B': 34500.0, 'C': 69.0}) 80 81 82def test_translate_with_invalid_transformation(): 83 M = StandardModelType 84 M.translations = build_translations( 85 ('a', 'A'), 86 ('b', 'B', 'b + z', 'B-Z'), 87 ) 88 M.default_parameters = {'a': 22.2, 'b': 33.3} 89 # really we should trap such errors in build_translations(), not in translate() 90 m = M() 91 assert_raises(NameError, 92 m.translate, 93 ParameterSpace({'a': 23.4, 'b': 34.5}, m.get_schema(), None)) 94 95 96def test_translate_with_divide_by_zero_error(): 97 M = StandardModelType 98 M.default_parameters = {'a': 22.2, 'b': 33.3} 99 M.translations = build_translations( 100 ('a', 'A'), 101 ('b', 'B', 'b/0', 'B*0'), 102 ) 103 m = M() 104 native_parameters = m.translate(ParameterSpace({'a': 23.4, 'b': 34.5}, m.get_schema(), 77)) 105 assert_raises(ZeroDivisionError, 106 native_parameters.evaluate, 107 simplify=True) 108 109 110def test_reverse_translate(): 111 M = StandardModelType 112 M.default_parameters = {'a': 22.2, 'b': 33.3, 'c': 44.4} 113 M.translations = build_translations( 114 ('a', 'A'), 115 ('b', 'B', 1000.0), 116 ('c', 'C', 'c + a', 'C - A'), 117 ) 118 assert_equal(_parameter_space_to_dict(M().reverse_translate(ParameterSpace({'A': 23.4, 'B': 34500.0, 'C': 69.0})), 88), 119 {'a': 23.4, 'b': 34.5, 'c': 45.6}) 120 121 122def test_reverse_translate_with_invalid_transformation(): 123 M = StandardModelType 124 M.translations = build_translations( 125 ('a', 'A'), 126 ('b', 'B', 'b + z', 'B-Z'), 127 ) 128 M.default_parameters = {'a': 22.2, 'b': 33.3} 129 # really we should trap such errors in build_translations(), not in reverse_translate() 130 assert_raises(NameError, 131 M().reverse_translate, 132 {'A': 23.4, 'B': 34.5}) 133 134 135def test_simple_parameters(): 136 M = StandardModelType 137 M.default_parameters = {'a': 22.2, 'b': 33.3, 'c': 44.4} 138 M.translations = build_translations( 139 ('a', 'A'), 140 ('b', 'B', 1000.0), 141 ('c', 'C', 'c + a', 'C - A'), 142 ) 143 assert_equal(M().simple_parameters(), ['a']) 144 145 146def test_scaled_parameters(): 147 M = StandardModelType 148 M.default_parameters = {'a': 22.2, 'b': 33.3, 'c': 44.4} 149 M.translations = build_translations( 150 ('a', 'A'), 151 ('b', 'B', 1000.0), 152 ('c', 'C', 'c + a', 'C - A'), 153 ) 154 assert_equal(M().scaled_parameters(), ['b']) 155 156 157def test_computed_parameters(): 158 M = StandardModelType 159 M.default_parameters = {'a': 22.2, 'b': 33.3, 'c': 44.4} 160 M.translations = build_translations( 161 ('a', 'A'), 162 ('b', 'B', 1000.0), 163 ('c', 'C', 'c + a', 'C - A'), 164 ) 165 assert_equal(M().computed_parameters(), ['c']) 166 167 168def test_describe(): 169 M = StandardModelType 170 M.default_parameters = {'a': 22.2, 'b': 33.3, 'c': 44.4} 171 M.translations = build_translations( 172 ('a', 'A'), 173 ('b', 'B', 1000.0), 174 ('c', 'C', 'c + a', 'C - A'), 175 ) 176 assert isinstance(M().describe(), str) 177 178# test StandardCellType 179 180# test ComposedSynapseType 181 182# test create 183 184 185def test_describe_synapse_type(): 186 StaticSynapse._get_minimum_delay = lambda self: 0.1 187 sd = StaticSynapse() 188 assert isinstance(sd.describe(), str) 189 assert isinstance(sd.describe(template=None), dict) 190 del StaticSynapse._get_minimum_delay 191 192 193def test_STDPMechanism_create(): 194 STDPMechanism._get_minimum_delay = lambda self: 0.1 195 STDPMechanism.base_translations = {} 196 STDPTimingDependence.__init__ = Mock(return_value=None) 197 STDPWeightDependence.__init__ = Mock(return_value=None) 198 td = STDPTimingDependence() 199 wd = STDPWeightDependence() 200 stdp = STDPMechanism(td, wd, None, 0.5) 201 assert_equal(stdp.timing_dependence, td) 202 assert_equal(stdp.weight_dependence, wd) 203 assert_equal(stdp.voltage_dependence, None) 204 assert_equal(stdp.dendritic_delay_fraction, 0.5) 205 del STDPMechanism._get_minimum_delay 206 del STDPMechanism.base_translations 207 208 209def test_STDPMechanism_create_invalid_types(): 210 assert_raises(AssertionError, # probably want a more informative error 211 STDPMechanism, timing_dependence="abc") 212 assert_raises(AssertionError, # probably want a more informative error 213 STDPMechanism, weight_dependence="abc") 214 assert_raises(AssertionError, # probably want a more informative error 215 STDPMechanism, dendritic_delay_fraction="abc") 216 assert_raises(AssertionError, # probably want a more informative error 217 STDPMechanism, dendritic_delay_fraction="1.1") 218 219 220# test STDPWeightDependence 221 222# test STDPTimingDependence 223