1import itertools 2import math 3 4import mpmath 5import pytest 6 7import diofant 8from diofant import (ITE, And, Float, Function, I, Integral, Lambda, Matrix, 9 Max, Min, Not, Or, Piecewise, Rational, cos, exp, false, 10 lambdify, oo, pi, sin, sqrt, symbols, true) 11from diofant.abc import t, w, x, y, z 12from diofant.external import import_module 13from diofant.printing.lambdarepr import LambdaPrinter 14from diofant.utilities.decorator import conserve_mpmath_dps 15from diofant.utilities.lambdify import (MATH_TRANSLATIONS, MPMATH_TRANSLATIONS, 16 NUMPY_TRANSLATIONS, _get_namespace, 17 implemented_function, lambdastr) 18 19 20__all__ = () 21 22MutableDenseMatrix = Matrix 23 24numpy = import_module('numpy') 25with_numpy = pytest.mark.skipif(numpy is None, 26 reason="Couldn't import numpy.") 27 28# ================= Test different arguments ======================= 29 30 31def test_no_args(): 32 f = lambdify([], 1) 33 pytest.raises(TypeError, lambda: f(-1)) 34 assert f() == 1 35 36 37def test_single_arg(): 38 f = lambdify(x, 2*x) 39 assert f(1) == 2 40 41 42def test_list_args(): 43 f = lambdify([x, y], x + y) 44 assert f(1, 2) == 3 45 46 47def test_nested_args(): 48 # issue sympy/sympy#2790 49 assert lambdify((x, (y, z)), x + y)(1, (2, 4)) == 3 50 assert lambdify((x, (y, (w, z))), w + x + y + z)(1, (2, (3, 4))) == 10 51 assert lambdify(x, x + 1, dummify=False)(1) == 2 52 53 54def test_str_args(): 55 f = lambdify('x,y,z', 'z,y,x') 56 assert f(3, 2, 1) == (1, 2, 3) 57 assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0) 58 # make sure correct number of args required 59 pytest.raises(TypeError, lambda: f(0)) 60 61 62def test_own_namespace(): 63 def myfunc(x): 64 return 1 65 f = lambdify(x, sin(x), {'sin': myfunc}) 66 assert f(0.1) == 1 67 assert f(100) == 1 68 69 70def test_own_module(): 71 f = lambdify(x, sin(x), math) 72 assert f(0) == 0.0 73 74 75def test_bad_args(): 76 # no vargs given 77 pytest.raises(TypeError, lambda: lambdify(1)) 78 # same with vector exprs 79 pytest.raises(TypeError, lambda: lambdify([1, 2])) 80 # reserved name 81 pytest.raises(ValueError, lambda: lambdify((('__flatten_args__',),), 1)) 82 83 pytest.raises(NameError, lambda: lambdify(x, 1, 'spam')) 84 85 86def test__get_namespace(): 87 pytest.raises(TypeError, lambda: _get_namespace(1)) 88 89 90def test_lambdastr(): 91 assert lambdastr(x, x**2) == 'lambda x: (x**2)' 92 assert lambdastr(x, None, dummify=True).find('None') > 0 93 94 95def test_atoms(): 96 # Non-Symbol atoms should not be pulled out from the expression namespace 97 f = lambdify(x, pi + x, {'pi': 3.14}) 98 assert f(0) == 3.14 99 f = lambdify(x, I + x, {'I': 1j}) 100 assert f(1) == 1 + 1j 101 102# ================= Test different modules ========================= 103 104# high precision output of sin(0.2*pi) is used to detect if precision is lost unwanted 105 106 107@conserve_mpmath_dps 108def test_diofant_lambda(): 109 mpmath.mp.dps = 50 110 sin02 = mpmath.mpf('0.19866933079506121545941262711838975037020672954020') 111 f = lambdify(x, sin(x), 'diofant') 112 assert f(x) == sin(x) 113 prec = 1e-15 114 assert -prec < f(Rational(1, 5)).evalf() - Float(str(sin02)) < prec 115 116 117@conserve_mpmath_dps 118def test_math_lambda(): 119 mpmath.mp.dps = 50 120 sin02 = mpmath.mpf('0.19866933079506121545941262711838975037020672954020') 121 f = lambdify(x, sin(x), 'math') 122 prec = 1e-15 123 assert -prec < f(0.2) - sin02 < prec 124 125 # if this succeeds, it can't be a python math function 126 pytest.raises(TypeError, lambda: f(x)) 127 128 129@conserve_mpmath_dps 130def test_mpmath_lambda(): 131 mpmath.mp.dps = 50 132 sin02 = mpmath.mpf('0.19866933079506121545941262711838975037020672954020') 133 f = lambdify(x, sin(x), 'mpmath') 134 prec = 1e-49 # mpmath precision is around 50 decimal places 135 assert -prec < f(mpmath.mpf('0.2')) - sin02 < prec 136 137 # if this succeeds, it can't be a mpmath function 138 pytest.raises(TypeError, lambda: f(x)) 139 140 141@conserve_mpmath_dps 142def test_number_precision(): 143 mpmath.mp.dps = 50 144 sin02 = mpmath.mpf('0.19866933079506121545941262711838975037020672954020') 145 f = lambdify(x, sin02, 'mpmath') 146 prec = 1e-49 # mpmath precision is around 50 decimal places 147 assert -prec < f(0) - sin02 < prec 148 149 150@conserve_mpmath_dps 151def test_mpmath_precision(): 152 mpmath.mp.dps = 100 153 assert str(lambdify((), pi.evalf(100), 'mpmath')()) == str(pi.evalf(100)) 154 155 156# ================= Test Translations ============================== 157# We can only check if all translated functions are valid. It has to be checked 158# by hand if they are complete. 159 160 161def test_math_transl(): 162 for sym, mat in MATH_TRANSLATIONS.items(): 163 assert sym in diofant.__dict__ 164 assert mat in math.__dict__ 165 166 167def test_mpmath_transl(): 168 for sym, mat in MPMATH_TRANSLATIONS.items(): 169 assert sym in diofant.__dict__ or sym == 'Matrix' 170 assert mat in mpmath.__dict__ 171 172 173@with_numpy 174def test_numpy_transl(): 175 for sym, nump in NUMPY_TRANSLATIONS.items(): 176 assert sym in diofant.__dict__ 177 assert nump in numpy.__dict__ 178 179 180@with_numpy 181def test_numpy_translation_abs(): 182 f = lambdify(x, abs(x), 'numpy') 183 assert f(-1) == 1 184 assert f(1) == 1 185 186 187# ================= Test some functions ============================ 188 189 190def test_exponentiation(): 191 f = lambdify(x, x**2) 192 assert f(-1) == 1 193 assert f(0) == 0 194 assert f(1) == 1 195 assert f(-2) == 4 196 assert f(2) == 4 197 assert f(2.5) == 6.25 198 199 200def test_sqrt(): 201 f = lambdify(x, sqrt(x)) 202 assert f(0) == 0.0 203 assert f(1) == 1.0 204 assert f(4) == 2.0 205 assert abs(f(2) - 1.414) < 0.001 206 assert f(6.25) == 2.5 207 208 209def test_trig(): 210 f = lambdify([x], [cos(x), sin(x)], 'math') 211 d = f(pi) 212 prec = 1e-11 213 assert -prec < d[0] + 1 < prec 214 assert -prec < d[1] < prec 215 d = f(3.14159) 216 prec = 1e-5 217 assert -prec < d[0] + 1 < prec 218 assert -prec < d[1] < prec 219 220# ================= Test vectors =================================== 221 222 223def test_vector_simple(): 224 f = lambdify((x, y, z), (z, y, x)) 225 assert f(3, 2, 1) == (1, 2, 3) 226 assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0) 227 # make sure correct number of args required 228 pytest.raises(TypeError, lambda: f(0)) 229 230 231def test_vector_discontinuous(): 232 f = lambdify(x, (-1/x, 1/x)) 233 pytest.raises(ZeroDivisionError, lambda: f(0)) 234 assert f(1) == (-1.0, 1.0) 235 assert f(2) == (-0.5, 0.5) 236 assert f(-2) == (0.5, -0.5) 237 238 239def test_trig_symbolic(): 240 f = lambdify([x], [cos(x), sin(x)], 'math') 241 d = f(pi) 242 assert abs(d[0] + 1) < 0.0001 243 assert abs(d[1] - 0) < 0.0001 244 245 246def test_trig_float(): 247 f = lambdify([x], [cos(x), sin(x)]) 248 d = f(3.14159) 249 assert abs(d[0] + 1) < 0.0001 250 assert abs(d[1] - 0) < 0.0001 251 252 253def test_docs(): 254 f = lambdify(x, x**2) 255 assert f(2) == 4 256 f = lambdify([x, y, z], [z, y, x]) 257 assert f(1, 2, 3) == [3, 2, 1] 258 f = lambdify(x, sqrt(x)) 259 assert f(4) == 2.0 260 f = lambdify((x, y), sin(x*y)**2) 261 assert f(0, 5) == 0 262 263 264def test_math(): 265 f = lambdify((x, y), sin(x), modules='math') 266 assert f(0, 5) == 0 267 268 269def test_sin(): 270 f = lambdify(x, sin(x)**2) 271 assert isinstance(f(2), float) 272 f = lambdify(x, sin(x)**2, modules='math') 273 assert isinstance(f(2), float) 274 275 276def test_matrix(): 277 A = Matrix([[x, x*y], [sin(z) + 4, x**z]]) 278 sol = Matrix([[1, 2], [sin(3) + 4, 1]]) 279 f = lambdify((x, y, z), A, modules='diofant') 280 assert f(1, 2, 3) == sol 281 f = lambdify((x, y, z), (A, [A]), modules='diofant') 282 assert f(1, 2, 3) == (sol, [sol]) 283 J = Matrix((x, x + y)).jacobian((x, y)) 284 v = Matrix((x, y)) 285 sol = Matrix([[1, 0], [1, 1]]) 286 assert lambdify(v, J, modules='diofant')(1, 2) == sol 287 assert lambdify(v.T, J, modules='diofant')(1, 2) == sol 288 289 290@with_numpy 291def test_numpy_matrix(): 292 A = Matrix([[x, x*y], [sin(z) + 4, x**z]]) 293 sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]]) 294 # Lambdify array first, to ensure return to array as default 295 f = lambdify((x, y, z), A, ['numpy']) 296 numpy.testing.assert_allclose(f(1, 2, 3), sol_arr) 297 # Check that the types are arrays and matrices 298 assert isinstance(f(1, 2, 3), numpy.ndarray) 299 300 301@with_numpy 302def test_numpy_transpose(): 303 A = Matrix([[1, x], [0, 1]]) 304 f = lambdify(x, A.T, modules='numpy') 305 numpy.testing.assert_array_equal(f(2), numpy.array([[1, 0], [2, 1]])) 306 307 308@with_numpy 309def test_numpy_inverse(): 310 A = Matrix([[1, x], [0, 1]]) 311 f = lambdify(x, A**-1, modules='numpy') 312 numpy.testing.assert_array_equal(f(2), numpy.array([[1, -2], [0, 1]])) 313 314 315@with_numpy 316def test_numpy_old_matrix(): 317 A = Matrix([[x, x*y], [sin(z) + 4, x**z]]) 318 sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]]) 319 f = lambdify((x, y, z), A, [{'ImmutableMatrix': numpy.array}, 'numpy']) 320 numpy.testing.assert_allclose(f(1, 2, 3), sol_arr) 321 assert isinstance(f(1, 2, 3), numpy.ndarray) 322 323 324@with_numpy 325@pytest.mark.filterwarnings('ignore::RuntimeWarning') 326def test_python_div_zero_sympyissue_11306(): 327 p = Piecewise((1 / x, y < -1), (x, y <= 1), (1 / x, True)) 328 lambdify([x, y], p, modules='numpy')(0, 1) 329 330 331@with_numpy 332def test_numpy_piecewise(): 333 pieces = Piecewise((x, x < 3), (x**2, x > 5), (0, True)) 334 f = lambdify(x, pieces, modules='numpy') 335 numpy.testing.assert_array_equal(f(numpy.arange(10)), 336 numpy.array([0, 1, 2, 0, 0, 0, 36, 49, 64, 81])) 337 # If we evaluate somewhere all conditions are False, we should get back NaN 338 nodef_func = lambdify(x, Piecewise((x, x > 0), (-x, x < 0))) 339 numpy.testing.assert_array_equal(nodef_func(numpy.array([-1, 0, 1])), 340 numpy.array([1, numpy.nan, 1])) 341 342 343@with_numpy 344def test_numpy_logical_ops(): 345 and_func = lambdify((x, y), And(x, y), modules='numpy') 346 or_func = lambdify((x, y), Or(x, y), modules='numpy') 347 not_func = lambdify(x, Not(x), modules='numpy') 348 arr1 = numpy.array([True, True]) 349 arr2 = numpy.array([False, True]) 350 numpy.testing.assert_array_equal(and_func(arr1, arr2), numpy.array([False, True])) 351 numpy.testing.assert_array_equal(or_func(arr1, arr2), numpy.array([True, True])) 352 numpy.testing.assert_array_equal(not_func(arr2), numpy.array([True, False])) 353 354 355@with_numpy 356def test_numpy_matmul(): 357 xmat = Matrix([[x, y], [z, 1+z]]) 358 ymat = Matrix([[x**2], [abs(x)]]) 359 mat_func = lambdify((x, y, z), xmat*ymat, modules='numpy') 360 numpy.testing.assert_array_equal(mat_func(0.5, 3, 4), numpy.array([[1.625], [3.5]])) 361 numpy.testing.assert_array_equal(mat_func(-0.5, 3, 4), numpy.array([[1.375], [3.5]])) 362 # Multiple matrices chained together in multiplication 363 f = lambdify((x, y, z), xmat*xmat*xmat, modules='numpy') 364 numpy.testing.assert_array_equal(f(0.5, 3, 4), numpy.array([[72.125, 119.25], 365 [159, 251]])) 366 367 368def test_integral(): 369 f = Lambda(x, exp(-x**2)) 370 l = lambdify(x, Integral(f(x), (x, -oo, oo)), modules='diofant') 371 assert l(x) == Integral(exp(-x**2), (x, -oo, oo)) 372 373# ================= Test symbolic ================================== 374 375 376def test_sym_single_arg(): 377 f = lambdify(x, x * y) 378 assert f(z) == z * y 379 380 381def test_sym_list_args(): 382 f = lambdify([x, y], x + y + z) 383 assert f(1, 2) == 3 + z 384 385 386def test_sym_integral(): 387 f = Lambda(x, exp(-x**2)) 388 l = lambdify(x, Integral(f(x), (x, -oo, oo)), modules='diofant') 389 assert l(y).doit() == sqrt(pi) 390 391 392def test_namespace_order(): 393 # lambdify had a bug, such that module dictionaries or cached module 394 # dictionaries would pull earlier namespaces into themselves. 395 # Because the module dictionaries form the namespace of the 396 # generated lambda, this meant that the behavior of a previously 397 # generated lambda function could change as a result of later calls 398 # to lambdify. 399 n1 = {'f': lambda x: 'first f'} 400 n2 = {'f': lambda x: 'second f', 401 'g': lambda x: 'function g'} 402 f = diofant.Function('f') 403 g = diofant.Function('g') 404 if1 = lambdify(x, f(x), modules=(n1, 'diofant')) 405 assert if1(1) == 'first f' 406 if2 = lambdify(x, g(x), modules=(n2, 'diofant')) 407 assert if2(1) == 'function g' 408 # previously gave 'second f' 409 assert if1(1) == 'first f' 410 411 412def test_imps(): 413 # Here we check if the default returned functions are anonymous - in 414 # the sense that we can have more than one function with the same name 415 f = implemented_function('f', lambda x: 2*x) 416 g = implemented_function('f', lambda x: math.sqrt(x)) 417 l1 = lambdify(x, f(x)) 418 l2 = lambdify(x, g(x)) 419 assert str(f(x)) == str(g(x)) 420 assert l1(3) == 6 421 assert l2(3) == math.sqrt(3) 422 # check that we can pass in a Function as input 423 func = diofant.Function('myfunc') 424 assert not hasattr(func, '_imp_') 425 my_f = implemented_function(func, lambda x: 2*x) 426 assert hasattr(func, '_imp_') and hasattr(my_f, '_imp_') 427 # Error for functions with same name and different implementation 428 f2 = implemented_function('f', lambda x: x + 101) 429 pytest.raises(ValueError, lambda: lambdify(x, f(f2(x)))) 430 431 432def test_imps_errors(): 433 # Test errors that implemented functions can return, and still be 434 # able to form expressions. See issue sympy/sympy#10810. 435 for val, error_class in itertools.product((0, 0., 2, 2.0), 436 (AttributeError, TypeError, 437 ValueError)): 438 439 def myfunc(a): 440 if a == 0: 441 raise error_class 442 return 1 443 444 f = implemented_function('f', myfunc) 445 expr = f(val) 446 assert expr == f(val) 447 448 449def test_imps_wrong_args(): 450 pytest.raises(ValueError, lambda: implemented_function(sin, lambda x: x)) 451 452 453def test_lambdify_imps(): 454 # Test lambdify with implemented functions 455 # first test basic (diofant) lambdify 456 f = diofant.cos 457 assert lambdify(x, f(x))(0) == 1 458 assert lambdify(x, 1 + f(x))(0) == 2 459 assert lambdify((x, y), y + f(x))(0, 1) == 2 460 # make an implemented function and test 461 f = implemented_function('f', lambda x: x + 100) 462 assert lambdify(x, f(x))(0) == 100 463 assert lambdify(x, 1 + f(x))(0) == 101 464 assert lambdify((x, y), y + f(x))(0, 1) == 101 465 # Can also handle tuples, lists, dicts as expressions 466 lam = lambdify(x, (f(x), x)) 467 assert lam(3) == (103, 3) 468 lam = lambdify(x, [f(x), x]) 469 assert lam(3) == [103, 3] 470 lam = lambdify(x, [f(x), (f(x), x)]) 471 assert lam(3) == [103, (103, 3)] 472 lam = lambdify(x, {f(x): x}) 473 assert lam(3) == {103: 3} 474 lam = lambdify(x, {f(x): x}) 475 assert lam(3) == {103: 3} 476 lam = lambdify(x, {x: f(x)}) 477 assert lam(3) == {3: 103} 478 # Check that imp preferred to other namespaces by default 479 d = {'f': lambda x: x + 99} 480 lam = lambdify(x, f(x), d) 481 assert lam(3) == 103 482 # Unless flag passed 483 lam = lambdify(x, f(x), d, use_imps=False) 484 assert lam(3) == 102 485 486 487def test_dummification(): 488 F = Function('F') 489 G = Function('G') 490 # "\alpha" is not a valid python variable name 491 # lambdify should sub in a dummy for it, and return 492 # without a syntax error 493 alpha = symbols(r'\alpha') 494 some_expr = 2 * F(t)**2 / G(t) 495 lam = lambdify((F(t), G(t)), some_expr) 496 assert lam(3, 9) == 2 497 lam = lambdify(sin(t), 2 * sin(t)**2) 498 assert lam(F(t)) == 2 * F(t)**2 499 # Test that \alpha was properly dummified 500 lam = lambdify((alpha, t), 2*alpha + t) 501 assert lam(2, 1) == 5 502 pytest.raises(SyntaxError, lambda: lambdify(F(t) * G(t), F(t) * G(t) + 5)) 503 pytest.raises(SyntaxError, lambda: lambdify(2 * F(t), 2 * F(t) + 5)) 504 pytest.raises(SyntaxError, lambda: lambdify(2 * F(t), 4 * F(t) + 5)) 505 506 507def test_python_keywords(): 508 # Test for issue sympy/sympy#7452. The automatic dummification should ensure use of 509 # Python reserved keywords as symbol names will create valid lambda 510 # functions. This is an additional regression test. 511 python_if = symbols('if') 512 expr = python_if / 2 513 f = lambdify(python_if, expr) 514 assert f(4.0) == 2.0 515 516 517def test_lambdify_docstring(): 518 func = lambdify((w, x, y, z), w + x + y + z) 519 assert func.__doc__ == ( 520 'Created with lambdify. Signature:\n\n' 521 'func(w, x, y, z)\n\n' 522 'Expression:\n\n' 523 'w + x + y + z') 524 syms = symbols('a1:26') 525 func = lambdify(syms, sum(syms)) 526 assert func.__doc__ == ( 527 'Created with lambdify. Signature:\n\n' 528 'func(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15,\n' 529 ' a16, a17, a18, a19, a20, a21, a22, a23, a24, a25)\n\n' 530 'Expression:\n\n' 531 'a1 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 + a2 + a20 +...') 532 533 534# ================= Test special printers ========================== 535 536 537def test_special_printers(): 538 class IntervalPrinter(LambdaPrinter): 539 """Use ``lambda`` printer but print numbers as ``mpi`` intervals.""" 540 541 def _print_Integer(self, expr): 542 return f"mpi('{super()._print_Integer(expr)}')" 543 544 def _print_Rational(self, expr): 545 return f"mpi('{super()._print_Rational(expr)}')" 546 547 def intervalrepr(expr): 548 return IntervalPrinter().doprint(expr) 549 550 expr = diofant.sqrt(diofant.sqrt(2) + diofant.sqrt(3)) + diofant.Rational(1, 2) 551 552 func0 = lambdify((), expr, modules='mpmath', printer=intervalrepr) 553 func1 = lambdify((), expr, modules='mpmath', printer=IntervalPrinter) 554 func2 = lambdify((), expr, modules='mpmath', printer=IntervalPrinter()) 555 556 mpi = type(mpmath.mpi(1, 2)) 557 558 assert isinstance(func0(), mpi) 559 assert isinstance(func1(), mpi) 560 assert isinstance(func2(), mpi) 561 562 563def test_true_false(): 564 # We want exact is comparison here, not just == 565 assert lambdify([], true)() is True 566 assert lambdify([], false)() is False 567 568 569def test_ITE(): 570 assert lambdify((x, y, z), ITE(x, y, z))(True, 5, 3) == 5 571 assert lambdify((x, y, z), ITE(x, y, z))(False, 5, 3) == 3 572 573 574def test_Min_Max(): 575 # see sympy/sympy#10375 576 assert lambdify((x, y, z), Min(x, y, z))(1, 2, 3) == 1 577 assert lambdify((x, y, z), Max(x, y, z))(1, 2, 3) == 3 578 579 580def test_sympyissue_12092(): 581 f = implemented_function('f', lambda x: x**2) 582 assert f(f(2)).evalf() == Float(16) 583