1from __future__ import absolute_import, print_function, division 2from copy import copy, deepcopy 3from functools import wraps 4import logging 5import sys 6import unittest 7from parameterized import parameterized 8from nose.tools import assert_raises 9 10from six import integer_types 11from six.moves import StringIO 12 13try: 14 from nose.plugins.attrib import attr 15except ImportError: 16 # This is an old version of nose 17 def attr(tag): 18 def func(f): 19 return f 20 return func 21import numpy as np 22 23import theano 24import theano.tensor as T 25from theano import config 26try: 27 from nose.plugins.skip import SkipTest 28except ImportError: 29 class SkipTest(Exception): 30 """ 31 Skip this test 32 """ 33_logger = logging.getLogger("theano.tests.unittest_tools") 34 35 36def custom_name_func(testcase_func, param_num, param): 37 return "%s_%s" % ( 38 testcase_func.__name__, 39 parameterized.to_safe_name("_".join(str(x) for x in param.args)), 40 ) 41 42 43def fetch_seed(pseed=None): 44 """ 45 Returns the seed to use for running the unit tests. 46 If an explicit seed is given, it will be used for seeding numpy's rng. 47 If not, it will use config.unittest.rseed (its default value is 666). 48 If config.unittest.rseed is set to "random", it will seed the rng with 49 None, which is equivalent to seeding with a random seed. 50 51 Useful for seeding RandomState objects. 52 >>> rng = np.random.RandomState(unittest_tools.fetch_seed()) 53 """ 54 55 seed = pseed or config.unittests.rseed 56 if seed == 'random': 57 seed = None 58 59 try: 60 if seed: 61 seed = int(seed) 62 else: 63 seed = None 64 except ValueError: 65 print(('Error: config.unittests.rseed contains ' 'invalid seed, using None instead'), file=sys.stderr) 66 seed = None 67 68 return seed 69 70 71def seed_rng(pseed=None): 72 """ 73 Seeds numpy's random number generator with the value returned by fetch_seed. 74 Usage: unittest_tools.seed_rng() 75 """ 76 77 seed = fetch_seed(pseed) 78 if pseed and pseed != seed: 79 print('Warning: using seed given by config.unittests.rseed=%i' 'instead of seed %i given as parameter' % (seed, pseed), file=sys.stderr) 80 np.random.seed(seed) 81 return seed 82 83 84def verify_grad(op, pt, n_tests=2, rng=None, *args, **kwargs): 85 """ 86 Wrapper for gradient.py:verify_grad 87 Takes care of seeding the random number generator if None is given 88 """ 89 if rng is None: 90 seed_rng() 91 rng = np.random 92 T.verify_grad(op, pt, n_tests, rng, *args, **kwargs) 93 94# 95# This supports the following syntax: 96# 97# try: 98# verify_grad(...) 99# except verify_grad.E_grad, e: 100# print e.num_grad.gf 101# print e.analytic_grad 102# raise 103# 104verify_grad.E_grad = T.verify_grad.E_grad 105 106 107# A helpful class to check random values close to the boundaries 108# when designing new tests 109class MockRandomState: 110 def __init__(self, val): 111 self.val = val 112 113 def rand(self, *shape): 114 return np.zeros(shape, dtype='float64') + self.val 115 116 def randint(self, minval, maxval=None, size=1): 117 if maxval is None: 118 minval, maxval = 0, minval 119 out = np.zeros(size, dtype='int64') 120 if self.val == 0: 121 return out + minval 122 else: 123 return out + maxval - 1 124 125 126class TestOptimizationMixin(object): 127 128 def assertFunctionContains(self, f, op, min=1, max=sys.maxsize): 129 toposort = f.maker.fgraph.toposort() 130 matches = [node for node in toposort if node.op == op] 131 assert (min <= len(matches) <= max), (toposort, matches, 132 str(op), len(matches), min, max) 133 134 def assertFunctionContains0(self, f, op): 135 return self.assertFunctionContains(f, op, min=0, max=0) 136 137 def assertFunctionContains1(self, f, op): 138 return self.assertFunctionContains(f, op, min=1, max=1) 139 140 def assertFunctionContainsN(self, f, op, N): 141 return self.assertFunctionContains(f, op, min=N, max=N) 142 143 def assertFunctionContainsClass(self, f, op, min=1, max=sys.maxsize): 144 toposort = f.maker.fgraph.toposort() 145 matches = [node for node in toposort if isinstance(node.op, op)] 146 assert (min <= len(matches) <= max), (toposort, matches, 147 str(op), len(matches), min, max) 148 149 def assertFunctionContainsClassN(self, f, op, N): 150 return self.assertFunctionContainsClass(f, op, min=N, max=N) 151 152 def SkipTest(self, msg='Skip this test'): 153 raise SkipTest(msg) 154 155 156# This object name should not start with Test. 157# Otherwise nosetests will execute it! 158class T_OpContractMixin(object): 159 # self.ops should be a list of instantiations of an Op class to test. 160 # self.other_op should be an op which is different from every op 161 other_op = T.add 162 163 def copy(self, x): 164 return copy(x) 165 166 def deepcopy(self, x): 167 return deepcopy(x) 168 169 def clone(self, op): 170 raise NotImplementedError('return new instance like `op`') 171 172 def test_eq(self): 173 for i, op_i in enumerate(self.ops): 174 assert op_i == op_i 175 assert op_i == self.copy(op_i) 176 assert op_i == self.deepcopy(op_i) 177 assert op_i == self.clone(op_i) 178 assert op_i != self.other_op 179 for j, op_j in enumerate(self.ops): 180 if i == j: 181 continue 182 assert op_i != op_j 183 184 def test_hash(self): 185 for i, op_i in enumerate(self.ops): 186 h_i = hash(op_i) 187 assert h_i == hash(op_i) 188 assert h_i == hash(self.copy(op_i)) 189 assert h_i == hash(self.deepcopy(op_i)) 190 assert h_i == hash(self.clone(op_i)) 191 assert h_i != hash(self.other_op) 192 for j, op_j in enumerate(self.ops): 193 if i == j: 194 continue 195 assert op_i != hash(op_j) 196 197 def test_name(self): 198 for op in self.ops: 199 s = str(op) # show that str works 200 assert s # names should not be empty 201 202 203class InferShapeTester(unittest.TestCase): 204 205 def setUp(self): 206 seed_rng() 207 # Take into account any mode that may be defined in a child class 208 # and it can be None 209 mode = getattr(self, 'mode', None) 210 if mode is None: 211 mode = theano.compile.get_default_mode() 212 # This mode seems to be the minimal one including the shape_i 213 # optimizations, if we don't want to enumerate them explicitly. 214 self.mode = mode.including("canonicalize") 215 216 def _compile_and_check(self, inputs, outputs, numeric_inputs, cls, 217 excluding=None, warn=True, check_topo=True): 218 """This tests the infer_shape method only 219 220 When testing with input values with shapes that take the same 221 value over different dimensions (for instance, a square 222 matrix, or a tensor3 with shape (n, n, n), or (m, n, m)), it 223 is not possible to detect if the output shape was computed 224 correctly, or if some shapes with the same value have been 225 mixed up. For instance, if the infer_shape uses the width of a 226 matrix instead of its height, then testing with only square 227 matrices will not detect the problem. If warn=True, we emit a 228 warning when testing with such values. 229 230 :param check_topo: If True, we check that the Op where removed 231 from the graph. False is useful to test not implemented case. 232 233 """ 234 mode = self.mode 235 if excluding: 236 mode = mode.excluding(*excluding) 237 if warn: 238 for var, inp in zip(inputs, numeric_inputs): 239 if isinstance(inp, (integer_types, float, list, tuple)): 240 inp = var.type.filter(inp) 241 if not hasattr(inp, "shape"): 242 continue 243 # remove broadcasted dims as it is sure they can't be 244 # changed to prevent the same dim problem. 245 if hasattr(var.type, "broadcastable"): 246 shp = [inp.shape[i] for i in range(inp.ndim) 247 if not var.type.broadcastable[i]] 248 else: 249 shp = inp.shape 250 if len(set(shp)) != len(shp): 251 _logger.warn( 252 "While testing shape inference for %r, we received an" 253 " input with a shape that has some repeated values: %r" 254 ", like a square matrix. This makes it impossible to" 255 " check if the values for these dimensions have been" 256 " correctly used, or if they have been mixed up.", 257 cls, inp.shape) 258 break 259 260 outputs_function = theano.function(inputs, outputs, mode=mode) 261 shapes_function = theano.function(inputs, [o.shape for o in outputs], 262 mode=mode) 263 # theano.printing.debugprint(shapes_function) 264 # Check that the Op is removed from the compiled function. 265 if check_topo: 266 topo_shape = shapes_function.maker.fgraph.toposort() 267 assert not any(isinstance(t.op, cls) for t in topo_shape) 268 topo_out = outputs_function.maker.fgraph.toposort() 269 assert any(isinstance(t.op, cls) for t in topo_out) 270 # Check that the shape produced agrees with the actual shape. 271 numeric_outputs = outputs_function(*numeric_inputs) 272 numeric_shapes = shapes_function(*numeric_inputs) 273 for out, shape in zip(numeric_outputs, numeric_shapes): 274 assert np.all(out.shape == shape), (out.shape, shape) 275 276 277def str_diagnostic(expected, value, rtol, atol): 278 """Return a pretty multiline string representating the cause 279 of the exception""" 280 sio = StringIO() 281 282 try: 283 ssio = StringIO() 284 print(" : shape, dtype, strides, min, max, n_inf, n_nan:", file=ssio) 285 print(" Expected :", end=' ', file=ssio) 286 print(expected.shape, end=' ', file=ssio) 287 print(expected.dtype, end=' ', file=ssio) 288 print(expected.strides, end=' ', file=ssio) 289 print(expected.min(), end=' ', file=ssio) 290 print(expected.max(), end=' ', file=ssio) 291 print(np.isinf(expected).sum(), end=' ', file=ssio) 292 print(np.isnan(expected).sum(), end=' ', file=ssio) 293 # only if all succeeds to we add anything to sio 294 print(ssio.getvalue(), file=sio) 295 except Exception: 296 pass 297 try: 298 ssio = StringIO() 299 print(" Value :", end=' ', file=ssio) 300 print(value.shape, end=' ', file=ssio) 301 print(value.dtype, end=' ', file=ssio) 302 print(value.strides, end=' ', file=ssio) 303 print(value.min(), end=' ', file=ssio) 304 print(value.max(), end=' ', file=ssio) 305 print(np.isinf(value).sum(), end=' ', file=ssio) 306 print(np.isnan(value).sum(), end=' ', file=ssio) 307 # only if all succeeds to we add anything to sio 308 print(ssio.getvalue(), file=sio) 309 except Exception: 310 pass 311 312 print(" expected :", expected, file=sio) 313 print(" value :", value, file=sio) 314 315 try: 316 ov = np.asarray(expected) 317 nv = np.asarray(value) 318 ssio = StringIO() 319 absdiff = np.absolute(nv - ov) 320 print(" Max Abs Diff: ", np.max(absdiff), file=ssio) 321 print(" Mean Abs Diff: ", np.mean(absdiff), file=ssio) 322 print(" Median Abs Diff: ", np.median(absdiff), file=ssio) 323 print(" Std Abs Diff: ", np.std(absdiff), file=ssio) 324 reldiff = np.absolute(nv - ov) / np.absolute(ov) 325 print(" Max Rel Diff: ", np.max(reldiff), file=ssio) 326 print(" Mean Rel Diff: ", np.mean(reldiff), file=ssio) 327 print(" Median Rel Diff: ", np.median(reldiff), file=ssio) 328 print(" Std Rel Diff: ", np.std(reldiff), file=ssio) 329 # only if all succeeds to we add anything to sio 330 print(ssio.getvalue(), file=sio) 331 except Exception: 332 pass 333 atol_, rtol_ = T.basic._get_atol_rtol(expected, value) 334 if rtol is not None: 335 rtol_ = rtol 336 if atol is not None: 337 atol_ = atol 338 print(" rtol, atol:", rtol_, atol_, file=sio) 339 return sio.getvalue() 340 341 342class WrongValue(Exception): 343 344 def __init__(self, expected_val, val, rtol, atol): 345 Exception.__init__(self) # to be compatible with python2.4 346 self.val1 = expected_val 347 self.val2 = val 348 self.rtol = rtol 349 self.atol = atol 350 351 def __str__(self): 352 s = "WrongValue\n" 353 return s + str_diagnostic(self.val1, self.val2, self.rtol, self.atol) 354 355 356def assert_allclose(expected, value, rtol=None, atol=None): 357 if not T.basic._allclose(expected, value, rtol, atol): 358 raise WrongValue(expected, value, rtol, atol) 359 360 361class AttemptManyTimes: 362 """Decorator for unit tests that forces a unit test to be attempted 363 multiple times. The test needs to pass a certain number of times for it to 364 be considered to have succeeded. If it doesn't pass enough times, it is 365 considered to have failed. 366 367 Warning : care should be exercised when using this decorator. For some 368 tests, the fact that they fail randomly could point to important issues 369 such as race conditions, usage of uninitialized memory region, etc. and 370 using this decorator could hide these problems. 371 372 Usage: 373 @AttemptManyTimes(n_attempts=5, n_req_successes=3) 374 def fct(args): 375 ... 376 """ 377 378 def __init__(self, n_attempts, n_req_successes=1): 379 assert n_attempts >= n_req_successes 380 self.n_attempts = n_attempts 381 self.n_req_successes = n_req_successes 382 383 def __call__(self, fct): 384 385 # Wrap fct in a function that will attempt to run it multiple 386 # times and return the result if the test passes enough times 387 # of propagate the raised exception if it doesn't. 388 @wraps(fct) 389 def attempt_multiple_times(*args, **kwargs): 390 391 # Keep a copy of the current seed for unittests so that we can use 392 # a different seed for every run of the decorated test and restore 393 # the original after 394 original_seed = config.unittests.rseed 395 current_seed = original_seed 396 397 # If the decorator has received only one, unnamed, argument 398 # and that argument has an attribute _testMethodName, it means 399 # that the unit test on which the decorator is used is in a test 400 # class. This means that the setup() method of that class will 401 # need to be called before any attempts to execute the test in 402 # case it relies on data randomly generated in the class' setup() 403 # method. 404 if (len(args) == 1 and hasattr(args[0], "_testMethodName")): 405 test_in_class = True 406 class_instance = args[0] 407 else: 408 test_in_class = False 409 410 n_fail = 0 411 n_success = 0 412 413 # Attempt to call the test function multiple times. If it does 414 # raise any exception for at least one attempt, it passes. If it 415 # raises an exception at every attempt, it fails. 416 for i in range(self.n_attempts): 417 try: 418 # Attempt to make the test use the current seed 419 config.unittests.rseed = current_seed 420 if test_in_class and hasattr(class_instance, "setUp"): 421 class_instance.setUp() 422 423 fct(*args, **kwargs) 424 425 n_success += 1 426 if n_success == self.n_req_successes: 427 break 428 429 except Exception: 430 n_fail += 1 431 432 # If there is not enough attempts remaining to achieve the 433 # required number of successes, propagate the original 434 # exception 435 if n_fail + self.n_req_successes > self.n_attempts: 436 raise 437 438 finally: 439 # Clean up after the test 440 config.unittests.rseed = original_seed 441 if test_in_class and hasattr(class_instance, "tearDown"): 442 class_instance.tearDown() 443 444 # Update the current_seed 445 if current_seed not in [None, "random"]: 446 current_seed = str(int(current_seed) + 1) 447 448 return attempt_multiple_times 449 450 451def assertFailure_fast(f): 452 """A Decorator to handle the test cases that are failing when 453 THEANO_FLAGS =cycle_detection='fast'. 454 """ 455 if theano.config.cycle_detection == 'fast': 456 def test_with_assert(*args, **kwargs): 457 with assert_raises(Exception): 458 f(*args, **kwargs) 459 return test_with_assert 460 else: 461 return f 462