1from __future__ import absolute_import, print_function, division
2from collections import OrderedDict
3from theano.tests.record import RecordMode, Record
4from theano.tests import disturb_mem
5import numpy as np
6import theano
7from theano.printing import var_descriptor
8
9from theano import config, shared
10from six import StringIO
11from six.moves import xrange
12
13
14def sharedX(x, name=None):
15    x = np.cast[config.floatX](x)
16    return shared(x, name)
17
18
19def test_determinism_1():
20
21    # Tests that repeatedly running a script that compiles and
22    # runs a function does exactly the same thing every time it
23    # is run, even when the memory addresses of the objects involved
24    # change.
25    # This specific script is capable of catching a bug where
26    # FunctionGraph.toposort was non-deterministic.
27
28    def run(replay, log=None):
29
30        if not replay:
31            log = StringIO()
32        else:
33            log = StringIO(log)
34        record = Record(replay=replay, file_object=log)
35
36        disturb_mem.disturb_mem()
37
38        mode = RecordMode(record=record)
39
40        b = sharedX(np.zeros((2,)), name='b')
41        channels = OrderedDict()
42
43        disturb_mem.disturb_mem()
44
45        v_max = b.max(axis=0)
46        v_min = b.min(axis=0)
47        v_range = v_max - v_min
48
49        updates = []
50        for i, val in enumerate([
51                v_max.max(),
52                v_max.min(),
53                v_range.max(),
54                ]):
55            disturb_mem.disturb_mem()
56            s = sharedX(0., name='s_' + str(i))
57            updates.append((s, val))
58
59        for var in theano.gof.graph.ancestors(update for _, update in updates):
60            if var.name is not None and var.name != 'b':
61                if var.name[0] != 's' or len(var.name) != 2:
62                    var.name = None
63
64        for key in channels:
65            updates.append((s, channels[key]))
66        f = theano.function([], mode=mode, updates=updates,
67                            on_unused_input='ignore', name='f')
68        for output in f.maker.fgraph.outputs:
69            mode.record.handle_line(var_descriptor(output) + '\n')
70        disturb_mem.disturb_mem()
71        f()
72
73        mode.record.f.flush()
74
75        if not replay:
76            return log.getvalue()
77
78    log = run(0)
79    # Do several trials, since failure doesn't always occur
80    # (Sometimes you sample the same outcome twice in a row)
81    for i in xrange(10):
82        run(1, log)
83
84if __name__ == '__main__':
85    test_determinism_1()
86