1from sympy.core import (pi, oo, symbols, Rational, Integer, GoldenRatio,
2                        EulerGamma, Catalan, Lambda, Dummy, S, Eq, Ne, Le,
3                        Lt, Gt, Ge)
4from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
5                             sinh, cosh, tanh, asin, acos, acosh, Max, Min)
6from sympy.testing.pytest import raises
7from sympy.printing.jscode import JavascriptCodePrinter
8from sympy.utilities.lambdify import implemented_function
9from sympy.tensor import IndexedBase, Idx
10from sympy.matrices import Matrix, MatrixSymbol
11
12from sympy import jscode
13
14x, y, z = symbols('x,y,z')
15
16
17def test_printmethod():
18    assert jscode(Abs(x)) == "Math.abs(x)"
19
20
21def test_jscode_sqrt():
22    assert jscode(sqrt(x)) == "Math.sqrt(x)"
23    assert jscode(x**0.5) == "Math.sqrt(x)"
24    assert jscode(x**(S.One/3)) == "Math.cbrt(x)"
25
26
27def test_jscode_Pow():
28    g = implemented_function('g', Lambda(x, 2*x))
29    assert jscode(x**3) == "Math.pow(x, 3)"
30    assert jscode(x**(y**3)) == "Math.pow(x, Math.pow(y, 3))"
31    assert jscode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
32        "Math.pow(3.5*2*x, -x + Math.pow(y, x))/(Math.pow(x, 2) + y)"
33    assert jscode(x**-1.0) == '1/x'
34
35
36def test_jscode_constants_mathh():
37    assert jscode(exp(1)) == "Math.E"
38    assert jscode(pi) == "Math.PI"
39    assert jscode(oo) == "Number.POSITIVE_INFINITY"
40    assert jscode(-oo) == "Number.NEGATIVE_INFINITY"
41
42
43def test_jscode_constants_other():
44    assert jscode(
45        2*GoldenRatio) == "var GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17)
46    assert jscode(2*Catalan) == "var Catalan = %s;\n2*Catalan" % Catalan.evalf(17)
47    assert jscode(
48        2*EulerGamma) == "var EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17)
49
50
51def test_jscode_Rational():
52    assert jscode(Rational(3, 7)) == "3/7"
53    assert jscode(Rational(18, 9)) == "2"
54    assert jscode(Rational(3, -7)) == "-3/7"
55    assert jscode(Rational(-3, -7)) == "3/7"
56
57
58def test_Relational():
59    assert jscode(Eq(x, y)) == "x == y"
60    assert jscode(Ne(x, y)) == "x != y"
61    assert jscode(Le(x, y)) == "x <= y"
62    assert jscode(Lt(x, y)) == "x < y"
63    assert jscode(Gt(x, y)) == "x > y"
64    assert jscode(Ge(x, y)) == "x >= y"
65
66
67
68def test_jscode_Integer():
69    assert jscode(Integer(67)) == "67"
70    assert jscode(Integer(-1)) == "-1"
71
72
73def test_jscode_functions():
74    assert jscode(sin(x) ** cos(x)) == "Math.pow(Math.sin(x), Math.cos(x))"
75    assert jscode(sinh(x) * cosh(x)) == "Math.sinh(x)*Math.cosh(x)"
76    assert jscode(Max(x, y) + Min(x, y)) == "Math.max(x, y) + Math.min(x, y)"
77    assert jscode(tanh(x)*acosh(y)) == "Math.tanh(x)*Math.acosh(y)"
78    assert jscode(asin(x)-acos(y)) == "-Math.acos(y) + Math.asin(x)"
79
80
81def test_jscode_inline_function():
82    x = symbols('x')
83    g = implemented_function('g', Lambda(x, 2*x))
84    assert jscode(g(x)) == "2*x"
85    g = implemented_function('g', Lambda(x, 2*x/Catalan))
86    assert jscode(g(x)) == "var Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17)
87    A = IndexedBase('A')
88    i = Idx('i', symbols('n', integer=True))
89    g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
90    assert jscode(g(A[i]), assign_to=A[i]) == (
91        "for (var i=0; i<n; i++){\n"
92        "   A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
93        "}"
94    )
95
96
97def test_jscode_exceptions():
98    assert jscode(ceiling(x)) == "Math.ceil(x)"
99    assert jscode(Abs(x)) == "Math.abs(x)"
100
101
102def test_jscode_boolean():
103    assert jscode(x & y) == "x && y"
104    assert jscode(x | y) == "x || y"
105    assert jscode(~x) == "!x"
106    assert jscode(x & y & z) == "x && y && z"
107    assert jscode(x | y | z) == "x || y || z"
108    assert jscode((x & y) | z) == "z || x && y"
109    assert jscode((x | y) & z) == "z && (x || y)"
110
111
112def test_jscode_Piecewise():
113    expr = Piecewise((x, x < 1), (x**2, True))
114    p = jscode(expr)
115    s = \
116"""\
117((x < 1) ? (
118   x
119)
120: (
121   Math.pow(x, 2)
122))\
123"""
124    assert p == s
125    assert jscode(expr, assign_to="c") == (
126    "if (x < 1) {\n"
127    "   c = x;\n"
128    "}\n"
129    "else {\n"
130    "   c = Math.pow(x, 2);\n"
131    "}")
132    # Check that Piecewise without a True (default) condition error
133    expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
134    raises(ValueError, lambda: jscode(expr))
135
136
137def test_jscode_Piecewise_deep():
138    p = jscode(2*Piecewise((x, x < 1), (x**2, True)))
139    s = \
140"""\
1412*((x < 1) ? (
142   x
143)
144: (
145   Math.pow(x, 2)
146))\
147"""
148    assert p == s
149
150
151def test_jscode_settings():
152    raises(TypeError, lambda: jscode(sin(x), method="garbage"))
153
154
155def test_jscode_Indexed():
156    from sympy.tensor import IndexedBase, Idx
157    from sympy import symbols
158    n, m, o = symbols('n m o', integer=True)
159    i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
160    p = JavascriptCodePrinter()
161    p._not_c = set()
162
163    x = IndexedBase('x')[j]
164    assert p._print_Indexed(x) == 'x[j]'
165    A = IndexedBase('A')[i, j]
166    assert p._print_Indexed(A) == 'A[%s]' % (m*i+j)
167    B = IndexedBase('B')[i, j, k]
168    assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k)
169
170    assert p._not_c == set()
171
172
173def test_jscode_loops_matrix_vector():
174    n, m = symbols('n m', integer=True)
175    A = IndexedBase('A')
176    x = IndexedBase('x')
177    y = IndexedBase('y')
178    i = Idx('i', m)
179    j = Idx('j', n)
180
181    s = (
182        'for (var i=0; i<m; i++){\n'
183        '   y[i] = 0;\n'
184        '}\n'
185        'for (var i=0; i<m; i++){\n'
186        '   for (var j=0; j<n; j++){\n'
187        '      y[i] = A[n*i + j]*x[j] + y[i];\n'
188        '   }\n'
189        '}'
190    )
191    c = jscode(A[i, j]*x[j], assign_to=y[i])
192    assert c == s
193
194
195def test_dummy_loops():
196    i, m = symbols('i m', integer=True, cls=Dummy)
197    x = IndexedBase('x')
198    y = IndexedBase('y')
199    i = Idx(i, m)
200
201    expected = (
202        'for (var i_%(icount)i=0; i_%(icount)i<m_%(mcount)i; i_%(icount)i++){\n'
203        '   y[i_%(icount)i] = x[i_%(icount)i];\n'
204        '}'
205    ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
206    code = jscode(x[i], assign_to=y[i])
207    assert code == expected
208
209
210def test_jscode_loops_add():
211    from sympy.tensor import IndexedBase, Idx
212    from sympy import symbols
213    n, m = symbols('n m', integer=True)
214    A = IndexedBase('A')
215    x = IndexedBase('x')
216    y = IndexedBase('y')
217    z = IndexedBase('z')
218    i = Idx('i', m)
219    j = Idx('j', n)
220
221    s = (
222        'for (var i=0; i<m; i++){\n'
223        '   y[i] = x[i] + z[i];\n'
224        '}\n'
225        'for (var i=0; i<m; i++){\n'
226        '   for (var j=0; j<n; j++){\n'
227        '      y[i] = A[n*i + j]*x[j] + y[i];\n'
228        '   }\n'
229        '}'
230    )
231    c = jscode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i])
232    assert c == s
233
234
235def test_jscode_loops_multiple_contractions():
236    from sympy.tensor import IndexedBase, Idx
237    from sympy import symbols
238    n, m, o, p = symbols('n m o p', integer=True)
239    a = IndexedBase('a')
240    b = IndexedBase('b')
241    y = IndexedBase('y')
242    i = Idx('i', m)
243    j = Idx('j', n)
244    k = Idx('k', o)
245    l = Idx('l', p)
246
247    s = (
248        'for (var i=0; i<m; i++){\n'
249        '   y[i] = 0;\n'
250        '}\n'
251        'for (var i=0; i<m; i++){\n'
252        '   for (var j=0; j<n; j++){\n'
253        '      for (var k=0; k<o; k++){\n'
254        '         for (var l=0; l<p; l++){\n'
255        '            y[i] = a[%s]*b[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
256        '         }\n'
257        '      }\n'
258        '   }\n'
259        '}'
260    )
261    c = jscode(b[j, k, l]*a[i, j, k, l], assign_to=y[i])
262    assert c == s
263
264
265def test_jscode_loops_addfactor():
266    from sympy.tensor import IndexedBase, Idx
267    from sympy import symbols
268    n, m, o, p = symbols('n m o p', integer=True)
269    a = IndexedBase('a')
270    b = IndexedBase('b')
271    c = IndexedBase('c')
272    y = IndexedBase('y')
273    i = Idx('i', m)
274    j = Idx('j', n)
275    k = Idx('k', o)
276    l = Idx('l', p)
277
278    s = (
279        'for (var i=0; i<m; i++){\n'
280        '   y[i] = 0;\n'
281        '}\n'
282        'for (var i=0; i<m; i++){\n'
283        '   for (var j=0; j<n; j++){\n'
284        '      for (var k=0; k<o; k++){\n'
285        '         for (var l=0; l<p; l++){\n'
286        '            y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
287        '         }\n'
288        '      }\n'
289        '   }\n'
290        '}'
291    )
292    c = jscode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
293    assert c == s
294
295
296def test_jscode_loops_multiple_terms():
297    from sympy.tensor import IndexedBase, Idx
298    from sympy import symbols
299    n, m, o, p = symbols('n m o p', integer=True)
300    a = IndexedBase('a')
301    b = IndexedBase('b')
302    c = IndexedBase('c')
303    y = IndexedBase('y')
304    i = Idx('i', m)
305    j = Idx('j', n)
306    k = Idx('k', o)
307
308    s0 = (
309        'for (var i=0; i<m; i++){\n'
310        '   y[i] = 0;\n'
311        '}\n'
312    )
313    s1 = (
314        'for (var i=0; i<m; i++){\n'
315        '   for (var j=0; j<n; j++){\n'
316        '      for (var k=0; k<o; k++){\n'
317        '         y[i] = b[j]*b[k]*c[%s] + y[i];\n' % (i*n*o + j*o + k) +\
318        '      }\n'
319        '   }\n'
320        '}\n'
321    )
322    s2 = (
323        'for (var i=0; i<m; i++){\n'
324        '   for (var k=0; k<o; k++){\n'
325        '      y[i] = a[%s]*b[k] + y[i];\n' % (i*o + k) +\
326        '   }\n'
327        '}\n'
328    )
329    s3 = (
330        'for (var i=0; i<m; i++){\n'
331        '   for (var j=0; j<n; j++){\n'
332        '      y[i] = a[%s]*b[j] + y[i];\n' % (i*n + j) +\
333        '   }\n'
334        '}\n'
335    )
336    c = jscode(
337        b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
338    assert (c == s0 + s1 + s2 + s3[:-1] or
339            c == s0 + s1 + s3 + s2[:-1] or
340            c == s0 + s2 + s1 + s3[:-1] or
341            c == s0 + s2 + s3 + s1[:-1] or
342            c == s0 + s3 + s1 + s2[:-1] or
343            c == s0 + s3 + s2 + s1[:-1])
344
345
346def test_Matrix_printing():
347    # Test returning a Matrix
348    mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
349    A = MatrixSymbol('A', 3, 1)
350    assert jscode(mat, A) == (
351        "A[0] = x*y;\n"
352        "if (y > 0) {\n"
353        "   A[1] = x + 2;\n"
354        "}\n"
355        "else {\n"
356        "   A[1] = y;\n"
357        "}\n"
358        "A[2] = Math.sin(z);")
359    # Test using MatrixElements in expressions
360    expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
361    assert jscode(expr) == (
362        "((x > 0) ? (\n"
363        "   2*A[2]\n"
364        ")\n"
365        ": (\n"
366        "   A[2]\n"
367        ")) + Math.sin(A[1]) + A[0]")
368    # Test using MatrixElements in a Matrix
369    q = MatrixSymbol('q', 5, 1)
370    M = MatrixSymbol('M', 3, 3)
371    m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
372        [q[1,0] + q[2,0], q[3, 0], 5],
373        [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
374    assert jscode(m, M) == (
375        "M[0] = Math.sin(q[1]);\n"
376        "M[1] = 0;\n"
377        "M[2] = Math.cos(q[2]);\n"
378        "M[3] = q[1] + q[2];\n"
379        "M[4] = q[3];\n"
380        "M[5] = 5;\n"
381        "M[6] = 2*q[4]/q[1];\n"
382        "M[7] = Math.sqrt(q[0]) + 4;\n"
383        "M[8] = 0;")
384
385
386def test_MatrixElement_printing():
387    # test cases for issue #11821
388    A = MatrixSymbol("A", 1, 3)
389    B = MatrixSymbol("B", 1, 3)
390    C = MatrixSymbol("C", 1, 3)
391
392    assert(jscode(A[0, 0]) == "A[0]")
393    assert(jscode(3 * A[0, 0]) == "3*A[0]")
394
395    F = C[0, 0].subs(C, A - B)
396    assert(jscode(F) == "(A - B)[0]")
397