1from sympy.codegen import Assignment 2from sympy.codegen.ast import none 3from sympy.codegen.cfunctions import expm1, log1p 4from sympy.codegen.scipy_nodes import cosm1 5from sympy.codegen.matrix_nodes import MatrixSolve 6from sympy.core import Expr, Mod, symbols, Eq, Le, Gt, zoo, oo, Rational, Pow 7from sympy.core.numbers import pi 8from sympy.core.singleton import S 9from sympy.functions import acos, KroneckerDelta, Piecewise, sign, sqrt 10from sympy.logic import And, Or 11from sympy.matrices import SparseMatrix, MatrixSymbol, Identity 12from sympy.printing.pycode import ( 13 MpmathPrinter, PythonCodePrinter, pycode, SymPyPrinter 14) 15from sympy.printing.numpy import NumPyPrinter, SciPyPrinter 16from sympy.testing.pytest import raises 17from sympy.tensor import IndexedBase 18from sympy.testing.pytest import skip 19from sympy.external import import_module 20from sympy.functions.special.gamma_functions import loggamma 21 22x, y, z = symbols('x y z') 23p = IndexedBase("p") 24 25def test_PythonCodePrinter(): 26 prntr = PythonCodePrinter() 27 28 assert not prntr.module_imports 29 30 assert prntr.doprint(x**y) == 'x**y' 31 assert prntr.doprint(Mod(x, 2)) == 'x % 2' 32 assert prntr.doprint(And(x, y)) == 'x and y' 33 assert prntr.doprint(Or(x, y)) == 'x or y' 34 assert not prntr.module_imports 35 36 assert prntr.doprint(pi) == 'math.pi' 37 assert prntr.module_imports == {'math': {'pi'}} 38 39 assert prntr.doprint(x**Rational(1, 2)) == 'math.sqrt(x)' 40 assert prntr.doprint(sqrt(x)) == 'math.sqrt(x)' 41 assert prntr.module_imports == {'math': {'pi', 'sqrt'}} 42 43 assert prntr.doprint(acos(x)) == 'math.acos(x)' 44 assert prntr.doprint(Assignment(x, 2)) == 'x = 2' 45 assert prntr.doprint(Piecewise((1, Eq(x, 0)), 46 (2, x>6))) == '((1) if (x == 0) else (2) if (x > 6) else None)' 47 assert prntr.doprint(Piecewise((2, Le(x, 0)), 48 (3, Gt(x, 0)), evaluate=False)) == '((2) if (x <= 0) else'\ 49 ' (3) if (x > 0) else None)' 50 assert prntr.doprint(sign(x)) == '(0.0 if x == 0 else math.copysign(1, x))' 51 assert prntr.doprint(p[0, 1]) == 'p[0, 1]' 52 assert prntr.doprint(KroneckerDelta(x,y)) == '(1 if x == y else 0)' 53 54 55def test_PythonCodePrinter_standard(): 56 import sys 57 prntr = PythonCodePrinter({'standard':None}) 58 59 python_version = sys.version_info.major 60 if python_version == 2: 61 assert prntr.standard == 'python2' 62 if python_version == 3: 63 assert prntr.standard == 'python3' 64 65 raises(ValueError, lambda: PythonCodePrinter({'standard':'python4'})) 66 67def test_MpmathPrinter(): 68 p = MpmathPrinter() 69 assert p.doprint(sign(x)) == 'mpmath.sign(x)' 70 assert p.doprint(Rational(1, 2)) == 'mpmath.mpf(1)/mpmath.mpf(2)' 71 72 assert p.doprint(S.Exp1) == 'mpmath.e' 73 assert p.doprint(S.Pi) == 'mpmath.pi' 74 assert p.doprint(S.GoldenRatio) == 'mpmath.phi' 75 assert p.doprint(S.EulerGamma) == 'mpmath.euler' 76 assert p.doprint(S.NaN) == 'mpmath.nan' 77 assert p.doprint(S.Infinity) == 'mpmath.inf' 78 assert p.doprint(S.NegativeInfinity) == 'mpmath.ninf' 79 assert p.doprint(loggamma(x)) == 'mpmath.loggamma(x)' 80 81 82def test_NumPyPrinter(): 83 from sympy import (Lambda, ZeroMatrix, OneMatrix, FunctionMatrix, 84 HadamardProduct, KroneckerProduct, Adjoint, DiagonalOf, 85 DiagMatrix, DiagonalMatrix) 86 from sympy.abc import a, b 87 p = NumPyPrinter() 88 assert p.doprint(sign(x)) == 'numpy.sign(x)' 89 A = MatrixSymbol("A", 2, 2) 90 B = MatrixSymbol("B", 2, 2) 91 C = MatrixSymbol("C", 1, 5) 92 D = MatrixSymbol("D", 3, 4) 93 assert p.doprint(A**(-1)) == "numpy.linalg.inv(A)" 94 assert p.doprint(A**5) == "numpy.linalg.matrix_power(A, 5)" 95 assert p.doprint(Identity(3)) == "numpy.eye(3)" 96 97 u = MatrixSymbol('x', 2, 1) 98 v = MatrixSymbol('y', 2, 1) 99 assert p.doprint(MatrixSolve(A, u)) == 'numpy.linalg.solve(A, x)' 100 assert p.doprint(MatrixSolve(A, u) + v) == 'numpy.linalg.solve(A, x) + y' 101 102 assert p.doprint(ZeroMatrix(2, 3)) == "numpy.zeros((2, 3))" 103 assert p.doprint(OneMatrix(2, 3)) == "numpy.ones((2, 3))" 104 assert p.doprint(FunctionMatrix(4, 5, Lambda((a, b), a + b))) == \ 105 "numpy.fromfunction(lambda a, b: a + b, (4, 5))" 106 assert p.doprint(HadamardProduct(A, B)) == "numpy.multiply(A, B)" 107 assert p.doprint(KroneckerProduct(A, B)) == "numpy.kron(A, B)" 108 assert p.doprint(Adjoint(A)) == "numpy.conjugate(numpy.transpose(A))" 109 assert p.doprint(DiagonalOf(A)) == "numpy.reshape(numpy.diag(A), (-1, 1))" 110 assert p.doprint(DiagMatrix(C)) == "numpy.diagflat(C)" 111 assert p.doprint(DiagonalMatrix(D)) == "numpy.multiply(D, numpy.eye(3, 4))" 112 113 # Workaround for numpy negative integer power errors 114 assert p.doprint(x**-1) == 'x**(-1.0)' 115 assert p.doprint(x**-2) == 'x**(-2.0)' 116 117 expr = Pow(2, -1, evaluate=False) 118 assert p.doprint(expr) == "2**(-1.0)" 119 120 assert p.doprint(S.Exp1) == 'numpy.e' 121 assert p.doprint(S.Pi) == 'numpy.pi' 122 assert p.doprint(S.EulerGamma) == 'numpy.euler_gamma' 123 assert p.doprint(S.NaN) == 'numpy.nan' 124 assert p.doprint(S.Infinity) == 'numpy.PINF' 125 assert p.doprint(S.NegativeInfinity) == 'numpy.NINF' 126 127 128def test_issue_18770(): 129 numpy = import_module('numpy') 130 if not numpy: 131 skip("numpy not installed.") 132 133 from sympy import lambdify, Min, Max 134 135 expr1 = Min(0.1*x + 3, x + 1, 0.5*x + 1) 136 func = lambdify(x, expr1, "numpy") 137 assert (func(numpy.linspace(0, 3, 3)) == [1.0 , 1.75, 2.5 ]).all() 138 assert func(4) == 3 139 140 expr1 = Max(x**2 , x**3) 141 func = lambdify(x,expr1, "numpy") 142 assert (func(numpy.linspace(-1 , 2, 4)) == [1, 0, 1, 8] ).all() 143 assert func(4) == 64 144 145 146def test_SciPyPrinter(): 147 p = SciPyPrinter() 148 expr = acos(x) 149 assert 'numpy' not in p.module_imports 150 assert p.doprint(expr) == 'numpy.arccos(x)' 151 assert 'numpy' in p.module_imports 152 assert not any(m.startswith('scipy') for m in p.module_imports) 153 smat = SparseMatrix(2, 5, {(0, 1): 3}) 154 assert p.doprint(smat) == \ 155 'scipy.sparse.coo_matrix(([3], ([0], [1])), shape=(2, 5))' 156 assert 'scipy.sparse' in p.module_imports 157 158 assert p.doprint(S.GoldenRatio) == 'scipy.constants.golden_ratio' 159 assert p.doprint(S.Pi) == 'scipy.constants.pi' 160 assert p.doprint(S.Exp1) == 'numpy.e' 161 162 163def test_pycode_reserved_words(): 164 s1, s2 = symbols('if else') 165 raises(ValueError, lambda: pycode(s1 + s2, error_on_reserved=True)) 166 py_str = pycode(s1 + s2) 167 assert py_str in ('else_ + if_', 'if_ + else_') 168 169 170def test_sqrt(): 171 prntr = PythonCodePrinter() 172 assert prntr._print_Pow(sqrt(x), rational=False) == 'math.sqrt(x)' 173 assert prntr._print_Pow(1/sqrt(x), rational=False) == '1/math.sqrt(x)' 174 175 prntr = PythonCodePrinter({'standard' : 'python2'}) 176 assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1./2.)' 177 assert prntr._print_Pow(1/sqrt(x), rational=True) == 'x**(-1./2.)' 178 179 prntr = PythonCodePrinter({'standard' : 'python3'}) 180 assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' 181 assert prntr._print_Pow(1/sqrt(x), rational=True) == 'x**(-1/2)' 182 183 prntr = MpmathPrinter() 184 assert prntr._print_Pow(sqrt(x), rational=False) == 'mpmath.sqrt(x)' 185 assert prntr._print_Pow(sqrt(x), rational=True) == \ 186 "x**(mpmath.mpf(1)/mpmath.mpf(2))" 187 188 prntr = NumPyPrinter() 189 assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)' 190 assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' 191 192 prntr = SciPyPrinter() 193 assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)' 194 assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' 195 196 prntr = SymPyPrinter() 197 assert prntr._print_Pow(sqrt(x), rational=False) == 'sympy.sqrt(x)' 198 assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' 199 200 201def test_frac(): 202 from sympy import frac 203 204 expr = frac(x) 205 206 prntr = NumPyPrinter() 207 assert prntr.doprint(expr) == 'numpy.mod(x, 1)' 208 209 prntr = SciPyPrinter() 210 assert prntr.doprint(expr) == 'numpy.mod(x, 1)' 211 212 prntr = PythonCodePrinter() 213 assert prntr.doprint(expr) == 'x % 1' 214 215 prntr = MpmathPrinter() 216 assert prntr.doprint(expr) == 'mpmath.frac(x)' 217 218 prntr = SymPyPrinter() 219 assert prntr.doprint(expr) == 'sympy.functions.elementary.integers.frac(x)' 220 221 222class CustomPrintedObject(Expr): 223 def _numpycode(self, printer): 224 return 'numpy' 225 226 def _mpmathcode(self, printer): 227 return 'mpmath' 228 229 230def test_printmethod(): 231 obj = CustomPrintedObject() 232 assert NumPyPrinter().doprint(obj) == 'numpy' 233 assert MpmathPrinter().doprint(obj) == 'mpmath' 234 235 236def test_codegen_ast_nodes(): 237 assert pycode(none) == 'None' 238 239 240def test_issue_14283(): 241 prntr = PythonCodePrinter() 242 243 assert prntr.doprint(zoo) == "float('nan')" 244 assert prntr.doprint(-oo) == "float('-inf')" 245 246def test_NumPyPrinter_print_seq(): 247 n = NumPyPrinter() 248 249 assert n._print_seq(range(2)) == '(0, 1,)' 250 251 252def test_issue_16535_16536(): 253 from sympy import lowergamma, uppergamma 254 255 a = symbols('a') 256 expr1 = lowergamma(a, x) 257 expr2 = uppergamma(a, x) 258 259 prntr = SciPyPrinter() 260 assert prntr.doprint(expr1) == 'scipy.special.gamma(a)*scipy.special.gammainc(a, x)' 261 assert prntr.doprint(expr2) == 'scipy.special.gamma(a)*scipy.special.gammaincc(a, x)' 262 263 prntr = NumPyPrinter() 264 assert "Not supported" in prntr.doprint(expr1) 265 assert "Not supported" in prntr.doprint(expr2) 266 267 prntr = PythonCodePrinter() 268 assert "Not supported" in prntr.doprint(expr1) 269 assert "Not supported" in prntr.doprint(expr2) 270 271 272def test_Integral(): 273 from sympy import Integral, exp 274 275 single = Integral(exp(-x), (x, 0, oo)) 276 double = Integral(x**2*exp(x*y), (x, -z, z), (y, 0, z)) 277 indefinite = Integral(x**2, x) 278 evaluateat = Integral(x**2, (x, 1)) 279 280 prntr = SciPyPrinter() 281 assert prntr.doprint(single) == 'scipy.integrate.quad(lambda x: numpy.exp(-x), 0, numpy.PINF)[0]' 282 assert prntr.doprint(double) == 'scipy.integrate.nquad(lambda x, y: x**2*numpy.exp(x*y), ((-z, z), (0, z)))[0]' 283 raises(NotImplementedError, lambda: prntr.doprint(indefinite)) 284 raises(NotImplementedError, lambda: prntr.doprint(evaluateat)) 285 286 prntr = MpmathPrinter() 287 assert prntr.doprint(single) == 'mpmath.quad(lambda x: mpmath.exp(-x), (0, mpmath.inf))' 288 assert prntr.doprint(double) == 'mpmath.quad(lambda x, y: x**2*mpmath.exp(x*y), (-z, z), (0, z))' 289 raises(NotImplementedError, lambda: prntr.doprint(indefinite)) 290 raises(NotImplementedError, lambda: prntr.doprint(evaluateat)) 291 292 293def test_fresnel_integrals(): 294 from sympy import fresnelc, fresnels 295 296 expr1 = fresnelc(x) 297 expr2 = fresnels(x) 298 299 prntr = SciPyPrinter() 300 assert prntr.doprint(expr1) == 'scipy.special.fresnel(x)[1]' 301 assert prntr.doprint(expr2) == 'scipy.special.fresnel(x)[0]' 302 303 prntr = NumPyPrinter() 304 assert "Not supported" in prntr.doprint(expr1) 305 assert "Not supported" in prntr.doprint(expr2) 306 307 prntr = PythonCodePrinter() 308 assert "Not supported" in prntr.doprint(expr1) 309 assert "Not supported" in prntr.doprint(expr2) 310 311 prntr = MpmathPrinter() 312 assert prntr.doprint(expr1) == 'mpmath.fresnelc(x)' 313 assert prntr.doprint(expr2) == 'mpmath.fresnels(x)' 314 315 316def test_beta(): 317 from sympy import beta 318 319 expr = beta(x, y) 320 321 prntr = SciPyPrinter() 322 assert prntr.doprint(expr) == 'scipy.special.beta(x, y)' 323 324 prntr = NumPyPrinter() 325 assert prntr.doprint(expr) == 'math.gamma(x)*math.gamma(y)/math.gamma(x + y)' 326 327 prntr = PythonCodePrinter() 328 assert prntr.doprint(expr) == 'math.gamma(x)*math.gamma(y)/math.gamma(x + y)' 329 330 prntr = PythonCodePrinter({'allow_unknown_functions': True}) 331 assert prntr.doprint(expr) == 'math.gamma(x)*math.gamma(y)/math.gamma(x + y)' 332 333 prntr = MpmathPrinter() 334 assert prntr.doprint(expr) == 'mpmath.beta(x, y)' 335 336def test_airy(): 337 from sympy import airyai, airybi 338 339 expr1 = airyai(x) 340 expr2 = airybi(x) 341 342 prntr = SciPyPrinter() 343 assert prntr.doprint(expr1) == 'scipy.special.airy(x)[0]' 344 assert prntr.doprint(expr2) == 'scipy.special.airy(x)[2]' 345 346 prntr = NumPyPrinter() 347 assert "Not supported" in prntr.doprint(expr1) 348 assert "Not supported" in prntr.doprint(expr2) 349 350 prntr = PythonCodePrinter() 351 assert "Not supported" in prntr.doprint(expr1) 352 assert "Not supported" in prntr.doprint(expr2) 353 354def test_airy_prime(): 355 from sympy import airyaiprime, airybiprime 356 357 expr1 = airyaiprime(x) 358 expr2 = airybiprime(x) 359 360 prntr = SciPyPrinter() 361 assert prntr.doprint(expr1) == 'scipy.special.airy(x)[1]' 362 assert prntr.doprint(expr2) == 'scipy.special.airy(x)[3]' 363 364 prntr = NumPyPrinter() 365 assert "Not supported" in prntr.doprint(expr1) 366 assert "Not supported" in prntr.doprint(expr2) 367 368 prntr = PythonCodePrinter() 369 assert "Not supported" in prntr.doprint(expr1) 370 assert "Not supported" in prntr.doprint(expr2) 371 372 373def test_numerical_accuracy_functions(): 374 prntr = SciPyPrinter() 375 assert prntr.doprint(expm1(x)) == 'numpy.expm1(x)' 376 assert prntr.doprint(log1p(x)) == 'numpy.log1p(x)' 377 assert prntr.doprint(cosm1(x)) == 'scipy.special.cosm1(x)' 378