1import itertools
2import math
3
4import mpmath
5import pytest
6
7import diofant
8from diofant import (ITE, And, Float, Function, I, Integral, Lambda, Matrix,
9                     Max, Min, Not, Or, Piecewise, Rational, cos, exp, false,
10                     lambdify, oo, pi, sin, sqrt, symbols, true)
11from diofant.abc import t, w, x, y, z
12from diofant.external import import_module
13from diofant.printing.lambdarepr import LambdaPrinter
14from diofant.utilities.decorator import conserve_mpmath_dps
15from diofant.utilities.lambdify import (MATH_TRANSLATIONS, MPMATH_TRANSLATIONS,
16                                        NUMPY_TRANSLATIONS, _get_namespace,
17                                        implemented_function, lambdastr)
18
19
20__all__ = ()
21
22MutableDenseMatrix = Matrix
23
24numpy = import_module('numpy')
25with_numpy = pytest.mark.skipif(numpy is None,
26                                reason="Couldn't import numpy.")
27
28# ================= Test different arguments =======================
29
30
31def test_no_args():
32    f = lambdify([], 1)
33    pytest.raises(TypeError, lambda: f(-1))
34    assert f() == 1
35
36
37def test_single_arg():
38    f = lambdify(x, 2*x)
39    assert f(1) == 2
40
41
42def test_list_args():
43    f = lambdify([x, y], x + y)
44    assert f(1, 2) == 3
45
46
47def test_nested_args():
48    # issue sympy/sympy#2790
49    assert lambdify((x, (y, z)), x + y)(1, (2, 4)) == 3
50    assert lambdify((x, (y, (w, z))), w + x + y + z)(1, (2, (3, 4))) == 10
51    assert lambdify(x, x + 1, dummify=False)(1) == 2
52
53
54def test_str_args():
55    f = lambdify('x,y,z', 'z,y,x')
56    assert f(3, 2, 1) == (1, 2, 3)
57    assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0)
58    # make sure correct number of args required
59    pytest.raises(TypeError, lambda: f(0))
60
61
62def test_own_namespace():
63    def myfunc(x):
64        return 1
65    f = lambdify(x, sin(x), {'sin': myfunc})
66    assert f(0.1) == 1
67    assert f(100) == 1
68
69
70def test_own_module():
71    f = lambdify(x, sin(x), math)
72    assert f(0) == 0.0
73
74
75def test_bad_args():
76    # no vargs given
77    pytest.raises(TypeError, lambda: lambdify(1))
78    # same with vector exprs
79    pytest.raises(TypeError, lambda: lambdify([1, 2]))
80    # reserved name
81    pytest.raises(ValueError, lambda: lambdify((('__flatten_args__',),), 1))
82
83    pytest.raises(NameError, lambda: lambdify(x, 1, 'spam'))
84
85
86def test__get_namespace():
87    pytest.raises(TypeError, lambda: _get_namespace(1))
88
89
90def test_lambdastr():
91    assert lambdastr(x, x**2) == 'lambda x: (x**2)'
92    assert lambdastr(x, None, dummify=True).find('None') > 0
93
94
95def test_atoms():
96    # Non-Symbol atoms should not be pulled out from the expression namespace
97    f = lambdify(x, pi + x, {'pi': 3.14})
98    assert f(0) == 3.14
99    f = lambdify(x, I + x, {'I': 1j})
100    assert f(1) == 1 + 1j
101
102# ================= Test different modules =========================
103
104# high precision output of sin(0.2*pi) is used to detect if precision is lost unwanted
105
106
107@conserve_mpmath_dps
108def test_diofant_lambda():
109    mpmath.mp.dps = 50
110    sin02 = mpmath.mpf('0.19866933079506121545941262711838975037020672954020')
111    f = lambdify(x, sin(x), 'diofant')
112    assert f(x) == sin(x)
113    prec = 1e-15
114    assert -prec < f(Rational(1, 5)).evalf() - Float(str(sin02)) < prec
115
116
117@conserve_mpmath_dps
118def test_math_lambda():
119    mpmath.mp.dps = 50
120    sin02 = mpmath.mpf('0.19866933079506121545941262711838975037020672954020')
121    f = lambdify(x, sin(x), 'math')
122    prec = 1e-15
123    assert -prec < f(0.2) - sin02 < prec
124
125    # if this succeeds, it can't be a python math function
126    pytest.raises(TypeError, lambda: f(x))
127
128
129@conserve_mpmath_dps
130def test_mpmath_lambda():
131    mpmath.mp.dps = 50
132    sin02 = mpmath.mpf('0.19866933079506121545941262711838975037020672954020')
133    f = lambdify(x, sin(x), 'mpmath')
134    prec = 1e-49  # mpmath precision is around 50 decimal places
135    assert -prec < f(mpmath.mpf('0.2')) - sin02 < prec
136
137    # if this succeeds, it can't be a mpmath function
138    pytest.raises(TypeError, lambda: f(x))
139
140
141@conserve_mpmath_dps
142def test_number_precision():
143    mpmath.mp.dps = 50
144    sin02 = mpmath.mpf('0.19866933079506121545941262711838975037020672954020')
145    f = lambdify(x, sin02, 'mpmath')
146    prec = 1e-49  # mpmath precision is around 50 decimal places
147    assert -prec < f(0) - sin02 < prec
148
149
150@conserve_mpmath_dps
151def test_mpmath_precision():
152    mpmath.mp.dps = 100
153    assert str(lambdify((), pi.evalf(100), 'mpmath')()) == str(pi.evalf(100))
154
155
156# ================= Test Translations ==============================
157# We can only check if all translated functions are valid. It has to be checked
158# by hand if they are complete.
159
160
161def test_math_transl():
162    for sym, mat in MATH_TRANSLATIONS.items():
163        assert sym in diofant.__dict__
164        assert mat in math.__dict__
165
166
167def test_mpmath_transl():
168    for sym, mat in MPMATH_TRANSLATIONS.items():
169        assert sym in diofant.__dict__ or sym == 'Matrix'
170        assert mat in mpmath.__dict__
171
172
173@with_numpy
174def test_numpy_transl():
175    for sym, nump in NUMPY_TRANSLATIONS.items():
176        assert sym in diofant.__dict__
177        assert nump in numpy.__dict__
178
179
180@with_numpy
181def test_numpy_translation_abs():
182    f = lambdify(x, abs(x), 'numpy')
183    assert f(-1) == 1
184    assert f(1) == 1
185
186
187# ================= Test some functions ============================
188
189
190def test_exponentiation():
191    f = lambdify(x, x**2)
192    assert f(-1) == 1
193    assert f(0) == 0
194    assert f(1) == 1
195    assert f(-2) == 4
196    assert f(2) == 4
197    assert f(2.5) == 6.25
198
199
200def test_sqrt():
201    f = lambdify(x, sqrt(x))
202    assert f(0) == 0.0
203    assert f(1) == 1.0
204    assert f(4) == 2.0
205    assert abs(f(2) - 1.414) < 0.001
206    assert f(6.25) == 2.5
207
208
209def test_trig():
210    f = lambdify([x], [cos(x), sin(x)], 'math')
211    d = f(pi)
212    prec = 1e-11
213    assert -prec < d[0] + 1 < prec
214    assert -prec < d[1] < prec
215    d = f(3.14159)
216    prec = 1e-5
217    assert -prec < d[0] + 1 < prec
218    assert -prec < d[1] < prec
219
220# ================= Test vectors ===================================
221
222
223def test_vector_simple():
224    f = lambdify((x, y, z), (z, y, x))
225    assert f(3, 2, 1) == (1, 2, 3)
226    assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0)
227    # make sure correct number of args required
228    pytest.raises(TypeError, lambda: f(0))
229
230
231def test_vector_discontinuous():
232    f = lambdify(x, (-1/x, 1/x))
233    pytest.raises(ZeroDivisionError, lambda: f(0))
234    assert f(1) == (-1.0, 1.0)
235    assert f(2) == (-0.5, 0.5)
236    assert f(-2) == (0.5, -0.5)
237
238
239def test_trig_symbolic():
240    f = lambdify([x], [cos(x), sin(x)], 'math')
241    d = f(pi)
242    assert abs(d[0] + 1) < 0.0001
243    assert abs(d[1] - 0) < 0.0001
244
245
246def test_trig_float():
247    f = lambdify([x], [cos(x), sin(x)])
248    d = f(3.14159)
249    assert abs(d[0] + 1) < 0.0001
250    assert abs(d[1] - 0) < 0.0001
251
252
253def test_docs():
254    f = lambdify(x, x**2)
255    assert f(2) == 4
256    f = lambdify([x, y, z], [z, y, x])
257    assert f(1, 2, 3) == [3, 2, 1]
258    f = lambdify(x, sqrt(x))
259    assert f(4) == 2.0
260    f = lambdify((x, y), sin(x*y)**2)
261    assert f(0, 5) == 0
262
263
264def test_math():
265    f = lambdify((x, y), sin(x), modules='math')
266    assert f(0, 5) == 0
267
268
269def test_sin():
270    f = lambdify(x, sin(x)**2)
271    assert isinstance(f(2), float)
272    f = lambdify(x, sin(x)**2, modules='math')
273    assert isinstance(f(2), float)
274
275
276def test_matrix():
277    A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
278    sol = Matrix([[1, 2], [sin(3) + 4, 1]])
279    f = lambdify((x, y, z), A, modules='diofant')
280    assert f(1, 2, 3) == sol
281    f = lambdify((x, y, z), (A, [A]), modules='diofant')
282    assert f(1, 2, 3) == (sol, [sol])
283    J = Matrix((x, x + y)).jacobian((x, y))
284    v = Matrix((x, y))
285    sol = Matrix([[1, 0], [1, 1]])
286    assert lambdify(v, J, modules='diofant')(1, 2) == sol
287    assert lambdify(v.T, J, modules='diofant')(1, 2) == sol
288
289
290@with_numpy
291def test_numpy_matrix():
292    A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
293    sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]])
294    # Lambdify array first, to ensure return to array as default
295    f = lambdify((x, y, z), A, ['numpy'])
296    numpy.testing.assert_allclose(f(1, 2, 3), sol_arr)
297    # Check that the types are arrays and matrices
298    assert isinstance(f(1, 2, 3), numpy.ndarray)
299
300
301@with_numpy
302def test_numpy_transpose():
303    A = Matrix([[1, x], [0, 1]])
304    f = lambdify(x, A.T, modules='numpy')
305    numpy.testing.assert_array_equal(f(2), numpy.array([[1, 0], [2, 1]]))
306
307
308@with_numpy
309def test_numpy_inverse():
310    A = Matrix([[1, x], [0, 1]])
311    f = lambdify(x, A**-1, modules='numpy')
312    numpy.testing.assert_array_equal(f(2), numpy.array([[1, -2], [0,  1]]))
313
314
315@with_numpy
316def test_numpy_old_matrix():
317    A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
318    sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]])
319    f = lambdify((x, y, z), A, [{'ImmutableMatrix': numpy.array}, 'numpy'])
320    numpy.testing.assert_allclose(f(1, 2, 3), sol_arr)
321    assert isinstance(f(1, 2, 3), numpy.ndarray)
322
323
324@with_numpy
325@pytest.mark.filterwarnings('ignore::RuntimeWarning')
326def test_python_div_zero_sympyissue_11306():
327    p = Piecewise((1 / x, y < -1), (x, y <= 1), (1 / x, True))
328    lambdify([x, y], p, modules='numpy')(0, 1)
329
330
331@with_numpy
332def test_numpy_piecewise():
333    pieces = Piecewise((x, x < 3), (x**2, x > 5), (0, True))
334    f = lambdify(x, pieces, modules='numpy')
335    numpy.testing.assert_array_equal(f(numpy.arange(10)),
336                                     numpy.array([0, 1, 2, 0, 0, 0, 36, 49, 64, 81]))
337    # If we evaluate somewhere all conditions are False, we should get back NaN
338    nodef_func = lambdify(x, Piecewise((x, x > 0), (-x, x < 0)))
339    numpy.testing.assert_array_equal(nodef_func(numpy.array([-1, 0, 1])),
340                                     numpy.array([1, numpy.nan, 1]))
341
342
343@with_numpy
344def test_numpy_logical_ops():
345    and_func = lambdify((x, y), And(x, y), modules='numpy')
346    or_func = lambdify((x, y), Or(x, y), modules='numpy')
347    not_func = lambdify(x, Not(x), modules='numpy')
348    arr1 = numpy.array([True, True])
349    arr2 = numpy.array([False, True])
350    numpy.testing.assert_array_equal(and_func(arr1, arr2), numpy.array([False, True]))
351    numpy.testing.assert_array_equal(or_func(arr1, arr2), numpy.array([True, True]))
352    numpy.testing.assert_array_equal(not_func(arr2), numpy.array([True, False]))
353
354
355@with_numpy
356def test_numpy_matmul():
357    xmat = Matrix([[x, y], [z, 1+z]])
358    ymat = Matrix([[x**2], [abs(x)]])
359    mat_func = lambdify((x, y, z), xmat*ymat, modules='numpy')
360    numpy.testing.assert_array_equal(mat_func(0.5, 3, 4), numpy.array([[1.625], [3.5]]))
361    numpy.testing.assert_array_equal(mat_func(-0.5, 3, 4), numpy.array([[1.375], [3.5]]))
362    # Multiple matrices chained together in multiplication
363    f = lambdify((x, y, z), xmat*xmat*xmat, modules='numpy')
364    numpy.testing.assert_array_equal(f(0.5, 3, 4), numpy.array([[72.125, 119.25],
365                                                                [159, 251]]))
366
367
368def test_integral():
369    f = Lambda(x, exp(-x**2))
370    l = lambdify(x, Integral(f(x), (x, -oo, oo)), modules='diofant')
371    assert l(x) == Integral(exp(-x**2), (x, -oo, oo))
372
373# ================= Test symbolic ==================================
374
375
376def test_sym_single_arg():
377    f = lambdify(x, x * y)
378    assert f(z) == z * y
379
380
381def test_sym_list_args():
382    f = lambdify([x, y], x + y + z)
383    assert f(1, 2) == 3 + z
384
385
386def test_sym_integral():
387    f = Lambda(x, exp(-x**2))
388    l = lambdify(x, Integral(f(x), (x, -oo, oo)), modules='diofant')
389    assert l(y).doit() == sqrt(pi)
390
391
392def test_namespace_order():
393    # lambdify had a bug, such that module dictionaries or cached module
394    # dictionaries would pull earlier namespaces into themselves.
395    # Because the module dictionaries form the namespace of the
396    # generated lambda, this meant that the behavior of a previously
397    # generated lambda function could change as a result of later calls
398    # to lambdify.
399    n1 = {'f': lambda x: 'first f'}
400    n2 = {'f': lambda x: 'second f',
401          'g': lambda x: 'function g'}
402    f = diofant.Function('f')
403    g = diofant.Function('g')
404    if1 = lambdify(x, f(x), modules=(n1, 'diofant'))
405    assert if1(1) == 'first f'
406    if2 = lambdify(x, g(x), modules=(n2, 'diofant'))
407    assert if2(1) == 'function g'
408    # previously gave 'second f'
409    assert if1(1) == 'first f'
410
411
412def test_imps():
413    # Here we check if the default returned functions are anonymous - in
414    # the sense that we can have more than one function with the same name
415    f = implemented_function('f', lambda x: 2*x)
416    g = implemented_function('f', lambda x: math.sqrt(x))
417    l1 = lambdify(x, f(x))
418    l2 = lambdify(x, g(x))
419    assert str(f(x)) == str(g(x))
420    assert l1(3) == 6
421    assert l2(3) == math.sqrt(3)
422    # check that we can pass in a Function as input
423    func = diofant.Function('myfunc')
424    assert not hasattr(func, '_imp_')
425    my_f = implemented_function(func, lambda x: 2*x)
426    assert hasattr(func, '_imp_') and hasattr(my_f, '_imp_')
427    # Error for functions with same name and different implementation
428    f2 = implemented_function('f', lambda x: x + 101)
429    pytest.raises(ValueError, lambda: lambdify(x, f(f2(x))))
430
431
432def test_imps_errors():
433    # Test errors that implemented functions can return, and still be
434    # able to form expressions.  See issue sympy/sympy#10810.
435    for val, error_class in itertools.product((0, 0., 2, 2.0),
436                                              (AttributeError, TypeError,
437                                               ValueError)):
438
439        def myfunc(a):
440            if a == 0:
441                raise error_class
442            return 1
443
444        f = implemented_function('f', myfunc)
445        expr = f(val)
446        assert expr == f(val)
447
448
449def test_imps_wrong_args():
450    pytest.raises(ValueError, lambda: implemented_function(sin, lambda x: x))
451
452
453def test_lambdify_imps():
454    # Test lambdify with implemented functions
455    # first test basic (diofant) lambdify
456    f = diofant.cos
457    assert lambdify(x, f(x))(0) == 1
458    assert lambdify(x, 1 + f(x))(0) == 2
459    assert lambdify((x, y), y + f(x))(0, 1) == 2
460    # make an implemented function and test
461    f = implemented_function('f', lambda x: x + 100)
462    assert lambdify(x, f(x))(0) == 100
463    assert lambdify(x, 1 + f(x))(0) == 101
464    assert lambdify((x, y), y + f(x))(0, 1) == 101
465    # Can also handle tuples, lists, dicts as expressions
466    lam = lambdify(x, (f(x), x))
467    assert lam(3) == (103, 3)
468    lam = lambdify(x, [f(x), x])
469    assert lam(3) == [103, 3]
470    lam = lambdify(x, [f(x), (f(x), x)])
471    assert lam(3) == [103, (103, 3)]
472    lam = lambdify(x, {f(x): x})
473    assert lam(3) == {103: 3}
474    lam = lambdify(x, {f(x): x})
475    assert lam(3) == {103: 3}
476    lam = lambdify(x, {x: f(x)})
477    assert lam(3) == {3: 103}
478    # Check that imp preferred to other namespaces by default
479    d = {'f': lambda x: x + 99}
480    lam = lambdify(x, f(x), d)
481    assert lam(3) == 103
482    # Unless flag passed
483    lam = lambdify(x, f(x), d, use_imps=False)
484    assert lam(3) == 102
485
486
487def test_dummification():
488    F = Function('F')
489    G = Function('G')
490    # "\alpha" is not a valid python variable name
491    # lambdify should sub in a dummy for it, and return
492    # without a syntax error
493    alpha = symbols(r'\alpha')
494    some_expr = 2 * F(t)**2 / G(t)
495    lam = lambdify((F(t), G(t)), some_expr)
496    assert lam(3, 9) == 2
497    lam = lambdify(sin(t), 2 * sin(t)**2)
498    assert lam(F(t)) == 2 * F(t)**2
499    # Test that \alpha was properly dummified
500    lam = lambdify((alpha, t), 2*alpha + t)
501    assert lam(2, 1) == 5
502    pytest.raises(SyntaxError, lambda: lambdify(F(t) * G(t), F(t) * G(t) + 5))
503    pytest.raises(SyntaxError, lambda: lambdify(2 * F(t), 2 * F(t) + 5))
504    pytest.raises(SyntaxError, lambda: lambdify(2 * F(t), 4 * F(t) + 5))
505
506
507def test_python_keywords():
508    # Test for issue sympy/sympy#7452. The automatic dummification should ensure use of
509    # Python reserved keywords as symbol names will create valid lambda
510    # functions. This is an additional regression test.
511    python_if = symbols('if')
512    expr = python_if / 2
513    f = lambdify(python_if, expr)
514    assert f(4.0) == 2.0
515
516
517def test_lambdify_docstring():
518    func = lambdify((w, x, y, z), w + x + y + z)
519    assert func.__doc__ == (
520        'Created with lambdify. Signature:\n\n'
521        'func(w, x, y, z)\n\n'
522        'Expression:\n\n'
523        'w + x + y + z')
524    syms = symbols('a1:26')
525    func = lambdify(syms, sum(syms))
526    assert func.__doc__ == (
527        'Created with lambdify. Signature:\n\n'
528        'func(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15,\n'
529        '        a16, a17, a18, a19, a20, a21, a22, a23, a24, a25)\n\n'
530        'Expression:\n\n'
531        'a1 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 + a2 + a20 +...')
532
533
534# ================= Test special printers ==========================
535
536
537def test_special_printers():
538    class IntervalPrinter(LambdaPrinter):
539        """Use ``lambda`` printer but print numbers as ``mpi`` intervals."""
540
541        def _print_Integer(self, expr):
542            return f"mpi('{super()._print_Integer(expr)}')"
543
544        def _print_Rational(self, expr):
545            return f"mpi('{super()._print_Rational(expr)}')"
546
547    def intervalrepr(expr):
548        return IntervalPrinter().doprint(expr)
549
550    expr = diofant.sqrt(diofant.sqrt(2) + diofant.sqrt(3)) + diofant.Rational(1, 2)
551
552    func0 = lambdify((), expr, modules='mpmath', printer=intervalrepr)
553    func1 = lambdify((), expr, modules='mpmath', printer=IntervalPrinter)
554    func2 = lambdify((), expr, modules='mpmath', printer=IntervalPrinter())
555
556    mpi = type(mpmath.mpi(1, 2))
557
558    assert isinstance(func0(), mpi)
559    assert isinstance(func1(), mpi)
560    assert isinstance(func2(), mpi)
561
562
563def test_true_false():
564    # We want exact is comparison here, not just ==
565    assert lambdify([], true)() is True
566    assert lambdify([], false)() is False
567
568
569def test_ITE():
570    assert lambdify((x, y, z), ITE(x, y, z))(True, 5, 3) == 5
571    assert lambdify((x, y, z), ITE(x, y, z))(False, 5, 3) == 3
572
573
574def test_Min_Max():
575    # see sympy/sympy#10375
576    assert lambdify((x, y, z), Min(x, y, z))(1, 2, 3) == 1
577    assert lambdify((x, y, z), Max(x, y, z))(1, 2, 3) == 3
578
579
580def test_sympyissue_12092():
581    f = implemented_function('f', lambda x: x**2)
582    assert f(f(2)).evalf() == Float(16)
583