1from pyNN.standardmodels import build_translations, StandardModelType, \
2    STDPWeightDependence, STDPTimingDependence
3from pyNN.standardmodels.synapses import StaticSynapse, STDPMechanism
4from pyNN import errors
5from pyNN.parameters import ParameterSpace
6from nose.tools import assert_equal, assert_raises
7try:
8    from unittest.mock import Mock
9except ImportError:
10    from mock import Mock
11import numpy as np
12
13
14def test_build_translations():
15    t = build_translations(
16        ('a', 'A'),
17        ('b', 'B', 1000.0),
18        ('c', 'C', 'c + a', 'C - A')
19    )
20    assert_equal(set(t.keys()), set(['a', 'b', 'c']))
21    assert_equal(set(t['a'].keys()),
22                 set(['translated_name', 'forward_transform', 'reverse_transform']))
23    assert_equal(t['a']['translated_name'], 'A')
24    assert_equal(t['a']['forward_transform'], 'a')
25    assert_equal(t['a']['reverse_transform'], 'A')
26    assert_equal(t['b']['translated_name'], 'B')
27    assert_equal(t['b']['forward_transform'], 'float(1000)*b')
28    assert_equal(t['b']['reverse_transform'], 'B/float(1000)')
29    assert_equal(t['c']['translated_name'], 'C')
30    assert_equal(t['c']['forward_transform'], 'c + a')
31    assert_equal(t['c']['reverse_transform'], 'C - A')
32
33
34def test_has_parameter():
35    M = StandardModelType
36    M.default_parameters = {'a': 22.2, 'b': 33.3}
37    assert M.has_parameter('a')
38    assert M.has_parameter('b')
39    assert not M.has_parameter('z')
40
41
42def test_get_parameter_names():
43    M = StandardModelType
44    M.default_parameters = {'a': 22.2, 'b': 33.3}
45    assert_equal(set(M.get_parameter_names()), set(['a', 'b']))
46
47
48def test_instantiate():
49    """
50    Instantiating a StandardModelType should set self.parameter_space to a
51    ParameterSpace object containing the provided parameters.
52    """
53    M = StandardModelType
54    M.default_parameters = {'a': 0.0, 'b': 0.0}
55    P1 = {'a': 22.2, 'b': 33.3}
56    m = M(**P1)
57    assert_equal(m.parameter_space._parameters, ParameterSpace(P1, None, None)._parameters)
58    M.default_parameters = {}
59
60
61def _parameter_space_to_dict(parameter_space, size):
62    parameter_space.shape = (size,)
63    parameter_space.evaluate(simplify=True)
64    return parameter_space.as_dict()
65
66
67def test_translate():
68    M = StandardModelType
69    M.default_parameters = {'a': 22.2, 'b': 33.3, 'c': 44.4}
70    M.translations = build_translations(
71        ('a', 'A'),
72        ('b', 'B', 1000.0),
73        ('c', 'C', 'c + a', 'C - A'),
74    )
75    m = M()
76    native_parameters = m.translate(ParameterSpace(
77        {'a': 23.4, 'b': 34.5, 'c': 45.6}, m.get_schema(), None))
78    assert_equal(_parameter_space_to_dict(native_parameters, 77),
79                 {'A': 23.4, 'B': 34500.0, 'C': 69.0})
80
81
82def test_translate_with_invalid_transformation():
83    M = StandardModelType
84    M.translations = build_translations(
85        ('a', 'A'),
86        ('b', 'B', 'b + z', 'B-Z'),
87    )
88    M.default_parameters = {'a': 22.2, 'b': 33.3}
89    # really we should trap such errors in build_translations(), not in translate()
90    m = M()
91    assert_raises(NameError,
92                  m.translate,
93                  ParameterSpace({'a': 23.4, 'b': 34.5}, m.get_schema(), None))
94
95
96def test_translate_with_divide_by_zero_error():
97    M = StandardModelType
98    M.default_parameters = {'a': 22.2, 'b': 33.3}
99    M.translations = build_translations(
100        ('a', 'A'),
101        ('b', 'B', 'b/0', 'B*0'),
102    )
103    m = M()
104    native_parameters = m.translate(ParameterSpace({'a': 23.4, 'b': 34.5}, m.get_schema(), 77))
105    assert_raises(ZeroDivisionError,
106                  native_parameters.evaluate,
107                  simplify=True)
108
109
110def test_reverse_translate():
111    M = StandardModelType
112    M.default_parameters = {'a': 22.2, 'b': 33.3, 'c': 44.4}
113    M.translations = build_translations(
114        ('a', 'A'),
115        ('b', 'B', 1000.0),
116        ('c', 'C', 'c + a', 'C - A'),
117    )
118    assert_equal(_parameter_space_to_dict(M().reverse_translate(ParameterSpace({'A': 23.4, 'B': 34500.0, 'C': 69.0})), 88),
119                 {'a': 23.4, 'b': 34.5, 'c': 45.6})
120
121
122def test_reverse_translate_with_invalid_transformation():
123    M = StandardModelType
124    M.translations = build_translations(
125        ('a', 'A'),
126        ('b', 'B', 'b + z', 'B-Z'),
127    )
128    M.default_parameters = {'a': 22.2, 'b': 33.3}
129    # really we should trap such errors in build_translations(), not in reverse_translate()
130    assert_raises(NameError,
131                  M().reverse_translate,
132                  {'A': 23.4, 'B': 34.5})
133
134
135def test_simple_parameters():
136    M = StandardModelType
137    M.default_parameters = {'a': 22.2, 'b': 33.3, 'c': 44.4}
138    M.translations = build_translations(
139        ('a', 'A'),
140        ('b', 'B', 1000.0),
141        ('c', 'C', 'c + a', 'C - A'),
142    )
143    assert_equal(M().simple_parameters(), ['a'])
144
145
146def test_scaled_parameters():
147    M = StandardModelType
148    M.default_parameters = {'a': 22.2, 'b': 33.3, 'c': 44.4}
149    M.translations = build_translations(
150        ('a', 'A'),
151        ('b', 'B', 1000.0),
152        ('c', 'C', 'c + a', 'C - A'),
153    )
154    assert_equal(M().scaled_parameters(), ['b'])
155
156
157def test_computed_parameters():
158    M = StandardModelType
159    M.default_parameters = {'a': 22.2, 'b': 33.3, 'c': 44.4}
160    M.translations = build_translations(
161        ('a', 'A'),
162        ('b', 'B', 1000.0),
163        ('c', 'C', 'c + a', 'C - A'),
164    )
165    assert_equal(M().computed_parameters(), ['c'])
166
167
168def test_describe():
169    M = StandardModelType
170    M.default_parameters = {'a': 22.2, 'b': 33.3, 'c': 44.4}
171    M.translations = build_translations(
172        ('a', 'A'),
173        ('b', 'B', 1000.0),
174        ('c', 'C', 'c + a', 'C - A'),
175    )
176    assert isinstance(M().describe(), str)
177
178# test StandardCellType
179
180# test ComposedSynapseType
181
182# test create
183
184
185def test_describe_synapse_type():
186    StaticSynapse._get_minimum_delay = lambda self: 0.1
187    sd = StaticSynapse()
188    assert isinstance(sd.describe(), str)
189    assert isinstance(sd.describe(template=None), dict)
190    del StaticSynapse._get_minimum_delay
191
192
193def test_STDPMechanism_create():
194    STDPMechanism._get_minimum_delay = lambda self: 0.1
195    STDPMechanism.base_translations = {}
196    STDPTimingDependence.__init__ = Mock(return_value=None)
197    STDPWeightDependence.__init__ = Mock(return_value=None)
198    td = STDPTimingDependence()
199    wd = STDPWeightDependence()
200    stdp = STDPMechanism(td, wd, None, 0.5)
201    assert_equal(stdp.timing_dependence, td)
202    assert_equal(stdp.weight_dependence, wd)
203    assert_equal(stdp.voltage_dependence, None)
204    assert_equal(stdp.dendritic_delay_fraction, 0.5)
205    del STDPMechanism._get_minimum_delay
206    del STDPMechanism.base_translations
207
208
209def test_STDPMechanism_create_invalid_types():
210    assert_raises(AssertionError,  # probably want a more informative error
211                  STDPMechanism, timing_dependence="abc")
212    assert_raises(AssertionError,  # probably want a more informative error
213                  STDPMechanism, weight_dependence="abc")
214    assert_raises(AssertionError,  # probably want a more informative error
215                  STDPMechanism, dendritic_delay_fraction="abc")
216    assert_raises(AssertionError,  # probably want a more informative error
217                  STDPMechanism, dendritic_delay_fraction="1.1")
218
219
220# test STDPWeightDependence
221
222# test STDPTimingDependence
223