1"""
2Mock classes for unit tests
3
4:copyright: Copyright 2006-2021 by the PyNN team, see AUTHORS.
5:license: CeCILL, see LICENSE for details.
6"""
7
8from pyNN import random
9import numpy as np
10
11
12class MockRNG(random.WrappedRNG):
13    rng = None
14
15    def __init__(self, start=0.0, delta=1, parallel_safe=True):
16        random.WrappedRNG.__init__(self, parallel_safe=parallel_safe)
17        self.start = start
18        self.delta = delta
19
20    def _next(self, distribution, n, parameters):
21        if distribution == "uniform_int":
22            return self._next_int(n, parameters)
23        elif distribution == "binomial":
24            return self._next_binomial(n, parameters)
25        s = self.start
26        self.start += n * self.delta
27        return s + self.delta * np.arange(n)
28
29    def _next_int(self, n, parameters):
30        low, high = parameters["low"], parameters["high"]
31        s = int(self.start)
32        self.start += n * self.delta
33        x = s + self.delta * np.arange(n)
34        return x % (high - low) + low
35
36    def _next_binomial(self, n, parameters):
37        return self._next_int(n, {"low": 0, "high": parameters["n"]})
38
39    def permutation(self, arr):
40        return arr[::-1]
41
42
43class MockRNG2(random.WrappedRNG):
44    rng = None
45
46    def __init__(self, numbers, parallel_safe=True):
47        random.WrappedRNG.__init__(self, parallel_safe=parallel_safe)
48        self.numbers = numbers
49        self.i = 0
50
51    def _next(self, distribution, n, parameters):
52        x = self.numbers[self.i:self.i + n]
53        self.i += n
54        return x
55
56    def permutation(self, arr):
57        return arr[::-1]
58
59
60class MockRNG3(random.WrappedRNG):
61    """
62    returns [1, 0, 0, 0,..]
63    """
64    rng = None
65
66    def __init__(self, parallel_safe=True):
67        random.WrappedRNG.__init__(self, parallel_safe=parallel_safe)
68
69    def _next(self, distribution, n, parameters):
70        x = np.zeros(n)
71        x.dtype = int
72        x[0] = 1
73        return x
74
75    def permutation(self, arr):
76        return arr[::-1]
77