1from __future__ import absolute_import
2import atexit
3import functools
4import os
5import random
6import types
7
8import numpy
9
10from chainer.backends import cuda
11from chainer.testing import _bundle
12
13
14_old_python_random_state = None
15_old_numpy_random_state = None
16
17
18def _numpy_do_setup(deterministic=True):
19    global _old_python_random_state
20    global _old_numpy_random_state
21    _old_python_random_state = random.getstate()
22    _old_numpy_random_state = numpy.random.get_state()
23    if not deterministic:
24        numpy.random.seed()
25    else:
26        numpy.random.seed(100)
27
28
29def _numpy_do_teardown():
30    global _old_python_random_state
31    global _old_numpy_random_state
32    random.setstate(_old_python_random_state)
33    numpy.random.set_state(_old_numpy_random_state)
34    _old_python_random_state = None
35    _old_numpy_random_state = None
36
37
38def _cupy_testing_random():
39    testing = cuda.cupy.testing
40    if hasattr(testing, 'random'):
41        return testing.random
42    return testing._random
43
44
45def do_setup(deterministic=True):
46    if cuda.available:
47        _cupy_testing_random().do_setup(deterministic)
48    else:
49        _numpy_do_setup(deterministic)
50
51
52def do_teardown():
53    if cuda.available:
54        _cupy_testing_random().do_teardown()
55    else:
56        _numpy_do_teardown()
57
58
59# In some tests (which utilize condition.repeat or condition.retry),
60# setUp/tearDown is nested. _setup_random() and _teardown_random() do their
61# work only in the outermost setUp/tearDown pair.
62_nest_count = 0
63
64
65@atexit.register
66def _check_teardown():
67    assert _nest_count == 0, ('_setup_random() and _teardown_random() '
68                              'must be called in pairs.')
69
70
71def _setup_random():
72    """Sets up the deterministic random states of ``numpy`` and ``cupy``.
73
74    """
75    global _nest_count
76    if _nest_count == 0:
77        nondeterministic = bool(int(os.environ.get(
78            'CHAINER_TEST_RANDOM_NONDETERMINISTIC', '0')))
79        do_setup(not nondeterministic)
80    _nest_count += 1
81
82
83def _teardown_random():
84    """Tears down the deterministic random states set up by ``_setup_random``.
85
86    """
87    global _nest_count
88    assert _nest_count > 0, '_setup_random has not been called'
89    _nest_count -= 1
90    if _nest_count == 0:
91        do_teardown()
92
93
94def generate_seed():
95    assert _nest_count > 0, 'random is not set up'
96    return numpy.random.randint(0xffffffff)
97
98
99def _fix_random(setup_method_name, teardown_method_name):
100    # TODO(niboshi): Prevent this decorator from being applied within
101    #    condition.repeat or condition.retry decorators. That would repeat
102    #    tests with the same random seeds. It's okay to apply this outside
103    #    these decorators.
104
105    def decorator(impl):
106        if (isinstance(impl, types.FunctionType) and
107                impl.__name__.startswith('test_')):
108            # Applied to test method
109            @functools.wraps(impl)
110            def test_func(self, *args, **kw):
111                _setup_random()
112                try:
113                    impl(self, *args, **kw)
114                finally:
115                    _teardown_random()
116            return test_func
117
118        if isinstance(impl, _bundle._ParameterizedTestCaseBundle):
119            cases = impl
120        else:
121            tup = _bundle._TestCaseTuple(impl, None, None)
122            cases = _bundle._ParameterizedTestCaseBundle([tup])
123
124        for klass, _, _ in cases.cases:
125            # Applied to test case class
126
127            def make_methods():
128                # make_methods is required to bind the variables prev_setup and
129                # prev_teardown.
130                prev_setup = getattr(klass, setup_method_name)
131                prev_teardown = getattr(klass, teardown_method_name)
132
133                @functools.wraps(prev_setup)
134                def new_setup(self):
135                    _setup_random()
136                    prev_setup(self)
137
138                @functools.wraps(prev_teardown)
139                def new_teardown(self):
140                    try:
141                        prev_teardown(self)
142                    finally:
143                        _teardown_random()
144
145                return new_setup, new_teardown
146
147            setup, teardown = make_methods()
148
149            setattr(klass, setup_method_name, setup)
150            setattr(klass, teardown_method_name, teardown)
151
152        return cases
153
154    return decorator
155
156
157def fix_random(*, setup_method='setUp', teardown_method='tearDown'):
158    """Decorator that fixes random numbers in a test.
159
160    This decorator can be applied to either a test case class or a test method.
161    It should not be applied within ``condition.retry`` or
162    ``condition.repeat``.
163    """
164    return _fix_random(setup_method, teardown_method)
165