1from sympy.codegen import Assignment
2from sympy.codegen.ast import none
3from sympy.codegen.cfunctions import expm1, log1p
4from sympy.codegen.scipy_nodes import cosm1
5from sympy.codegen.matrix_nodes import MatrixSolve
6from sympy.core import Expr, Mod, symbols, Eq, Le, Gt, zoo, oo, Rational, Pow
7from sympy.core.numbers import pi
8from sympy.core.singleton import S
9from sympy.functions import acos, KroneckerDelta, Piecewise, sign, sqrt
10from sympy.logic import And, Or
11from sympy.matrices import SparseMatrix, MatrixSymbol, Identity
12from sympy.printing.pycode import (
13    MpmathPrinter, PythonCodePrinter, pycode, SymPyPrinter
14)
15from sympy.printing.numpy import NumPyPrinter, SciPyPrinter
16from sympy.testing.pytest import raises
17from sympy.tensor import IndexedBase
18from sympy.testing.pytest import skip
19from sympy.external import import_module
20from sympy.functions.special.gamma_functions import loggamma
21
22x, y, z = symbols('x y z')
23p = IndexedBase("p")
24
25def test_PythonCodePrinter():
26    prntr = PythonCodePrinter()
27
28    assert not prntr.module_imports
29
30    assert prntr.doprint(x**y) == 'x**y'
31    assert prntr.doprint(Mod(x, 2)) == 'x % 2'
32    assert prntr.doprint(And(x, y)) == 'x and y'
33    assert prntr.doprint(Or(x, y)) == 'x or y'
34    assert not prntr.module_imports
35
36    assert prntr.doprint(pi) == 'math.pi'
37    assert prntr.module_imports == {'math': {'pi'}}
38
39    assert prntr.doprint(x**Rational(1, 2)) == 'math.sqrt(x)'
40    assert prntr.doprint(sqrt(x)) == 'math.sqrt(x)'
41    assert prntr.module_imports == {'math': {'pi', 'sqrt'}}
42
43    assert prntr.doprint(acos(x)) == 'math.acos(x)'
44    assert prntr.doprint(Assignment(x, 2)) == 'x = 2'
45    assert prntr.doprint(Piecewise((1, Eq(x, 0)),
46                        (2, x>6))) == '((1) if (x == 0) else (2) if (x > 6) else None)'
47    assert prntr.doprint(Piecewise((2, Le(x, 0)),
48                        (3, Gt(x, 0)), evaluate=False)) == '((2) if (x <= 0) else'\
49                                                        ' (3) if (x > 0) else None)'
50    assert prntr.doprint(sign(x)) == '(0.0 if x == 0 else math.copysign(1, x))'
51    assert prntr.doprint(p[0, 1]) == 'p[0, 1]'
52    assert prntr.doprint(KroneckerDelta(x,y)) == '(1 if x == y else 0)'
53
54
55def test_PythonCodePrinter_standard():
56    import sys
57    prntr = PythonCodePrinter({'standard':None})
58
59    python_version = sys.version_info.major
60    if python_version == 2:
61        assert prntr.standard == 'python2'
62    if python_version == 3:
63        assert prntr.standard == 'python3'
64
65    raises(ValueError, lambda: PythonCodePrinter({'standard':'python4'}))
66
67def test_MpmathPrinter():
68    p = MpmathPrinter()
69    assert p.doprint(sign(x)) == 'mpmath.sign(x)'
70    assert p.doprint(Rational(1, 2)) == 'mpmath.mpf(1)/mpmath.mpf(2)'
71
72    assert p.doprint(S.Exp1) == 'mpmath.e'
73    assert p.doprint(S.Pi) == 'mpmath.pi'
74    assert p.doprint(S.GoldenRatio) == 'mpmath.phi'
75    assert p.doprint(S.EulerGamma) == 'mpmath.euler'
76    assert p.doprint(S.NaN) == 'mpmath.nan'
77    assert p.doprint(S.Infinity) == 'mpmath.inf'
78    assert p.doprint(S.NegativeInfinity) == 'mpmath.ninf'
79    assert p.doprint(loggamma(x)) == 'mpmath.loggamma(x)'
80
81
82def test_NumPyPrinter():
83    from sympy import (Lambda, ZeroMatrix, OneMatrix, FunctionMatrix,
84        HadamardProduct, KroneckerProduct, Adjoint, DiagonalOf,
85        DiagMatrix, DiagonalMatrix)
86    from sympy.abc import a, b
87    p = NumPyPrinter()
88    assert p.doprint(sign(x)) == 'numpy.sign(x)'
89    A = MatrixSymbol("A", 2, 2)
90    B = MatrixSymbol("B", 2, 2)
91    C = MatrixSymbol("C", 1, 5)
92    D = MatrixSymbol("D", 3, 4)
93    assert p.doprint(A**(-1)) == "numpy.linalg.inv(A)"
94    assert p.doprint(A**5) == "numpy.linalg.matrix_power(A, 5)"
95    assert p.doprint(Identity(3)) == "numpy.eye(3)"
96
97    u = MatrixSymbol('x', 2, 1)
98    v = MatrixSymbol('y', 2, 1)
99    assert p.doprint(MatrixSolve(A, u)) == 'numpy.linalg.solve(A, x)'
100    assert p.doprint(MatrixSolve(A, u) + v) == 'numpy.linalg.solve(A, x) + y'
101
102    assert p.doprint(ZeroMatrix(2, 3)) == "numpy.zeros((2, 3))"
103    assert p.doprint(OneMatrix(2, 3)) == "numpy.ones((2, 3))"
104    assert p.doprint(FunctionMatrix(4, 5, Lambda((a, b), a + b))) == \
105        "numpy.fromfunction(lambda a, b: a + b, (4, 5))"
106    assert p.doprint(HadamardProduct(A, B)) == "numpy.multiply(A, B)"
107    assert p.doprint(KroneckerProduct(A, B)) == "numpy.kron(A, B)"
108    assert p.doprint(Adjoint(A)) == "numpy.conjugate(numpy.transpose(A))"
109    assert p.doprint(DiagonalOf(A)) == "numpy.reshape(numpy.diag(A), (-1, 1))"
110    assert p.doprint(DiagMatrix(C)) == "numpy.diagflat(C)"
111    assert p.doprint(DiagonalMatrix(D)) == "numpy.multiply(D, numpy.eye(3, 4))"
112
113    # Workaround for numpy negative integer power errors
114    assert p.doprint(x**-1) == 'x**(-1.0)'
115    assert p.doprint(x**-2) == 'x**(-2.0)'
116
117    expr = Pow(2, -1, evaluate=False)
118    assert p.doprint(expr) == "2**(-1.0)"
119
120    assert p.doprint(S.Exp1) == 'numpy.e'
121    assert p.doprint(S.Pi) == 'numpy.pi'
122    assert p.doprint(S.EulerGamma) == 'numpy.euler_gamma'
123    assert p.doprint(S.NaN) == 'numpy.nan'
124    assert p.doprint(S.Infinity) == 'numpy.PINF'
125    assert p.doprint(S.NegativeInfinity) == 'numpy.NINF'
126
127
128def test_issue_18770():
129    numpy = import_module('numpy')
130    if not numpy:
131        skip("numpy not installed.")
132
133    from sympy import lambdify, Min, Max
134
135    expr1 = Min(0.1*x + 3, x + 1, 0.5*x + 1)
136    func = lambdify(x, expr1, "numpy")
137    assert (func(numpy.linspace(0, 3, 3)) == [1.0 , 1.75, 2.5 ]).all()
138    assert  func(4) == 3
139
140    expr1 = Max(x**2 , x**3)
141    func = lambdify(x,expr1, "numpy")
142    assert (func(numpy.linspace(-1 , 2, 4)) == [1, 0, 1, 8] ).all()
143    assert func(4) == 64
144
145
146def test_SciPyPrinter():
147    p = SciPyPrinter()
148    expr = acos(x)
149    assert 'numpy' not in p.module_imports
150    assert p.doprint(expr) == 'numpy.arccos(x)'
151    assert 'numpy' in p.module_imports
152    assert not any(m.startswith('scipy') for m in p.module_imports)
153    smat = SparseMatrix(2, 5, {(0, 1): 3})
154    assert p.doprint(smat) == \
155        'scipy.sparse.coo_matrix(([3], ([0], [1])), shape=(2, 5))'
156    assert 'scipy.sparse' in p.module_imports
157
158    assert p.doprint(S.GoldenRatio) == 'scipy.constants.golden_ratio'
159    assert p.doprint(S.Pi) == 'scipy.constants.pi'
160    assert p.doprint(S.Exp1) == 'numpy.e'
161
162
163def test_pycode_reserved_words():
164    s1, s2 = symbols('if else')
165    raises(ValueError, lambda: pycode(s1 + s2, error_on_reserved=True))
166    py_str = pycode(s1 + s2)
167    assert py_str in ('else_ + if_', 'if_ + else_')
168
169
170def test_sqrt():
171    prntr = PythonCodePrinter()
172    assert prntr._print_Pow(sqrt(x), rational=False) == 'math.sqrt(x)'
173    assert prntr._print_Pow(1/sqrt(x), rational=False) == '1/math.sqrt(x)'
174
175    prntr = PythonCodePrinter({'standard' : 'python2'})
176    assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1./2.)'
177    assert prntr._print_Pow(1/sqrt(x), rational=True) == 'x**(-1./2.)'
178
179    prntr = PythonCodePrinter({'standard' : 'python3'})
180    assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
181    assert prntr._print_Pow(1/sqrt(x), rational=True) == 'x**(-1/2)'
182
183    prntr = MpmathPrinter()
184    assert prntr._print_Pow(sqrt(x), rational=False) == 'mpmath.sqrt(x)'
185    assert prntr._print_Pow(sqrt(x), rational=True) == \
186        "x**(mpmath.mpf(1)/mpmath.mpf(2))"
187
188    prntr = NumPyPrinter()
189    assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)'
190    assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
191
192    prntr = SciPyPrinter()
193    assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)'
194    assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
195
196    prntr = SymPyPrinter()
197    assert prntr._print_Pow(sqrt(x), rational=False) == 'sympy.sqrt(x)'
198    assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
199
200
201def test_frac():
202    from sympy import frac
203
204    expr = frac(x)
205
206    prntr = NumPyPrinter()
207    assert prntr.doprint(expr) == 'numpy.mod(x, 1)'
208
209    prntr = SciPyPrinter()
210    assert prntr.doprint(expr) == 'numpy.mod(x, 1)'
211
212    prntr = PythonCodePrinter()
213    assert prntr.doprint(expr) == 'x % 1'
214
215    prntr = MpmathPrinter()
216    assert prntr.doprint(expr) == 'mpmath.frac(x)'
217
218    prntr = SymPyPrinter()
219    assert prntr.doprint(expr) == 'sympy.functions.elementary.integers.frac(x)'
220
221
222class CustomPrintedObject(Expr):
223    def _numpycode(self, printer):
224        return 'numpy'
225
226    def _mpmathcode(self, printer):
227        return 'mpmath'
228
229
230def test_printmethod():
231    obj = CustomPrintedObject()
232    assert NumPyPrinter().doprint(obj) == 'numpy'
233    assert MpmathPrinter().doprint(obj) == 'mpmath'
234
235
236def test_codegen_ast_nodes():
237    assert pycode(none) == 'None'
238
239
240def test_issue_14283():
241    prntr = PythonCodePrinter()
242
243    assert prntr.doprint(zoo) == "float('nan')"
244    assert prntr.doprint(-oo) == "float('-inf')"
245
246def test_NumPyPrinter_print_seq():
247    n = NumPyPrinter()
248
249    assert n._print_seq(range(2)) == '(0, 1,)'
250
251
252def test_issue_16535_16536():
253    from sympy import lowergamma, uppergamma
254
255    a = symbols('a')
256    expr1 = lowergamma(a, x)
257    expr2 = uppergamma(a, x)
258
259    prntr = SciPyPrinter()
260    assert prntr.doprint(expr1) == 'scipy.special.gamma(a)*scipy.special.gammainc(a, x)'
261    assert prntr.doprint(expr2) == 'scipy.special.gamma(a)*scipy.special.gammaincc(a, x)'
262
263    prntr = NumPyPrinter()
264    assert "Not supported" in prntr.doprint(expr1)
265    assert "Not supported" in prntr.doprint(expr2)
266
267    prntr = PythonCodePrinter()
268    assert "Not supported" in prntr.doprint(expr1)
269    assert "Not supported" in prntr.doprint(expr2)
270
271
272def test_Integral():
273    from sympy import Integral, exp
274
275    single = Integral(exp(-x), (x, 0, oo))
276    double = Integral(x**2*exp(x*y), (x, -z, z), (y, 0, z))
277    indefinite = Integral(x**2, x)
278    evaluateat = Integral(x**2, (x, 1))
279
280    prntr = SciPyPrinter()
281    assert prntr.doprint(single) == 'scipy.integrate.quad(lambda x: numpy.exp(-x), 0, numpy.PINF)[0]'
282    assert prntr.doprint(double) == 'scipy.integrate.nquad(lambda x, y: x**2*numpy.exp(x*y), ((-z, z), (0, z)))[0]'
283    raises(NotImplementedError, lambda: prntr.doprint(indefinite))
284    raises(NotImplementedError, lambda: prntr.doprint(evaluateat))
285
286    prntr = MpmathPrinter()
287    assert prntr.doprint(single) == 'mpmath.quad(lambda x: mpmath.exp(-x), (0, mpmath.inf))'
288    assert prntr.doprint(double) == 'mpmath.quad(lambda x, y: x**2*mpmath.exp(x*y), (-z, z), (0, z))'
289    raises(NotImplementedError, lambda: prntr.doprint(indefinite))
290    raises(NotImplementedError, lambda: prntr.doprint(evaluateat))
291
292
293def test_fresnel_integrals():
294    from sympy import fresnelc, fresnels
295
296    expr1 = fresnelc(x)
297    expr2 = fresnels(x)
298
299    prntr = SciPyPrinter()
300    assert prntr.doprint(expr1) == 'scipy.special.fresnel(x)[1]'
301    assert prntr.doprint(expr2) == 'scipy.special.fresnel(x)[0]'
302
303    prntr = NumPyPrinter()
304    assert "Not supported" in prntr.doprint(expr1)
305    assert "Not supported" in prntr.doprint(expr2)
306
307    prntr = PythonCodePrinter()
308    assert "Not supported" in prntr.doprint(expr1)
309    assert "Not supported" in prntr.doprint(expr2)
310
311    prntr = MpmathPrinter()
312    assert prntr.doprint(expr1) == 'mpmath.fresnelc(x)'
313    assert prntr.doprint(expr2) == 'mpmath.fresnels(x)'
314
315
316def test_beta():
317    from sympy import beta
318
319    expr = beta(x, y)
320
321    prntr = SciPyPrinter()
322    assert prntr.doprint(expr) == 'scipy.special.beta(x, y)'
323
324    prntr = NumPyPrinter()
325    assert prntr.doprint(expr) == 'math.gamma(x)*math.gamma(y)/math.gamma(x + y)'
326
327    prntr = PythonCodePrinter()
328    assert prntr.doprint(expr) == 'math.gamma(x)*math.gamma(y)/math.gamma(x + y)'
329
330    prntr = PythonCodePrinter({'allow_unknown_functions': True})
331    assert prntr.doprint(expr) == 'math.gamma(x)*math.gamma(y)/math.gamma(x + y)'
332
333    prntr = MpmathPrinter()
334    assert prntr.doprint(expr) ==  'mpmath.beta(x, y)'
335
336def test_airy():
337    from sympy import airyai, airybi
338
339    expr1 = airyai(x)
340    expr2 = airybi(x)
341
342    prntr = SciPyPrinter()
343    assert prntr.doprint(expr1) == 'scipy.special.airy(x)[0]'
344    assert prntr.doprint(expr2) == 'scipy.special.airy(x)[2]'
345
346    prntr = NumPyPrinter()
347    assert "Not supported" in prntr.doprint(expr1)
348    assert "Not supported" in prntr.doprint(expr2)
349
350    prntr = PythonCodePrinter()
351    assert "Not supported" in prntr.doprint(expr1)
352    assert "Not supported" in prntr.doprint(expr2)
353
354def test_airy_prime():
355    from sympy import airyaiprime, airybiprime
356
357    expr1 = airyaiprime(x)
358    expr2 = airybiprime(x)
359
360    prntr = SciPyPrinter()
361    assert prntr.doprint(expr1) == 'scipy.special.airy(x)[1]'
362    assert prntr.doprint(expr2) == 'scipy.special.airy(x)[3]'
363
364    prntr = NumPyPrinter()
365    assert "Not supported" in prntr.doprint(expr1)
366    assert "Not supported" in prntr.doprint(expr2)
367
368    prntr = PythonCodePrinter()
369    assert "Not supported" in prntr.doprint(expr1)
370    assert "Not supported" in prntr.doprint(expr2)
371
372
373def test_numerical_accuracy_functions():
374    prntr = SciPyPrinter()
375    assert prntr.doprint(expm1(x)) == 'numpy.expm1(x)'
376    assert prntr.doprint(log1p(x)) == 'numpy.log1p(x)'
377    assert prntr.doprint(cosm1(x)) == 'scipy.special.cosm1(x)'
378