1""" 2Important note on tests in this module - the Aesara printing functions use a 3global cache by default, which means that tests using it will modify global 4state and thus not be independent from each other. Instead of using the "cache" 5keyword argument each time, this module uses the aesara_code_ and 6aesara_function_ functions defined below which default to using a new, empty 7cache instead. 8""" 9 10import logging 11 12from sympy.external import import_module 13from sympy.testing.pytest import raises, SKIP 14 15aesaralogger = logging.getLogger('aesara.configdefaults') 16aesaralogger.setLevel(logging.CRITICAL) 17aesara = import_module('aesara') 18aesaralogger.setLevel(logging.WARNING) 19 20 21if aesara: 22 import numpy as np 23 aet = aesara.tensor 24 from aesara.scalar.basic import Scalar 25 from aesara.graph.basic import Variable 26 from aesara.tensor.var import TensorVariable 27 from aesara.tensor.elemwise import Elemwise, DimShuffle 28 from aesara.tensor.math import Dot 29 30 xt, yt, zt = [aet.scalar(name, 'floatX') for name in 'xyz'] 31 Xt, Yt, Zt = [aet.tensor('floatX', (False, False), name=n) for n in 'XYZ'] 32else: 33 #bin/test will not execute any tests now 34 disabled = True 35 36import sympy as sy 37from sympy import S 38from sympy.abc import x, y, z, t 39from sympy.printing.aesaracode import (aesara_code, dim_handling, 40 aesara_function) 41 42 43# Default set of matrix symbols for testing - make square so we can both 44# multiply and perform elementwise operations between them. 45X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ'] 46 47# For testing AppliedUndef 48f_t = sy.Function('f')(t) 49 50 51def aesara_code_(expr, **kwargs): 52 """ Wrapper for aesara_code that uses a new, empty cache by default. """ 53 kwargs.setdefault('cache', {}) 54 return aesara_code(expr, **kwargs) 55 56def aesara_function_(inputs, outputs, **kwargs): 57 """ Wrapper for aesara_function that uses a new, empty cache by default. """ 58 kwargs.setdefault('cache', {}) 59 return aesara_function(inputs, outputs, **kwargs) 60 61 62def fgraph_of(*exprs): 63 """ Transform SymPy expressions into Aesara Computation. 64 65 Parameters 66 ========== 67 exprs 68 Sympy expressions 69 70 Returns 71 ======= 72 aesara.graph.fg.FunctionGraph 73 """ 74 outs = list(map(aesara_code_, exprs)) 75 ins = list(aesara.graph.basic.graph_inputs(outs)) 76 ins, outs = aesara.graph.basic.clone(ins, outs) 77 return aesara.graph.fg.FunctionGraph(ins, outs) 78 79 80def aesara_simplify(fgraph): 81 """ Simplify a Aesara Computation. 82 83 Parameters 84 ========== 85 fgraph : aesara.graph.fg.FunctionGraph 86 87 Returns 88 ======= 89 aesara.graph.fg.FunctionGraph 90 """ 91 mode = aesara.compile.get_default_mode().excluding("fusion") 92 fgraph = fgraph.clone() 93 mode.optimizer.optimize(fgraph) 94 return fgraph 95 96 97def theq(a, b): 98 """ Test two Aesara objects for equality. 99 100 Also accepts numeric types and lists/tuples of supported types. 101 102 Note - debugprint() has a bug where it will accept numeric types but does 103 not respect the "file" argument and in this case and instead prints the number 104 to stdout and returns an empty string. This can lead to tests passing where 105 they should fail because any two numbers will always compare as equal. To 106 prevent this we treat numbers as a separate case. 107 """ 108 numeric_types = (int, float, np.number) 109 a_is_num = isinstance(a, numeric_types) 110 b_is_num = isinstance(b, numeric_types) 111 112 # Compare numeric types using regular equality 113 if a_is_num or b_is_num: 114 if not (a_is_num and b_is_num): 115 return False 116 117 return a == b 118 119 # Compare sequences element-wise 120 a_is_seq = isinstance(a, (tuple, list)) 121 b_is_seq = isinstance(b, (tuple, list)) 122 123 if a_is_seq or b_is_seq: 124 if not (a_is_seq and b_is_seq) or type(a) != type(b): 125 return False 126 127 return list(map(theq, a)) == list(map(theq, b)) 128 129 # Otherwise, assume debugprint() can handle it 130 astr = aesara.printing.debugprint(a, file='str') 131 bstr = aesara.printing.debugprint(b, file='str') 132 133 # Check for bug mentioned above 134 for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]: 135 if argstr == '': 136 raise TypeError( 137 'aesara.printing.debugprint(%s) returned empty string ' 138 '(%s is instance of %r)' 139 % (argname, argname, type(argval)) 140 ) 141 142 return astr == bstr 143 144 145def test_example_symbols(): 146 """ 147 Check that the example symbols in this module print to their Aesara 148 equivalents, as many of the other tests depend on this. 149 """ 150 assert theq(xt, aesara_code_(x)) 151 assert theq(yt, aesara_code_(y)) 152 assert theq(zt, aesara_code_(z)) 153 assert theq(Xt, aesara_code_(X)) 154 assert theq(Yt, aesara_code_(Y)) 155 assert theq(Zt, aesara_code_(Z)) 156 157 158def test_Symbol(): 159 """ Test printing a Symbol to a aesara variable. """ 160 xx = aesara_code_(x) 161 assert isinstance(xx, Variable) 162 assert xx.broadcastable == () 163 assert xx.name == x.name 164 165 xx2 = aesara_code_(x, broadcastables={x: (False,)}) 166 assert xx2.broadcastable == (False,) 167 assert xx2.name == x.name 168 169def test_MatrixSymbol(): 170 """ Test printing a MatrixSymbol to a aesara variable. """ 171 XX = aesara_code_(X) 172 assert isinstance(XX, TensorVariable) 173 assert XX.broadcastable == (False, False) 174 175@SKIP # TODO - this is currently not checked but should be implemented 176def test_MatrixSymbol_wrong_dims(): 177 """ Test MatrixSymbol with invalid broadcastable. """ 178 bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)] 179 for bc in bcs: 180 with raises(ValueError): 181 aesara_code_(X, broadcastables={X: bc}) 182 183def test_AppliedUndef(): 184 """ Test printing AppliedUndef instance, which works similarly to Symbol. """ 185 ftt = aesara_code_(f_t) 186 assert isinstance(ftt, TensorVariable) 187 assert ftt.broadcastable == () 188 assert ftt.name == 'f_t' 189 190 191def test_add(): 192 expr = x + y 193 comp = aesara_code_(expr) 194 assert comp.owner.op == aesara.tensor.add 195 196def test_trig(): 197 assert theq(aesara_code_(sy.sin(x)), aet.sin(xt)) 198 assert theq(aesara_code_(sy.tan(x)), aet.tan(xt)) 199 200def test_many(): 201 """ Test printing a complex expression with multiple symbols. """ 202 expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z) 203 comp = aesara_code_(expr) 204 expected = aet.exp(xt**2 + aet.cos(yt)) * aet.log(2*zt) 205 assert theq(comp, expected) 206 207 208def test_dtype(): 209 """ Test specifying specific data types through the dtype argument. """ 210 for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']: 211 assert aesara_code_(x, dtypes={x: dtype}).type.dtype == dtype 212 213 # "floatX" type 214 assert aesara_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64') 215 216 # Type promotion 217 assert aesara_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32' 218 assert aesara_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64' 219 220 221def test_broadcastables(): 222 """ Test the "broadcastables" argument when printing symbol-like objects. """ 223 224 # No restrictions on shape 225 for s in [x, f_t]: 226 for bc in [(), (False,), (True,), (False, False), (True, False)]: 227 assert aesara_code_(s, broadcastables={s: bc}).broadcastable == bc 228 229 # TODO - matrix broadcasting? 230 231def test_broadcasting(): 232 """ Test "broadcastable" attribute after applying element-wise binary op. """ 233 234 expr = x + y 235 236 cases = [ 237 [(), (), ()], 238 [(False,), (False,), (False,)], 239 [(True,), (False,), (False,)], 240 [(False, True), (False, False), (False, False)], 241 [(True, False), (False, False), (False, False)], 242 ] 243 244 for bc1, bc2, bc3 in cases: 245 comp = aesara_code_(expr, broadcastables={x: bc1, y: bc2}) 246 assert comp.broadcastable == bc3 247 248 249def test_MatMul(): 250 expr = X*Y*Z 251 expr_t = aesara_code_(expr) 252 assert isinstance(expr_t.owner.op, Dot) 253 assert theq(expr_t, Xt.dot(Yt).dot(Zt)) 254 255def test_Transpose(): 256 assert isinstance(aesara_code_(X.T).owner.op, DimShuffle) 257 258def test_MatAdd(): 259 expr = X+Y+Z 260 assert isinstance(aesara_code_(expr).owner.op, Elemwise) 261 262 263def test_Rationals(): 264 assert theq(aesara_code_(sy.Integer(2) / 3), aet.true_div(2, 3)) 265 assert theq(aesara_code_(S.Half), aet.true_div(1, 2)) 266 267def test_Integers(): 268 assert aesara_code_(sy.Integer(3)) == 3 269 270def test_factorial(): 271 n = sy.Symbol('n') 272 assert aesara_code_(sy.factorial(n)) 273 274def test_Derivative(): 275 simp = lambda expr: aesara_simplify(fgraph_of(expr)) 276 assert theq(simp(aesara_code_(sy.Derivative(sy.sin(x), x, evaluate=False))), 277 simp(aesara.grad(aet.sin(xt), xt))) 278 279 280def test_aesara_function_simple(): 281 """ Test aesara_function() with single output. """ 282 f = aesara_function_([x, y], [x+y]) 283 assert f(2, 3) == 5 284 285def test_aesara_function_multi(): 286 """ Test aesara_function() with multiple outputs. """ 287 f = aesara_function_([x, y], [x+y, x-y]) 288 o1, o2 = f(2, 3) 289 assert o1 == 5 290 assert o2 == -1 291 292def test_aesara_function_numpy(): 293 """ Test aesara_function() vs Numpy implementation. """ 294 f = aesara_function_([x, y], [x+y], dim=1, 295 dtypes={x: 'float64', y: 'float64'}) 296 assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9 297 298 f = aesara_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'}, 299 dim=1) 300 xx = np.arange(3).astype('float64') 301 yy = 2*np.arange(3).astype('float64') 302 assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9 303 304 305def test_aesara_function_matrix(): 306 m = sy.Matrix([[x, y], [z, x + y + z]]) 307 expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]]) 308 f = aesara_function_([x, y, z], [m]) 309 np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) 310 f = aesara_function_([x, y, z], [m], scalar=True) 311 np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) 312 f = aesara_function_([x, y, z], [m, m]) 313 assert isinstance(f(1.0, 2.0, 3.0), type([])) 314 np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected) 315 np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected) 316 317def test_dim_handling(): 318 assert dim_handling([x], dim=2) == {x: (False, False)} 319 assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True), 320 y: (False, False)} 321 assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)} 322 323def test_aesara_function_kwargs(): 324 """ 325 Test passing additional kwargs from aesara_function() to aesara.function(). 326 """ 327 import numpy as np 328 f = aesara_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore', 329 dtypes={x: 'float64', y: 'float64', z: 'float64'}) 330 assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9 331 332 f = aesara_function_([x, y, z], [x+y], 333 dtypes={x: 'float64', y: 'float64', z: 'float64'}, 334 dim=1, on_unused_input='ignore') 335 xx = np.arange(3).astype('float64') 336 yy = 2*np.arange(3).astype('float64') 337 zz = 2*np.arange(3).astype('float64') 338 assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9 339 340def test_aesara_function_scalar(): 341 """ Test the "scalar" argument to aesara_function(). """ 342 from aesara.compile.function.types import Function 343 344 args = [ 345 ([x, y], [x + y], None, [0]), # Single 0d output 346 ([X, Y], [X + Y], None, [2]), # Single 2d output 347 ([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output 348 ([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs 349 ([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d 350 ] 351 352 # Create and test functions with and without the scalar setting 353 for inputs, outputs, in_dims, out_dims in args: 354 for scalar in [False, True]: 355 356 f = aesara_function_(inputs, outputs, dims=in_dims, scalar=scalar) 357 358 # Check the aesara_function attribute is set whether wrapped or not 359 assert isinstance(f.aesara_function, Function) 360 361 # Feed in inputs of the appropriate size and get outputs 362 in_values = [ 363 np.ones([1 if bc else 5 for bc in i.type.broadcastable]) 364 for i in f.aesara_function.input_storage 365 ] 366 out_values = f(*in_values) 367 if not isinstance(out_values, list): 368 out_values = [out_values] 369 370 # Check output types and shapes 371 assert len(out_dims) == len(out_values) 372 for d, value in zip(out_dims, out_values): 373 374 if scalar and d == 0: 375 # Should have been converted to a scalar value 376 assert isinstance(value, np.number) 377 378 else: 379 # Otherwise should be an array 380 assert isinstance(value, np.ndarray) 381 assert value.ndim == d 382 383def test_aesara_function_bad_kwarg(): 384 """ 385 Passing an unknown keyword argument to aesara_function() should raise an 386 exception. 387 """ 388 raises(Exception, lambda : aesara_function_([x], [x+1], foobar=3)) 389 390 391def test_slice(): 392 assert aesara_code_(slice(1, 2, 3)) == slice(1, 2, 3) 393 394 def theq_slice(s1, s2): 395 for attr in ['start', 'stop', 'step']: 396 a1 = getattr(s1, attr) 397 a2 = getattr(s2, attr) 398 if a1 is None or a2 is None: 399 if not (a1 is None or a2 is None): 400 return False 401 elif not theq(a1, a2): 402 return False 403 return True 404 405 dtypes = {x: 'int32', y: 'int32'} 406 assert theq_slice(aesara_code_(slice(x, y), dtypes=dtypes), slice(xt, yt)) 407 assert theq_slice(aesara_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3)) 408 409def test_MatrixSlice(): 410 from aesara.graph.basic import Constant 411 412 cache = {} 413 414 n = sy.Symbol('n', integer=True) 415 X = sy.MatrixSymbol('X', n, n) 416 417 Y = X[1:2:3, 4:5:6] 418 Yt = aesara_code_(Y, cache=cache) 419 420 s = Scalar('int64') 421 assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s)) 422 assert Yt.owner.inputs[0] == aesara_code_(X, cache=cache) 423 # == doesn't work in Aesara like it does in SymPy. You have to use 424 # equals. 425 assert all(Yt.owner.inputs[i].equals(Constant(s, i)) for i in range(1, 7)) 426 427 k = sy.Symbol('k') 428 aesara_code_(k, dtypes={k: 'int32'}) 429 start, stop, step = 4, k, 2 430 Y = X[start:stop:step] 431 Yt = aesara_code_(Y, dtypes={n: 'int32', k: 'int32'}) 432 # assert Yt.owner.op.idx_list[0].stop == kt 433 434def test_BlockMatrix(): 435 n = sy.Symbol('n', integer=True) 436 A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD'] 437 At, Bt, Ct, Dt = map(aesara_code_, (A, B, C, D)) 438 Block = sy.BlockMatrix([[A, B], [C, D]]) 439 Blockt = aesara_code_(Block) 440 solutions = [aet.join(0, aet.join(1, At, Bt), aet.join(1, Ct, Dt)), 441 aet.join(1, aet.join(0, At, Ct), aet.join(0, Bt, Dt))] 442 assert any(theq(Blockt, solution) for solution in solutions) 443 444@SKIP 445def test_BlockMatrix_Inverse_execution(): 446 k, n = 2, 4 447 dtype = 'float32' 448 A = sy.MatrixSymbol('A', n, k) 449 B = sy.MatrixSymbol('B', n, n) 450 inputs = A, B 451 output = B.I*A 452 453 cutsizes = {A: [(n//2, n//2), (k//2, k//2)], 454 B: [(n//2, n//2), (n//2, n//2)]} 455 cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs] 456 cutoutput = output.subs(dict(zip(inputs, cutinputs))) 457 458 dtypes = dict(zip(inputs, [dtype]*len(inputs))) 459 f = aesara_function_(inputs, [output], dtypes=dtypes, cache={}) 460 fblocked = aesara_function_(inputs, [sy.block_collapse(cutoutput)], 461 dtypes=dtypes, cache={}) 462 463 ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs] 464 ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype), 465 np.eye(n).astype(dtype)] 466 ninputs[1] += np.ones(B.shape)*1e-5 467 468 assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5) 469 470def test_DenseMatrix(): 471 from aesara.tensor.basic import Join 472 473 t = sy.Symbol('theta') 474 for MatrixType in [sy.Matrix, sy.ImmutableMatrix]: 475 X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]]) 476 tX = aesara_code_(X) 477 assert isinstance(tX, TensorVariable) 478 assert isinstance(tX.owner.op, Join) 479 480 481def test_cache_basic(): 482 """ Test single symbol-like objects are cached when printed by themselves. """ 483 484 # Pairs of objects which should be considered equivalent with respect to caching 485 pairs = [ 486 (x, sy.Symbol('x')), 487 (X, sy.MatrixSymbol('X', *X.shape)), 488 (f_t, sy.Function('f')(sy.Symbol('t'))), 489 ] 490 491 for s1, s2 in pairs: 492 cache = {} 493 st = aesara_code_(s1, cache=cache) 494 495 # Test hit with same instance 496 assert aesara_code_(s1, cache=cache) is st 497 498 # Test miss with same instance but new cache 499 assert aesara_code_(s1, cache={}) is not st 500 501 # Test hit with different but equivalent instance 502 assert aesara_code_(s2, cache=cache) is st 503 504def test_global_cache(): 505 """ Test use of the global cache. """ 506 from sympy.printing.aesaracode import global_cache 507 508 backup = dict(global_cache) 509 try: 510 # Temporarily empty global cache 511 global_cache.clear() 512 513 for s in [x, X, f_t]: 514 st = aesara_code(s) 515 assert aesara_code(s) is st 516 517 finally: 518 # Restore global cache 519 global_cache.update(backup) 520 521def test_cache_types_distinct(): 522 """ 523 Test that symbol-like objects of different types (Symbol, MatrixSymbol, 524 AppliedUndef) are distinguished by the cache even if they have the same 525 name. 526 """ 527 symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t] 528 529 cache = {} # Single shared cache 530 printed = {} 531 532 for s in symbols: 533 st = aesara_code_(s, cache=cache) 534 assert st not in printed.values() 535 printed[s] = st 536 537 # Check all printed objects are distinct 538 assert len(set(map(id, printed.values()))) == len(symbols) 539 540 # Check retrieving 541 for s, st in printed.items(): 542 assert aesara_code(s, cache=cache) is st 543 544def test_symbols_are_created_once(): 545 """ 546 Test that a symbol is cached and reused when it appears in an expression 547 more than once. 548 """ 549 expr = sy.Add(x, x, evaluate=False) 550 comp = aesara_code_(expr) 551 552 assert theq(comp, xt + xt) 553 assert not theq(comp, xt + aesara_code_(x)) 554 555def test_cache_complex(): 556 """ 557 Test caching on a complicated expression with multiple symbols appearing 558 multiple times. 559 """ 560 expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y) 561 symbol_names = {s.name for s in expr.free_symbols} 562 expr_t = aesara_code_(expr) 563 564 # Iterate through variables in the Aesara computational graph that the 565 # printed expression depends on 566 seen = set() 567 for v in aesara.graph.basic.ancestors([expr_t]): 568 # Owner-less, non-constant variables should be our symbols 569 if v.owner is None and not isinstance(v, aesara.graph.basic.Constant): 570 # Check it corresponds to a symbol and appears only once 571 assert v.name in symbol_names 572 assert v.name not in seen 573 seen.add(v.name) 574 575 # Check all were present 576 assert seen == symbol_names 577 578 579def test_Piecewise(): 580 # A piecewise linear 581 expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III 582 result = aesara_code_(expr) 583 assert result.owner.op == aet.switch 584 585 expected = aet.switch(xt<0, 0, aet.switch(xt<2, xt, 1)) 586 assert theq(result, expected) 587 588 expr = sy.Piecewise((x, x < 0)) 589 result = aesara_code_(expr) 590 expected = aet.switch(xt < 0, xt, np.nan) 591 assert theq(result, expected) 592 593 expr = sy.Piecewise((0, sy.And(x>0, x<2)), \ 594 (x, sy.Or(x>2, x<0))) 595 result = aesara_code_(expr) 596 expected = aet.switch(aet.and_(xt>0,xt<2), 0, \ 597 aet.switch(aet.or_(xt>2, xt<0), xt, np.nan)) 598 assert theq(result, expected) 599 600 601def test_Relationals(): 602 assert theq(aesara_code_(sy.Eq(x, y)), aet.eq(xt, yt)) 603 # assert theq(aesara_code_(sy.Ne(x, y)), aet.neq(xt, yt)) # TODO - implement 604 assert theq(aesara_code_(x > y), xt > yt) 605 assert theq(aesara_code_(x < y), xt < yt) 606 assert theq(aesara_code_(x >= y), xt >= yt) 607 assert theq(aesara_code_(x <= y), xt <= yt) 608 609 610def test_complexfunctions(): 611 xt, yt = aesara_code(x, dtypes={x:'complex128'}), aesara_code(y, dtypes={y: 'complex128'}) 612 from sympy import conjugate 613 from aesara.tensor import as_tensor_variable as atv 614 from aesara.tensor import complex as cplx 615 assert theq(aesara_code(y*conjugate(x)), yt*(xt.conj())) 616 assert theq(aesara_code((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1))) 617 618 619def test_constantfunctions(): 620 tf = aesara_function([],[1+1j]) 621 assert(tf()==1+1j) 622