1from __future__ import absolute_import, print_function, division
2from copy import copy, deepcopy
3from functools import wraps
4import logging
5import sys
6import unittest
7from parameterized import parameterized
8from nose.tools import assert_raises
9
10from six import integer_types
11from six.moves import StringIO
12
13try:
14    from nose.plugins.attrib import attr
15except ImportError:
16    # This is an old version of nose
17    def attr(tag):
18        def func(f):
19            return f
20        return func
21import numpy as np
22
23import theano
24import theano.tensor as T
25from theano import config
26try:
27    from nose.plugins.skip import SkipTest
28except ImportError:
29    class SkipTest(Exception):
30        """
31        Skip this test
32        """
33_logger = logging.getLogger("theano.tests.unittest_tools")
34
35
36def custom_name_func(testcase_func, param_num, param):
37    return "%s_%s" % (
38        testcase_func.__name__,
39        parameterized.to_safe_name("_".join(str(x) for x in param.args)),
40    )
41
42
43def fetch_seed(pseed=None):
44    """
45    Returns the seed to use for running the unit tests.
46    If an explicit seed is given, it will be used for seeding numpy's rng.
47    If not, it will use config.unittest.rseed (its default value is 666).
48    If config.unittest.rseed is set to "random", it will seed the rng with
49    None, which is equivalent to seeding with a random seed.
50
51    Useful for seeding RandomState objects.
52    >>> rng = np.random.RandomState(unittest_tools.fetch_seed())
53    """
54
55    seed = pseed or config.unittests.rseed
56    if seed == 'random':
57        seed = None
58
59    try:
60        if seed:
61            seed = int(seed)
62        else:
63            seed = None
64    except ValueError:
65        print(('Error: config.unittests.rseed contains ' 'invalid seed, using None instead'), file=sys.stderr)
66        seed = None
67
68    return seed
69
70
71def seed_rng(pseed=None):
72    """
73    Seeds numpy's random number generator with the value returned by fetch_seed.
74    Usage: unittest_tools.seed_rng()
75    """
76
77    seed = fetch_seed(pseed)
78    if pseed and pseed != seed:
79        print('Warning: using seed given by config.unittests.rseed=%i' 'instead of seed %i given as parameter' % (seed, pseed), file=sys.stderr)
80    np.random.seed(seed)
81    return seed
82
83
84def verify_grad(op, pt, n_tests=2, rng=None, *args, **kwargs):
85    """
86    Wrapper for gradient.py:verify_grad
87    Takes care of seeding the random number generator if None is given
88    """
89    if rng is None:
90        seed_rng()
91        rng = np.random
92    T.verify_grad(op, pt, n_tests, rng, *args, **kwargs)
93
94#
95# This supports the following syntax:
96#
97# try:
98#     verify_grad(...)
99# except verify_grad.E_grad, e:
100#     print e.num_grad.gf
101#     print e.analytic_grad
102#     raise
103#
104verify_grad.E_grad = T.verify_grad.E_grad
105
106
107# A helpful class to check random values close to the boundaries
108# when designing new tests
109class MockRandomState:
110    def __init__(self, val):
111        self.val = val
112
113    def rand(self, *shape):
114        return np.zeros(shape, dtype='float64') + self.val
115
116    def randint(self, minval, maxval=None, size=1):
117        if maxval is None:
118            minval, maxval = 0, minval
119        out = np.zeros(size, dtype='int64')
120        if self.val == 0:
121            return out + minval
122        else:
123            return out + maxval - 1
124
125
126class TestOptimizationMixin(object):
127
128    def assertFunctionContains(self, f, op, min=1, max=sys.maxsize):
129        toposort = f.maker.fgraph.toposort()
130        matches = [node for node in toposort if node.op == op]
131        assert (min <= len(matches) <= max), (toposort, matches,
132                                              str(op), len(matches), min, max)
133
134    def assertFunctionContains0(self, f, op):
135        return self.assertFunctionContains(f, op, min=0, max=0)
136
137    def assertFunctionContains1(self, f, op):
138        return self.assertFunctionContains(f, op, min=1, max=1)
139
140    def assertFunctionContainsN(self, f, op, N):
141        return self.assertFunctionContains(f, op, min=N, max=N)
142
143    def assertFunctionContainsClass(self, f, op, min=1, max=sys.maxsize):
144        toposort = f.maker.fgraph.toposort()
145        matches = [node for node in toposort if isinstance(node.op, op)]
146        assert (min <= len(matches) <= max), (toposort, matches,
147                                              str(op), len(matches), min, max)
148
149    def assertFunctionContainsClassN(self, f, op, N):
150        return self.assertFunctionContainsClass(f, op, min=N, max=N)
151
152    def SkipTest(self, msg='Skip this test'):
153        raise SkipTest(msg)
154
155
156# This object name should not start with Test.
157# Otherwise nosetests will execute it!
158class T_OpContractMixin(object):
159    # self.ops should be a list of instantiations of an Op class to test.
160    # self.other_op should be an op which is different from every op
161    other_op = T.add
162
163    def copy(self, x):
164        return copy(x)
165
166    def deepcopy(self, x):
167        return deepcopy(x)
168
169    def clone(self, op):
170        raise NotImplementedError('return new instance like `op`')
171
172    def test_eq(self):
173        for i, op_i in enumerate(self.ops):
174            assert op_i == op_i
175            assert op_i == self.copy(op_i)
176            assert op_i == self.deepcopy(op_i)
177            assert op_i == self.clone(op_i)
178            assert op_i != self.other_op
179            for j, op_j in enumerate(self.ops):
180                if i == j:
181                    continue
182                assert op_i != op_j
183
184    def test_hash(self):
185        for i, op_i in enumerate(self.ops):
186            h_i = hash(op_i)
187            assert h_i == hash(op_i)
188            assert h_i == hash(self.copy(op_i))
189            assert h_i == hash(self.deepcopy(op_i))
190            assert h_i == hash(self.clone(op_i))
191            assert h_i != hash(self.other_op)
192            for j, op_j in enumerate(self.ops):
193                if i == j:
194                    continue
195                assert op_i != hash(op_j)
196
197    def test_name(self):
198        for op in self.ops:
199            s = str(op)    # show that str works
200            assert s       # names should not be empty
201
202
203class InferShapeTester(unittest.TestCase):
204
205    def setUp(self):
206        seed_rng()
207        # Take into account any mode that may be defined in a child class
208        # and it can be None
209        mode = getattr(self, 'mode', None)
210        if mode is None:
211            mode = theano.compile.get_default_mode()
212        # This mode seems to be the minimal one including the shape_i
213        # optimizations, if we don't want to enumerate them explicitly.
214        self.mode = mode.including("canonicalize")
215
216    def _compile_and_check(self, inputs, outputs, numeric_inputs, cls,
217                           excluding=None, warn=True, check_topo=True):
218        """This tests the infer_shape method only
219
220        When testing with input values with shapes that take the same
221        value over different dimensions (for instance, a square
222        matrix, or a tensor3 with shape (n, n, n), or (m, n, m)), it
223        is not possible to detect if the output shape was computed
224        correctly, or if some shapes with the same value have been
225        mixed up. For instance, if the infer_shape uses the width of a
226        matrix instead of its height, then testing with only square
227        matrices will not detect the problem. If warn=True, we emit a
228        warning when testing with such values.
229
230        :param check_topo: If True, we check that the Op where removed
231            from the graph. False is useful to test not implemented case.
232
233        """
234        mode = self.mode
235        if excluding:
236            mode = mode.excluding(*excluding)
237        if warn:
238            for var, inp in zip(inputs, numeric_inputs):
239                if isinstance(inp, (integer_types, float, list, tuple)):
240                    inp = var.type.filter(inp)
241                if not hasattr(inp, "shape"):
242                    continue
243                # remove broadcasted dims as it is sure they can't be
244                # changed to prevent the same dim problem.
245                if hasattr(var.type, "broadcastable"):
246                    shp = [inp.shape[i] for i in range(inp.ndim)
247                           if not var.type.broadcastable[i]]
248                else:
249                    shp = inp.shape
250                if len(set(shp)) != len(shp):
251                    _logger.warn(
252                        "While testing shape inference for %r, we received an"
253                        " input with a shape that has some repeated values: %r"
254                        ", like a square matrix. This makes it impossible to"
255                        " check if the values for these dimensions have been"
256                        " correctly used, or if they have been mixed up.",
257                        cls, inp.shape)
258                    break
259
260        outputs_function = theano.function(inputs, outputs, mode=mode)
261        shapes_function = theano.function(inputs, [o.shape for o in outputs],
262                                          mode=mode)
263        # theano.printing.debugprint(shapes_function)
264        # Check that the Op is removed from the compiled function.
265        if check_topo:
266            topo_shape = shapes_function.maker.fgraph.toposort()
267            assert not any(isinstance(t.op, cls) for t in topo_shape)
268        topo_out = outputs_function.maker.fgraph.toposort()
269        assert any(isinstance(t.op, cls) for t in topo_out)
270        # Check that the shape produced agrees with the actual shape.
271        numeric_outputs = outputs_function(*numeric_inputs)
272        numeric_shapes = shapes_function(*numeric_inputs)
273        for out, shape in zip(numeric_outputs, numeric_shapes):
274            assert np.all(out.shape == shape), (out.shape, shape)
275
276
277def str_diagnostic(expected, value, rtol, atol):
278    """Return a pretty multiline string representating the cause
279    of the exception"""
280    sio = StringIO()
281
282    try:
283        ssio = StringIO()
284        print("           : shape, dtype, strides, min, max, n_inf, n_nan:", file=ssio)
285        print("  Expected :", end=' ', file=ssio)
286        print(expected.shape, end=' ', file=ssio)
287        print(expected.dtype, end=' ', file=ssio)
288        print(expected.strides, end=' ', file=ssio)
289        print(expected.min(), end=' ', file=ssio)
290        print(expected.max(), end=' ', file=ssio)
291        print(np.isinf(expected).sum(), end=' ', file=ssio)
292        print(np.isnan(expected).sum(), end=' ', file=ssio)
293        # only if all succeeds to we add anything to sio
294        print(ssio.getvalue(), file=sio)
295    except Exception:
296        pass
297    try:
298        ssio = StringIO()
299        print("  Value    :", end=' ', file=ssio)
300        print(value.shape, end=' ', file=ssio)
301        print(value.dtype, end=' ', file=ssio)
302        print(value.strides, end=' ', file=ssio)
303        print(value.min(), end=' ', file=ssio)
304        print(value.max(), end=' ', file=ssio)
305        print(np.isinf(value).sum(), end=' ', file=ssio)
306        print(np.isnan(value).sum(), end=' ', file=ssio)
307        # only if all succeeds to we add anything to sio
308        print(ssio.getvalue(), file=sio)
309    except Exception:
310        pass
311
312    print("  expected    :", expected, file=sio)
313    print("  value    :", value, file=sio)
314
315    try:
316        ov = np.asarray(expected)
317        nv = np.asarray(value)
318        ssio = StringIO()
319        absdiff = np.absolute(nv - ov)
320        print("  Max Abs Diff: ", np.max(absdiff), file=ssio)
321        print("  Mean Abs Diff: ", np.mean(absdiff), file=ssio)
322        print("  Median Abs Diff: ", np.median(absdiff), file=ssio)
323        print("  Std Abs Diff: ", np.std(absdiff), file=ssio)
324        reldiff = np.absolute(nv - ov) / np.absolute(ov)
325        print("  Max Rel Diff: ", np.max(reldiff), file=ssio)
326        print("  Mean Rel Diff: ", np.mean(reldiff), file=ssio)
327        print("  Median Rel Diff: ", np.median(reldiff), file=ssio)
328        print("  Std Rel Diff: ", np.std(reldiff), file=ssio)
329        # only if all succeeds to we add anything to sio
330        print(ssio.getvalue(), file=sio)
331    except Exception:
332        pass
333    atol_, rtol_ = T.basic._get_atol_rtol(expected, value)
334    if rtol is not None:
335        rtol_ = rtol
336    if atol is not None:
337        atol_ = atol
338    print("  rtol, atol:", rtol_, atol_, file=sio)
339    return sio.getvalue()
340
341
342class WrongValue(Exception):
343
344    def __init__(self, expected_val, val, rtol, atol):
345        Exception.__init__(self)  # to be compatible with python2.4
346        self.val1 = expected_val
347        self.val2 = val
348        self.rtol = rtol
349        self.atol = atol
350
351    def __str__(self):
352        s = "WrongValue\n"
353        return s + str_diagnostic(self.val1, self.val2, self.rtol, self.atol)
354
355
356def assert_allclose(expected, value, rtol=None, atol=None):
357    if not T.basic._allclose(expected, value, rtol, atol):
358        raise WrongValue(expected, value, rtol, atol)
359
360
361class AttemptManyTimes:
362    """Decorator for unit tests that forces a unit test to be attempted
363    multiple times. The test needs to pass a certain number of times for it to
364    be considered to have succeeded. If it doesn't pass enough times, it is
365    considered to have failed.
366
367    Warning : care should be exercised when using this decorator. For some
368    tests, the fact that they fail randomly could point to important issues
369    such as race conditions, usage of uninitialized memory region, etc. and
370    using this decorator could hide these problems.
371
372    Usage:
373        @AttemptManyTimes(n_attempts=5, n_req_successes=3)
374        def fct(args):
375            ...
376    """
377
378    def __init__(self, n_attempts, n_req_successes=1):
379        assert n_attempts >= n_req_successes
380        self.n_attempts = n_attempts
381        self.n_req_successes = n_req_successes
382
383    def __call__(self, fct):
384
385        # Wrap fct in a function that will attempt to run it multiple
386        # times and return the result if the test passes enough times
387        # of propagate the raised exception if it doesn't.
388        @wraps(fct)
389        def attempt_multiple_times(*args, **kwargs):
390
391            # Keep a copy of the current seed for unittests so that we can use
392            # a different seed for every run of the decorated test and restore
393            # the original after
394            original_seed = config.unittests.rseed
395            current_seed = original_seed
396
397            # If the decorator has received only one, unnamed, argument
398            # and that argument has an attribute _testMethodName, it means
399            # that the unit test on which the decorator is used is in a test
400            # class. This means that the setup() method of that class will
401            # need to be called before any attempts to execute the test in
402            # case it relies on data randomly generated in the class' setup()
403            # method.
404            if (len(args) == 1 and hasattr(args[0], "_testMethodName")):
405                test_in_class = True
406                class_instance = args[0]
407            else:
408                test_in_class = False
409
410            n_fail = 0
411            n_success = 0
412
413            # Attempt to call the test function multiple times. If it does
414            # raise any exception for at least one attempt, it passes. If it
415            # raises an exception at every attempt, it fails.
416            for i in range(self.n_attempts):
417                try:
418                    # Attempt to make the test use the current seed
419                    config.unittests.rseed = current_seed
420                    if test_in_class and hasattr(class_instance, "setUp"):
421                        class_instance.setUp()
422
423                    fct(*args, **kwargs)
424
425                    n_success += 1
426                    if n_success == self.n_req_successes:
427                        break
428
429                except Exception:
430                    n_fail += 1
431
432                    # If there is not enough attempts remaining to achieve the
433                    # required number of successes, propagate the original
434                    # exception
435                    if n_fail + self.n_req_successes > self.n_attempts:
436                        raise
437
438                finally:
439                    # Clean up after the test
440                    config.unittests.rseed = original_seed
441                    if test_in_class and hasattr(class_instance, "tearDown"):
442                        class_instance.tearDown()
443
444                    # Update the current_seed
445                    if current_seed not in [None, "random"]:
446                        current_seed = str(int(current_seed) + 1)
447
448        return attempt_multiple_times
449
450
451def assertFailure_fast(f):
452    """A Decorator to handle the test cases that are failing when
453    THEANO_FLAGS =cycle_detection='fast'.
454    """
455    if theano.config.cycle_detection == 'fast':
456        def test_with_assert(*args, **kwargs):
457            with assert_raises(Exception):
458                f(*args, **kwargs)
459        return test_with_assert
460    else:
461        return f
462