1from sympy import (
2    Piecewise, lambdify, Equality, Unequality, Sum, Mod, sqrt,
3    MatrixSymbol, BlockMatrix, Identity
4)
5from sympy import eye
6from sympy.abc import x, i, j, a, b, c, d
7from sympy.core import Pow
8from sympy.codegen.matrix_nodes import MatrixSolve
9from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
10from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt
11from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \
12    PermuteDims, ArrayDiagonal
13from sympy.printing.numpy import NumPyPrinter, SciPyPrinter, _numpy_known_constants, \
14    _numpy_known_functions, _scipy_known_constants, _scipy_known_functions
15from sympy.tensor.array.expressions.conv_matrix_to_array import convert_matrix_to_array
16
17from sympy.testing.pytest import warns_deprecated_sympy
18from sympy.testing.pytest import skip, raises
19from sympy.external import import_module
20
21np = import_module('numpy')
22
23
24def test_numpy_piecewise_regression():
25    """
26    NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid
27    breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+.
28    See gh-9747 and gh-9749 for details.
29    """
30    printer = NumPyPrinter()
31    p = Piecewise((1, x < 0), (0, True))
32    assert printer.doprint(p) == \
33        'numpy.select([numpy.less(x, 0),True], [1,0], default=numpy.nan)'
34    assert printer.module_imports == {'numpy': {'select', 'less', 'nan'}}
35
36def test_numpy_logaddexp():
37    lae = logaddexp(a, b)
38    assert NumPyPrinter().doprint(lae) == 'numpy.logaddexp(a, b)'
39    lae2 = logaddexp2(a, b)
40    assert NumPyPrinter().doprint(lae2) == 'numpy.logaddexp2(a, b)'
41
42
43def test_sum():
44    if not np:
45        skip("NumPy not installed")
46
47    s = Sum(x ** i, (i, a, b))
48    f = lambdify((a, b, x), s, 'numpy')
49
50    a_, b_ = 0, 10
51    x_ = np.linspace(-1, +1, 10)
52    assert np.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1)))
53
54    s = Sum(i * x, (i, a, b))
55    f = lambdify((a, b, x), s, 'numpy')
56
57    a_, b_ = 0, 10
58    x_ = np.linspace(-1, +1, 10)
59    assert np.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1)))
60
61
62def test_multiple_sums():
63    if not np:
64        skip("NumPy not installed")
65
66    s = Sum((x + j) * i, (i, a, b), (j, c, d))
67    f = lambdify((a, b, c, d, x), s, 'numpy')
68
69    a_, b_ = 0, 10
70    c_, d_ = 11, 21
71    x_ = np.linspace(-1, +1, 10)
72    assert np.allclose(f(a_, b_, c_, d_, x_),
73                       sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1)))
74
75
76def test_codegen_einsum():
77    if not np:
78        skip("NumPy not installed")
79
80    M = MatrixSymbol("M", 2, 2)
81    N = MatrixSymbol("N", 2, 2)
82
83    cg = convert_matrix_to_array(M * N)
84    f = lambdify((M, N), cg, 'numpy')
85
86    ma = np.matrix([[1, 2], [3, 4]])
87    mb = np.matrix([[1,-2], [-1, 3]])
88    assert (f(ma, mb) == ma*mb).all()
89
90
91def test_codegen_extra():
92    if not np:
93        skip("NumPy not installed")
94
95    M = MatrixSymbol("M", 2, 2)
96    N = MatrixSymbol("N", 2, 2)
97    P = MatrixSymbol("P", 2, 2)
98    Q = MatrixSymbol("Q", 2, 2)
99    ma = np.matrix([[1, 2], [3, 4]])
100    mb = np.matrix([[1,-2], [-1, 3]])
101    mc = np.matrix([[2, 0], [1, 2]])
102    md = np.matrix([[1,-1], [4, 7]])
103
104    cg = ArrayTensorProduct(M, N)
105    f = lambdify((M, N), cg, 'numpy')
106    assert (f(ma, mb) == np.einsum(ma, [0, 1], mb, [2, 3])).all()
107
108    cg = ArrayAdd(M, N)
109    f = lambdify((M, N), cg, 'numpy')
110    assert (f(ma, mb) == ma+mb).all()
111
112    cg = ArrayAdd(M, N, P)
113    f = lambdify((M, N, P), cg, 'numpy')
114    assert (f(ma, mb, mc) == ma+mb+mc).all()
115
116    cg = ArrayAdd(M, N, P, Q)
117    f = lambdify((M, N, P, Q), cg, 'numpy')
118    assert (f(ma, mb, mc, md) == ma+mb+mc+md).all()
119
120    cg = PermuteDims(M, [1, 0])
121    f = lambdify((M,), cg, 'numpy')
122    assert (f(ma) == ma.T).all()
123
124    cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0])
125    f = lambdify((M, N), cg, 'numpy')
126    assert (f(ma, mb) == np.transpose(np.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all()
127
128    cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2))
129    f = lambdify((M, N), cg, 'numpy')
130    assert (f(ma, mb) == np.diagonal(np.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all()
131
132
133def test_relational():
134    if not np:
135        skip("NumPy not installed")
136
137    e = Equality(x, 1)
138
139    f = lambdify((x,), e)
140    x_ = np.array([0, 1, 2])
141    assert np.array_equal(f(x_), [False, True, False])
142
143    e = Unequality(x, 1)
144
145    f = lambdify((x,), e)
146    x_ = np.array([0, 1, 2])
147    assert np.array_equal(f(x_), [True, False, True])
148
149    e = (x < 1)
150
151    f = lambdify((x,), e)
152    x_ = np.array([0, 1, 2])
153    assert np.array_equal(f(x_), [True, False, False])
154
155    e = (x <= 1)
156
157    f = lambdify((x,), e)
158    x_ = np.array([0, 1, 2])
159    assert np.array_equal(f(x_), [True, True, False])
160
161    e = (x > 1)
162
163    f = lambdify((x,), e)
164    x_ = np.array([0, 1, 2])
165    assert np.array_equal(f(x_), [False, False, True])
166
167    e = (x >= 1)
168
169    f = lambdify((x,), e)
170    x_ = np.array([0, 1, 2])
171    assert np.array_equal(f(x_), [False, True, True])
172
173
174def test_mod():
175    if not np:
176        skip("NumPy not installed")
177
178    e = Mod(a, b)
179    f = lambdify((a, b), e)
180
181    a_ = np.array([0, 1, 2, 3])
182    b_ = 2
183    assert np.array_equal(f(a_, b_), [0, 1, 0, 1])
184
185    a_ = np.array([0, 1, 2, 3])
186    b_ = np.array([2, 2, 2, 2])
187    assert np.array_equal(f(a_, b_), [0, 1, 0, 1])
188
189    a_ = np.array([2, 3, 4, 5])
190    b_ = np.array([2, 3, 4, 5])
191    assert np.array_equal(f(a_, b_), [0, 0, 0, 0])
192
193
194def test_pow():
195    if not np:
196        skip('NumPy not installed')
197
198    expr = Pow(2, -1, evaluate=False)
199    f = lambdify([], expr, 'numpy')
200    assert f() == 0.5
201
202
203def test_expm1():
204    if not np:
205        skip("NumPy not installed")
206
207    f = lambdify((a,), expm1(a), 'numpy')
208    assert abs(f(1e-10) - 1e-10 - 5e-21) < 1e-22
209
210
211def test_log1p():
212    if not np:
213        skip("NumPy not installed")
214
215    f = lambdify((a,), log1p(a), 'numpy')
216    assert abs(f(1e-99) - 1e-99) < 1e-100
217
218def test_hypot():
219    if not np:
220        skip("NumPy not installed")
221    assert abs(lambdify((a, b), hypot(a, b), 'numpy')(3, 4) - 5) < 1e-16
222
223def test_log10():
224    if not np:
225        skip("NumPy not installed")
226    assert abs(lambdify((a,), log10(a), 'numpy')(100) - 2) < 1e-16
227
228
229def test_exp2():
230    if not np:
231        skip("NumPy not installed")
232    assert abs(lambdify((a,), exp2(a), 'numpy')(5) - 32) < 1e-16
233
234
235def test_log2():
236    if not np:
237        skip("NumPy not installed")
238    assert abs(lambdify((a,), log2(a), 'numpy')(256) - 8) < 1e-16
239
240
241def test_Sqrt():
242    if not np:
243        skip("NumPy not installed")
244    assert abs(lambdify((a,), Sqrt(a), 'numpy')(4) - 2) < 1e-16
245
246
247def test_sqrt():
248    if not np:
249        skip("NumPy not installed")
250    assert abs(lambdify((a,), sqrt(a), 'numpy')(4) - 2) < 1e-16
251
252
253def test_matsolve():
254    if not np:
255        skip("NumPy not installed")
256
257    M = MatrixSymbol("M", 3, 3)
258    x = MatrixSymbol("x", 3, 1)
259
260    expr = M**(-1) * x + x
261    matsolve_expr = MatrixSolve(M, x) + x
262
263    f = lambdify((M, x), expr)
264    f_matsolve = lambdify((M, x), matsolve_expr)
265
266    m0 = np.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]])
267    assert np.linalg.matrix_rank(m0) == 3
268
269    x0 = np.array([3, 4, 5])
270
271    assert np.allclose(f_matsolve(m0, x0), f(m0, x0))
272
273
274def test_issue_15601():
275    if not np:
276        skip("Numpy not installed")
277
278    M = MatrixSymbol("M", 3, 3)
279    N = MatrixSymbol("N", 3, 3)
280    expr = M*N
281    f = lambdify((M, N), expr, "numpy")
282
283    with warns_deprecated_sympy():
284        ans = f(eye(3), eye(3))
285        assert np.array_equal(ans, np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]))
286
287def test_16857():
288    if not np:
289        skip("NumPy not installed")
290
291    a_1 = MatrixSymbol('a_1', 10, 3)
292    a_2 = MatrixSymbol('a_2', 10, 3)
293    a_3 = MatrixSymbol('a_3', 10, 3)
294    a_4 = MatrixSymbol('a_4', 10, 3)
295    A = BlockMatrix([[a_1, a_2], [a_3, a_4]])
296    assert A.shape == (20, 6)
297
298    printer = NumPyPrinter()
299    assert printer.doprint(A) == 'numpy.block([[a_1, a_2], [a_3, a_4]])'
300
301
302def test_issue_17006():
303    if not np:
304        skip("NumPy not installed")
305
306    M = MatrixSymbol("M", 2, 2)
307
308    f = lambdify(M, M + Identity(2))
309    ma = np.array([[1, 2], [3, 4]])
310    mr = np.array([[2, 2], [3, 5]])
311
312    assert (f(ma) == mr).all()
313
314    from sympy import symbols
315    n = symbols('n', integer=True)
316    N = MatrixSymbol("M", n, n)
317    raises(NotImplementedError, lambda: lambdify(N, N + Identity(n)))
318
319def test_numpy_known_funcs_consts():
320    assert _numpy_known_constants['NaN'] == 'numpy.nan'
321    assert _numpy_known_constants['EulerGamma'] == 'numpy.euler_gamma'
322
323    assert _numpy_known_functions['acos'] == 'numpy.arccos'
324    assert _numpy_known_functions['log'] == 'numpy.log'
325
326def test_scipy_known_funcs_consts():
327    assert _scipy_known_constants['GoldenRatio'] == 'scipy.constants.golden_ratio'
328    assert _scipy_known_constants['Pi'] == 'scipy.constants.pi'
329
330    assert _scipy_known_functions['erf'] == 'scipy.special.erf'
331    assert _scipy_known_functions['factorial'] == 'scipy.special.factorial'
332
333def test_numpy_print_methods():
334    prntr = NumPyPrinter()
335    assert hasattr(prntr, '_print_acos')
336    assert hasattr(prntr, '_print_log')
337
338def test_scipy_print_methods():
339    prntr = SciPyPrinter()
340    assert hasattr(prntr, '_print_acos')
341    assert hasattr(prntr, '_print_log')
342    assert hasattr(prntr, '_print_erf')
343    assert hasattr(prntr, '_print_factorial')
344    assert hasattr(prntr, '_print_chebyshevt')
345