1"""
2Important note on tests in this module - the Aesara printing functions use a
3global cache by default, which means that tests using it will modify global
4state and thus not be independent from each other. Instead of using the "cache"
5keyword argument each time, this module uses the aesara_code_ and
6aesara_function_ functions defined below which default to using a new, empty
7cache instead.
8"""
9
10import logging
11
12from sympy.external import import_module
13from sympy.testing.pytest import raises, SKIP
14
15aesaralogger = logging.getLogger('aesara.configdefaults')
16aesaralogger.setLevel(logging.CRITICAL)
17aesara = import_module('aesara')
18aesaralogger.setLevel(logging.WARNING)
19
20
21if aesara:
22    import numpy as np
23    aet = aesara.tensor
24    from aesara.scalar.basic import Scalar
25    from aesara.graph.basic import Variable
26    from aesara.tensor.var import TensorVariable
27    from aesara.tensor.elemwise import Elemwise, DimShuffle
28    from aesara.tensor.math import Dot
29
30    xt, yt, zt = [aet.scalar(name, 'floatX') for name in 'xyz']
31    Xt, Yt, Zt = [aet.tensor('floatX', (False, False), name=n) for n in 'XYZ']
32else:
33    #bin/test will not execute any tests now
34    disabled = True
35
36import sympy as sy
37from sympy import S
38from sympy.abc import x, y, z, t
39from sympy.printing.aesaracode import (aesara_code, dim_handling,
40        aesara_function)
41
42
43# Default set of matrix symbols for testing - make square so we can both
44# multiply and perform elementwise operations between them.
45X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ']
46
47# For testing AppliedUndef
48f_t = sy.Function('f')(t)
49
50
51def aesara_code_(expr, **kwargs):
52    """ Wrapper for aesara_code that uses a new, empty cache by default. """
53    kwargs.setdefault('cache', {})
54    return aesara_code(expr, **kwargs)
55
56def aesara_function_(inputs, outputs, **kwargs):
57    """ Wrapper for aesara_function that uses a new, empty cache by default. """
58    kwargs.setdefault('cache', {})
59    return aesara_function(inputs, outputs, **kwargs)
60
61
62def fgraph_of(*exprs):
63    """ Transform SymPy expressions into Aesara Computation.
64
65    Parameters
66    ==========
67    exprs
68        Sympy expressions
69
70    Returns
71    =======
72    aesara.graph.fg.FunctionGraph
73    """
74    outs = list(map(aesara_code_, exprs))
75    ins = list(aesara.graph.basic.graph_inputs(outs))
76    ins, outs = aesara.graph.basic.clone(ins, outs)
77    return aesara.graph.fg.FunctionGraph(ins, outs)
78
79
80def aesara_simplify(fgraph):
81    """ Simplify a Aesara Computation.
82
83    Parameters
84    ==========
85    fgraph : aesara.graph.fg.FunctionGraph
86
87    Returns
88    =======
89    aesara.graph.fg.FunctionGraph
90    """
91    mode = aesara.compile.get_default_mode().excluding("fusion")
92    fgraph = fgraph.clone()
93    mode.optimizer.optimize(fgraph)
94    return fgraph
95
96
97def theq(a, b):
98    """ Test two Aesara objects for equality.
99
100    Also accepts numeric types and lists/tuples of supported types.
101
102    Note - debugprint() has a bug where it will accept numeric types but does
103    not respect the "file" argument and in this case and instead prints the number
104    to stdout and returns an empty string. This can lead to tests passing where
105    they should fail because any two numbers will always compare as equal. To
106    prevent this we treat numbers as a separate case.
107    """
108    numeric_types = (int, float, np.number)
109    a_is_num = isinstance(a, numeric_types)
110    b_is_num = isinstance(b, numeric_types)
111
112    # Compare numeric types using regular equality
113    if a_is_num or b_is_num:
114        if not (a_is_num and b_is_num):
115            return False
116
117        return a == b
118
119    # Compare sequences element-wise
120    a_is_seq = isinstance(a, (tuple, list))
121    b_is_seq = isinstance(b, (tuple, list))
122
123    if a_is_seq or b_is_seq:
124        if not (a_is_seq and b_is_seq) or type(a) != type(b):
125            return False
126
127        return list(map(theq, a)) == list(map(theq, b))
128
129    # Otherwise, assume debugprint() can handle it
130    astr = aesara.printing.debugprint(a, file='str')
131    bstr = aesara.printing.debugprint(b, file='str')
132
133    # Check for bug mentioned above
134    for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]:
135        if argstr == '':
136            raise TypeError(
137                'aesara.printing.debugprint(%s) returned empty string '
138                '(%s is instance of %r)'
139                % (argname, argname, type(argval))
140            )
141
142    return astr == bstr
143
144
145def test_example_symbols():
146    """
147    Check that the example symbols in this module print to their Aesara
148    equivalents, as many of the other tests depend on this.
149    """
150    assert theq(xt, aesara_code_(x))
151    assert theq(yt, aesara_code_(y))
152    assert theq(zt, aesara_code_(z))
153    assert theq(Xt, aesara_code_(X))
154    assert theq(Yt, aesara_code_(Y))
155    assert theq(Zt, aesara_code_(Z))
156
157
158def test_Symbol():
159    """ Test printing a Symbol to a aesara variable. """
160    xx = aesara_code_(x)
161    assert isinstance(xx, Variable)
162    assert xx.broadcastable == ()
163    assert xx.name == x.name
164
165    xx2 = aesara_code_(x, broadcastables={x: (False,)})
166    assert xx2.broadcastable == (False,)
167    assert xx2.name == x.name
168
169def test_MatrixSymbol():
170    """ Test printing a MatrixSymbol to a aesara variable. """
171    XX = aesara_code_(X)
172    assert isinstance(XX, TensorVariable)
173    assert XX.broadcastable == (False, False)
174
175@SKIP  # TODO - this is currently not checked but should be implemented
176def test_MatrixSymbol_wrong_dims():
177    """ Test MatrixSymbol with invalid broadcastable. """
178    bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)]
179    for bc in bcs:
180        with raises(ValueError):
181            aesara_code_(X, broadcastables={X: bc})
182
183def test_AppliedUndef():
184    """ Test printing AppliedUndef instance, which works similarly to Symbol. """
185    ftt = aesara_code_(f_t)
186    assert isinstance(ftt, TensorVariable)
187    assert ftt.broadcastable == ()
188    assert ftt.name == 'f_t'
189
190
191def test_add():
192    expr = x + y
193    comp = aesara_code_(expr)
194    assert comp.owner.op == aesara.tensor.add
195
196def test_trig():
197    assert theq(aesara_code_(sy.sin(x)), aet.sin(xt))
198    assert theq(aesara_code_(sy.tan(x)), aet.tan(xt))
199
200def test_many():
201    """ Test printing a complex expression with multiple symbols. """
202    expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z)
203    comp = aesara_code_(expr)
204    expected = aet.exp(xt**2 + aet.cos(yt)) * aet.log(2*zt)
205    assert theq(comp, expected)
206
207
208def test_dtype():
209    """ Test specifying specific data types through the dtype argument. """
210    for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']:
211        assert aesara_code_(x, dtypes={x: dtype}).type.dtype == dtype
212
213    # "floatX" type
214    assert aesara_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64')
215
216    # Type promotion
217    assert aesara_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32'
218    assert aesara_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64'
219
220
221def test_broadcastables():
222    """ Test the "broadcastables" argument when printing symbol-like objects. """
223
224    # No restrictions on shape
225    for s in [x, f_t]:
226        for bc in [(), (False,), (True,), (False, False), (True, False)]:
227            assert aesara_code_(s, broadcastables={s: bc}).broadcastable == bc
228
229    # TODO - matrix broadcasting?
230
231def test_broadcasting():
232    """ Test "broadcastable" attribute after applying element-wise binary op. """
233
234    expr = x + y
235
236    cases = [
237        [(), (), ()],
238        [(False,), (False,), (False,)],
239        [(True,), (False,), (False,)],
240        [(False, True), (False, False), (False, False)],
241        [(True, False), (False, False), (False, False)],
242    ]
243
244    for bc1, bc2, bc3 in cases:
245        comp = aesara_code_(expr, broadcastables={x: bc1, y: bc2})
246        assert comp.broadcastable == bc3
247
248
249def test_MatMul():
250    expr = X*Y*Z
251    expr_t = aesara_code_(expr)
252    assert isinstance(expr_t.owner.op, Dot)
253    assert theq(expr_t, Xt.dot(Yt).dot(Zt))
254
255def test_Transpose():
256    assert isinstance(aesara_code_(X.T).owner.op, DimShuffle)
257
258def test_MatAdd():
259    expr = X+Y+Z
260    assert isinstance(aesara_code_(expr).owner.op, Elemwise)
261
262
263def test_Rationals():
264    assert theq(aesara_code_(sy.Integer(2) / 3), aet.true_div(2, 3))
265    assert theq(aesara_code_(S.Half), aet.true_div(1, 2))
266
267def test_Integers():
268    assert aesara_code_(sy.Integer(3)) == 3
269
270def test_factorial():
271    n = sy.Symbol('n')
272    assert aesara_code_(sy.factorial(n))
273
274def test_Derivative():
275    simp = lambda expr: aesara_simplify(fgraph_of(expr))
276    assert theq(simp(aesara_code_(sy.Derivative(sy.sin(x), x, evaluate=False))),
277                simp(aesara.grad(aet.sin(xt), xt)))
278
279
280def test_aesara_function_simple():
281    """ Test aesara_function() with single output. """
282    f = aesara_function_([x, y], [x+y])
283    assert f(2, 3) == 5
284
285def test_aesara_function_multi():
286    """ Test aesara_function() with multiple outputs. """
287    f = aesara_function_([x, y], [x+y, x-y])
288    o1, o2 = f(2, 3)
289    assert o1 == 5
290    assert o2 == -1
291
292def test_aesara_function_numpy():
293    """ Test aesara_function() vs Numpy implementation. """
294    f = aesara_function_([x, y], [x+y], dim=1,
295                         dtypes={x: 'float64', y: 'float64'})
296    assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9
297
298    f = aesara_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'},
299                         dim=1)
300    xx = np.arange(3).astype('float64')
301    yy = 2*np.arange(3).astype('float64')
302    assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9
303
304
305def test_aesara_function_matrix():
306    m = sy.Matrix([[x, y], [z, x + y + z]])
307    expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]])
308    f = aesara_function_([x, y, z], [m])
309    np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected)
310    f = aesara_function_([x, y, z], [m], scalar=True)
311    np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected)
312    f = aesara_function_([x, y, z], [m, m])
313    assert isinstance(f(1.0, 2.0, 3.0), type([]))
314    np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected)
315    np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected)
316
317def test_dim_handling():
318    assert dim_handling([x], dim=2) == {x: (False, False)}
319    assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True),
320                                                       y: (False, False)}
321    assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)}
322
323def test_aesara_function_kwargs():
324    """
325    Test passing additional kwargs from aesara_function() to aesara.function().
326    """
327    import numpy as np
328    f = aesara_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore',
329            dtypes={x: 'float64', y: 'float64', z: 'float64'})
330    assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9
331
332    f = aesara_function_([x, y, z], [x+y],
333                        dtypes={x: 'float64', y: 'float64', z: 'float64'},
334                        dim=1, on_unused_input='ignore')
335    xx = np.arange(3).astype('float64')
336    yy = 2*np.arange(3).astype('float64')
337    zz = 2*np.arange(3).astype('float64')
338    assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9
339
340def test_aesara_function_scalar():
341    """ Test the "scalar" argument to aesara_function(). """
342    from aesara.compile.function.types import Function
343
344    args = [
345        ([x, y], [x + y], None, [0]),  # Single 0d output
346        ([X, Y], [X + Y], None, [2]),  # Single 2d output
347        ([x, y], [x + y], {x: 0, y: 1}, [1]),  # Single 1d output
348        ([x, y], [x + y, x - y], None, [0, 0]),  # Two 0d outputs
349        ([x, y, X, Y], [x + y, X + Y], None, [0, 2]),  # One 0d output, one 2d
350    ]
351
352    # Create and test functions with and without the scalar setting
353    for inputs, outputs, in_dims, out_dims in args:
354        for scalar in [False, True]:
355
356            f = aesara_function_(inputs, outputs, dims=in_dims, scalar=scalar)
357
358            # Check the aesara_function attribute is set whether wrapped or not
359            assert isinstance(f.aesara_function, Function)
360
361            # Feed in inputs of the appropriate size and get outputs
362            in_values = [
363                np.ones([1 if bc else 5 for bc in i.type.broadcastable])
364                for i in f.aesara_function.input_storage
365            ]
366            out_values = f(*in_values)
367            if not isinstance(out_values, list):
368                out_values = [out_values]
369
370            # Check output types and shapes
371            assert len(out_dims) == len(out_values)
372            for d, value in zip(out_dims, out_values):
373
374                if scalar and d == 0:
375                    # Should have been converted to a scalar value
376                    assert isinstance(value, np.number)
377
378                else:
379                    # Otherwise should be an array
380                    assert isinstance(value, np.ndarray)
381                    assert value.ndim == d
382
383def test_aesara_function_bad_kwarg():
384    """
385    Passing an unknown keyword argument to aesara_function() should raise an
386    exception.
387    """
388    raises(Exception, lambda : aesara_function_([x], [x+1], foobar=3))
389
390
391def test_slice():
392    assert aesara_code_(slice(1, 2, 3)) == slice(1, 2, 3)
393
394    def theq_slice(s1, s2):
395        for attr in ['start', 'stop', 'step']:
396            a1 = getattr(s1, attr)
397            a2 = getattr(s2, attr)
398            if a1 is None or a2 is None:
399                if not (a1 is None or a2 is None):
400                    return False
401            elif not theq(a1, a2):
402                return False
403        return True
404
405    dtypes = {x: 'int32', y: 'int32'}
406    assert theq_slice(aesara_code_(slice(x, y), dtypes=dtypes), slice(xt, yt))
407    assert theq_slice(aesara_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3))
408
409def test_MatrixSlice():
410    from aesara.graph.basic import Constant
411
412    cache = {}
413
414    n = sy.Symbol('n', integer=True)
415    X = sy.MatrixSymbol('X', n, n)
416
417    Y = X[1:2:3, 4:5:6]
418    Yt = aesara_code_(Y, cache=cache)
419
420    s = Scalar('int64')
421    assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s))
422    assert Yt.owner.inputs[0] == aesara_code_(X, cache=cache)
423    # == doesn't work in Aesara like it does in SymPy. You have to use
424    # equals.
425    assert all(Yt.owner.inputs[i].equals(Constant(s, i)) for i in range(1, 7))
426
427    k = sy.Symbol('k')
428    aesara_code_(k, dtypes={k: 'int32'})
429    start, stop, step = 4, k, 2
430    Y = X[start:stop:step]
431    Yt = aesara_code_(Y, dtypes={n: 'int32', k: 'int32'})
432    # assert Yt.owner.op.idx_list[0].stop == kt
433
434def test_BlockMatrix():
435    n = sy.Symbol('n', integer=True)
436    A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD']
437    At, Bt, Ct, Dt = map(aesara_code_, (A, B, C, D))
438    Block = sy.BlockMatrix([[A, B], [C, D]])
439    Blockt = aesara_code_(Block)
440    solutions = [aet.join(0, aet.join(1, At, Bt), aet.join(1, Ct, Dt)),
441                 aet.join(1, aet.join(0, At, Ct), aet.join(0, Bt, Dt))]
442    assert any(theq(Blockt, solution) for solution in solutions)
443
444@SKIP
445def test_BlockMatrix_Inverse_execution():
446    k, n = 2, 4
447    dtype = 'float32'
448    A = sy.MatrixSymbol('A', n, k)
449    B = sy.MatrixSymbol('B', n, n)
450    inputs = A, B
451    output = B.I*A
452
453    cutsizes = {A: [(n//2, n//2), (k//2, k//2)],
454                B: [(n//2, n//2), (n//2, n//2)]}
455    cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs]
456    cutoutput = output.subs(dict(zip(inputs, cutinputs)))
457
458    dtypes = dict(zip(inputs, [dtype]*len(inputs)))
459    f = aesara_function_(inputs, [output], dtypes=dtypes, cache={})
460    fblocked = aesara_function_(inputs, [sy.block_collapse(cutoutput)],
461                                dtypes=dtypes, cache={})
462
463    ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs]
464    ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype),
465               np.eye(n).astype(dtype)]
466    ninputs[1] += np.ones(B.shape)*1e-5
467
468    assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5)
469
470def test_DenseMatrix():
471    from aesara.tensor.basic import Join
472
473    t = sy.Symbol('theta')
474    for MatrixType in [sy.Matrix, sy.ImmutableMatrix]:
475        X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]])
476        tX = aesara_code_(X)
477        assert isinstance(tX, TensorVariable)
478        assert isinstance(tX.owner.op, Join)
479
480
481def test_cache_basic():
482    """ Test single symbol-like objects are cached when printed by themselves. """
483
484    # Pairs of objects which should be considered equivalent with respect to caching
485    pairs = [
486        (x, sy.Symbol('x')),
487        (X, sy.MatrixSymbol('X', *X.shape)),
488        (f_t, sy.Function('f')(sy.Symbol('t'))),
489    ]
490
491    for s1, s2 in pairs:
492        cache = {}
493        st = aesara_code_(s1, cache=cache)
494
495        # Test hit with same instance
496        assert aesara_code_(s1, cache=cache) is st
497
498        # Test miss with same instance but new cache
499        assert aesara_code_(s1, cache={}) is not st
500
501        # Test hit with different but equivalent instance
502        assert aesara_code_(s2, cache=cache) is st
503
504def test_global_cache():
505    """ Test use of the global cache. """
506    from sympy.printing.aesaracode import global_cache
507
508    backup = dict(global_cache)
509    try:
510        # Temporarily empty global cache
511        global_cache.clear()
512
513        for s in [x, X, f_t]:
514            st = aesara_code(s)
515            assert aesara_code(s) is st
516
517    finally:
518        # Restore global cache
519        global_cache.update(backup)
520
521def test_cache_types_distinct():
522    """
523    Test that symbol-like objects of different types (Symbol, MatrixSymbol,
524    AppliedUndef) are distinguished by the cache even if they have the same
525    name.
526    """
527    symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t]
528
529    cache = {}  # Single shared cache
530    printed = {}
531
532    for s in symbols:
533        st = aesara_code_(s, cache=cache)
534        assert st not in printed.values()
535        printed[s] = st
536
537    # Check all printed objects are distinct
538    assert len(set(map(id, printed.values()))) == len(symbols)
539
540    # Check retrieving
541    for s, st in printed.items():
542        assert aesara_code(s, cache=cache) is st
543
544def test_symbols_are_created_once():
545    """
546    Test that a symbol is cached and reused when it appears in an expression
547    more than once.
548    """
549    expr = sy.Add(x, x, evaluate=False)
550    comp = aesara_code_(expr)
551
552    assert theq(comp, xt + xt)
553    assert not theq(comp, xt + aesara_code_(x))
554
555def test_cache_complex():
556    """
557    Test caching on a complicated expression with multiple symbols appearing
558    multiple times.
559    """
560    expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y)
561    symbol_names = {s.name for s in expr.free_symbols}
562    expr_t = aesara_code_(expr)
563
564    # Iterate through variables in the Aesara computational graph that the
565    # printed expression depends on
566    seen = set()
567    for v in aesara.graph.basic.ancestors([expr_t]):
568        # Owner-less, non-constant variables should be our symbols
569        if v.owner is None and not isinstance(v, aesara.graph.basic.Constant):
570            # Check it corresponds to a symbol and appears only once
571            assert v.name in symbol_names
572            assert v.name not in seen
573            seen.add(v.name)
574
575    # Check all were present
576    assert seen == symbol_names
577
578
579def test_Piecewise():
580    # A piecewise linear
581    expr = sy.Piecewise((0, x<0), (x, x<2), (1, True))  # ___/III
582    result = aesara_code_(expr)
583    assert result.owner.op == aet.switch
584
585    expected = aet.switch(xt<0, 0, aet.switch(xt<2, xt, 1))
586    assert theq(result, expected)
587
588    expr = sy.Piecewise((x, x < 0))
589    result = aesara_code_(expr)
590    expected = aet.switch(xt < 0, xt, np.nan)
591    assert theq(result, expected)
592
593    expr = sy.Piecewise((0, sy.And(x>0, x<2)), \
594        (x, sy.Or(x>2, x<0)))
595    result = aesara_code_(expr)
596    expected = aet.switch(aet.and_(xt>0,xt<2), 0, \
597        aet.switch(aet.or_(xt>2, xt<0), xt, np.nan))
598    assert theq(result, expected)
599
600
601def test_Relationals():
602    assert theq(aesara_code_(sy.Eq(x, y)), aet.eq(xt, yt))
603    # assert theq(aesara_code_(sy.Ne(x, y)), aet.neq(xt, yt))  # TODO - implement
604    assert theq(aesara_code_(x > y), xt > yt)
605    assert theq(aesara_code_(x < y), xt < yt)
606    assert theq(aesara_code_(x >= y), xt >= yt)
607    assert theq(aesara_code_(x <= y), xt <= yt)
608
609
610def test_complexfunctions():
611    xt, yt = aesara_code(x, dtypes={x:'complex128'}), aesara_code(y, dtypes={y: 'complex128'})
612    from sympy import conjugate
613    from aesara.tensor import as_tensor_variable as atv
614    from aesara.tensor import complex as cplx
615    assert theq(aesara_code(y*conjugate(x)), yt*(xt.conj()))
616    assert theq(aesara_code((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1)))
617
618
619def test_constantfunctions():
620    tf = aesara_function([],[1+1j])
621    assert(tf()==1+1j)
622