1from sympy import symbols, sin, Matrix, Interval, Piecewise, Sum, lambdify, \
2                  Expr, sqrt
3from sympy.testing.pytest import raises
4
5from sympy.printing.tensorflow import TensorflowPrinter
6from sympy.printing.lambdarepr import lambdarepr, LambdaPrinter, NumExprPrinter
7
8
9x, y, z = symbols("x,y,z")
10i, a, b = symbols("i,a,b")
11j, c, d = symbols("j,c,d")
12
13
14def test_basic():
15    assert lambdarepr(x*y) == "x*y"
16    assert lambdarepr(x + y) in ["y + x", "x + y"]
17    assert lambdarepr(x**y) == "x**y"
18
19
20def test_matrix():
21    A = Matrix([[x, y], [y*x, z**2]])
22    # assert lambdarepr(A) == "ImmutableDenseMatrix([[x, y], [x*y, z**2]])"
23    # Test printing a Matrix that has an element that is printed differently
24    # with the LambdaPrinter than in the StrPrinter.
25    p = Piecewise((x, True), evaluate=False)
26    A = Matrix([p])
27    assert lambdarepr(A) == "ImmutableDenseMatrix([[((x))]])"
28
29
30def test_piecewise():
31    # In each case, test eval() the lambdarepr() to make sure there are a
32    # correct number of parentheses. It will give a SyntaxError if there aren't.
33
34    h = "lambda x: "
35
36    p = Piecewise((x, True), evaluate=False)
37    l = lambdarepr(p)
38    eval(h + l)
39    assert l == "((x))"
40
41    p = Piecewise((x, x < 0))
42    l = lambdarepr(p)
43    eval(h + l)
44    assert l == "((x) if (x < 0) else None)"
45
46    p = Piecewise(
47        (1, x < 1),
48        (2, x < 2),
49        (0, True)
50    )
51    l = lambdarepr(p)
52    eval(h + l)
53    assert l == "((1) if (x < 1) else (2) if (x < 2) else (0))"
54
55    p = Piecewise(
56        (1, x < 1),
57        (2, x < 2),
58    )
59    l = lambdarepr(p)
60    eval(h + l)
61    assert l == "((1) if (x < 1) else (2) if (x < 2) else None)"
62
63    p = Piecewise(
64        (x, x < 1),
65        (x**2, Interval(3, 4, True, False).contains(x)),
66        (0, True),
67    )
68    l = lambdarepr(p)
69    eval(h + l)
70    assert l == "((x) if (x < 1) else (x**2) if (((x <= 4)) and ((x > 3))) else (0))"
71
72    p = Piecewise(
73        (x**2, x < 0),
74        (x, x < 1),
75        (2 - x, x >= 1),
76        (0, True), evaluate=False
77    )
78    l = lambdarepr(p)
79    eval(h + l)
80    assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\
81                                " else (2 - x) if (x >= 1) else (0))"
82
83    p = Piecewise(
84        (x**2, x < 0),
85        (x, x < 1),
86        (2 - x, x >= 1), evaluate=False
87    )
88    l = lambdarepr(p)
89    eval(h + l)
90    assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\
91                    " else (2 - x) if (x >= 1) else None)"
92
93    p = Piecewise(
94        (1, x >= 1),
95        (2, x >= 2),
96        (3, x >= 3),
97        (4, x >= 4),
98        (5, x >= 5),
99        (6, True)
100    )
101    l = lambdarepr(p)
102    eval(h + l)
103    assert l == "((1) if (x >= 1) else (2) if (x >= 2) else (3) if (x >= 3)"\
104                        " else (4) if (x >= 4) else (5) if (x >= 5) else (6))"
105
106    p = Piecewise(
107        (1, x <= 1),
108        (2, x <= 2),
109        (3, x <= 3),
110        (4, x <= 4),
111        (5, x <= 5),
112        (6, True)
113    )
114    l = lambdarepr(p)
115    eval(h + l)
116    assert l == "((1) if (x <= 1) else (2) if (x <= 2) else (3) if (x <= 3)"\
117                            " else (4) if (x <= 4) else (5) if (x <= 5) else (6))"
118
119    p = Piecewise(
120        (1, x > 1),
121        (2, x > 2),
122        (3, x > 3),
123        (4, x > 4),
124        (5, x > 5),
125        (6, True)
126    )
127    l = lambdarepr(p)
128    eval(h + l)
129    assert l =="((1) if (x > 1) else (2) if (x > 2) else (3) if (x > 3)"\
130                            " else (4) if (x > 4) else (5) if (x > 5) else (6))"
131
132    p = Piecewise(
133        (1, x < 1),
134        (2, x < 2),
135        (3, x < 3),
136        (4, x < 4),
137        (5, x < 5),
138        (6, True)
139    )
140    l = lambdarepr(p)
141    eval(h + l)
142    assert l == "((1) if (x < 1) else (2) if (x < 2) else (3) if (x < 3)"\
143                            " else (4) if (x < 4) else (5) if (x < 5) else (6))"
144
145    p = Piecewise(
146        (Piecewise(
147            (1, x > 0),
148            (2, True)
149        ), y > 0),
150        (3, True)
151    )
152    l = lambdarepr(p)
153    eval(h + l)
154    assert l == "((((1) if (x > 0) else (2))) if (y > 0) else (3))"
155
156
157def test_sum__1():
158    # In each case, test eval() the lambdarepr() to make sure that
159    # it evaluates to the same results as the symbolic expression
160    s = Sum(x ** i, (i, a, b))
161    l = lambdarepr(s)
162    assert l == "(builtins.sum(x**i for i in range(a, b+1)))"
163
164    args = x, a, b
165    f = lambdify(args, s)
166    v = 2, 3, 8
167    assert f(*v) == s.subs(zip(args, v)).doit()
168
169def test_sum__2():
170    s = Sum(i * x, (i, a, b))
171    l = lambdarepr(s)
172    assert l == "(builtins.sum(i*x for i in range(a, b+1)))"
173
174    args = x, a, b
175    f = lambdify(args, s)
176    v = 2, 3, 8
177    assert f(*v) == s.subs(zip(args, v)).doit()
178
179
180def test_multiple_sums():
181    s = Sum(i * x + j, (i, a, b), (j, c, d))
182
183    l = lambdarepr(s)
184    assert l == "(builtins.sum(i*x + j for i in range(a, b+1) for j in range(c, d+1)))"
185
186    args = x, a, b, c, d
187    f = lambdify(args, s)
188    vals = 2, 3, 4, 5, 6
189    f_ref = s.subs(zip(args, vals)).doit()
190    f_res = f(*vals)
191    assert f_res == f_ref
192
193
194def test_sqrt():
195    prntr = LambdaPrinter({'standard' : 'python2'})
196    assert prntr._print_Pow(sqrt(x), rational=False) == 'sqrt(x)'
197    assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1./2.)'
198    prntr = LambdaPrinter({'standard' : 'python3'})
199    assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
200
201
202def test_settings():
203    raises(TypeError, lambda: lambdarepr(sin(x), method="garbage"))
204
205
206def test_numexpr():
207    # test ITE rewrite as Piecewise
208    from sympy.logic.boolalg import ITE
209    expr = ITE(x > 0, True, False, evaluate=False)
210    assert NumExprPrinter().doprint(expr) == \
211           "evaluate('where((x > 0), True, False)', truediv=True)"
212
213
214class CustomPrintedObject(Expr):
215    def _lambdacode(self, printer):
216        return 'lambda'
217
218    def _tensorflowcode(self, printer):
219        return 'tensorflow'
220
221    def _numpycode(self, printer):
222        return 'numpy'
223
224    def _numexprcode(self, printer):
225        return 'numexpr'
226
227    def _mpmathcode(self, printer):
228        return 'mpmath'
229
230
231def test_printmethod():
232    # In each case, printmethod is called to test
233    # its working
234
235    obj = CustomPrintedObject()
236    assert LambdaPrinter().doprint(obj) == 'lambda'
237    assert TensorflowPrinter().doprint(obj) == 'tensorflow'
238    assert NumExprPrinter().doprint(obj) == "evaluate('numexpr', truediv=True)"
239
240    assert NumExprPrinter().doprint(Piecewise((y, x >= 0), (z, x < 0))) == \
241            "evaluate('where((x >= 0), y, z)', truediv=True)"
242