1from __future__ import absolute_import, print_function, division 2from collections import OrderedDict 3from theano.tests.record import RecordMode, Record 4from theano.tests import disturb_mem 5import numpy as np 6import theano 7from theano.printing import var_descriptor 8 9from theano import config, shared 10from six import StringIO 11from six.moves import xrange 12 13 14def sharedX(x, name=None): 15 x = np.cast[config.floatX](x) 16 return shared(x, name) 17 18 19def test_determinism_1(): 20 21 # Tests that repeatedly running a script that compiles and 22 # runs a function does exactly the same thing every time it 23 # is run, even when the memory addresses of the objects involved 24 # change. 25 # This specific script is capable of catching a bug where 26 # FunctionGraph.toposort was non-deterministic. 27 28 def run(replay, log=None): 29 30 if not replay: 31 log = StringIO() 32 else: 33 log = StringIO(log) 34 record = Record(replay=replay, file_object=log) 35 36 disturb_mem.disturb_mem() 37 38 mode = RecordMode(record=record) 39 40 b = sharedX(np.zeros((2,)), name='b') 41 channels = OrderedDict() 42 43 disturb_mem.disturb_mem() 44 45 v_max = b.max(axis=0) 46 v_min = b.min(axis=0) 47 v_range = v_max - v_min 48 49 updates = [] 50 for i, val in enumerate([ 51 v_max.max(), 52 v_max.min(), 53 v_range.max(), 54 ]): 55 disturb_mem.disturb_mem() 56 s = sharedX(0., name='s_' + str(i)) 57 updates.append((s, val)) 58 59 for var in theano.gof.graph.ancestors(update for _, update in updates): 60 if var.name is not None and var.name != 'b': 61 if var.name[0] != 's' or len(var.name) != 2: 62 var.name = None 63 64 for key in channels: 65 updates.append((s, channels[key])) 66 f = theano.function([], mode=mode, updates=updates, 67 on_unused_input='ignore', name='f') 68 for output in f.maker.fgraph.outputs: 69 mode.record.handle_line(var_descriptor(output) + '\n') 70 disturb_mem.disturb_mem() 71 f() 72 73 mode.record.f.flush() 74 75 if not replay: 76 return log.getvalue() 77 78 log = run(0) 79 # Do several trials, since failure doesn't always occur 80 # (Sometimes you sample the same outcome twice in a row) 81 for i in xrange(10): 82 run(1, log) 83 84if __name__ == '__main__': 85 test_determinism_1() 86