1from sympy.external import import_module
2from sympy.testing.pytest import raises
3import ctypes
4
5
6if import_module('llvmlite'):
7    import sympy.printing.llvmjitcode as g
8else:
9    disabled = True
10
11import sympy
12from sympy.abc import a, b, n
13
14
15# copied from numpy.isclose documentation
16def isclose(a, b):
17    rtol = 1e-5
18    atol = 1e-8
19    return abs(a-b) <= atol + rtol*abs(b)
20
21
22def test_simple_expr():
23    e = a + 1.0
24    f = g.llvm_callable([a], e)
25    res = float(e.subs({a: 4.0}).evalf())
26    jit_res = f(4.0)
27
28    assert isclose(jit_res, res)
29
30
31def test_two_arg():
32    e = 4.0*a + b + 3.0
33    f = g.llvm_callable([a, b], e)
34    res = float(e.subs({a: 4.0, b: 3.0}).evalf())
35    jit_res = f(4.0, 3.0)
36
37    assert isclose(jit_res, res)
38
39
40def test_func():
41    e = 4.0*sympy.exp(-a)
42    f = g.llvm_callable([a], e)
43    res = float(e.subs({a: 1.5}).evalf())
44    jit_res = f(1.5)
45
46    assert isclose(jit_res, res)
47
48
49def test_two_func():
50    e = 4.0*sympy.exp(-a) + sympy.exp(b)
51    f = g.llvm_callable([a, b], e)
52    res = float(e.subs({a: 1.5, b: 2.0}).evalf())
53    jit_res = f(1.5, 2.0)
54
55    assert isclose(jit_res, res)
56
57
58def test_two_sqrt():
59    e = 4.0*sympy.sqrt(a) + sympy.sqrt(b)
60    f = g.llvm_callable([a, b], e)
61    res = float(e.subs({a: 1.5, b: 2.0}).evalf())
62    jit_res = f(1.5, 2.0)
63
64    assert isclose(jit_res, res)
65
66
67def test_two_pow():
68    e = a**1.5 + b**7
69    f = g.llvm_callable([a, b], e)
70    res = float(e.subs({a: 1.5, b: 2.0}).evalf())
71    jit_res = f(1.5, 2.0)
72
73    assert isclose(jit_res, res)
74
75
76def test_callback():
77    e = a + 1.2
78    f = g.llvm_callable([a], e, callback_type='scipy.integrate.test')
79    m = ctypes.c_int(1)
80    array_type = ctypes.c_double * 1
81    inp = {a: 2.2}
82    array = array_type(inp[a])
83    jit_res = f(m, array)
84
85    res = float(e.subs(inp).evalf())
86
87    assert isclose(jit_res, res)
88
89
90def test_callback_cubature():
91    e = a + 1.2
92    f = g.llvm_callable([a], e, callback_type='cubature')
93    m = ctypes.c_int(1)
94    array_type = ctypes.c_double * 1
95    inp = {a: 2.2}
96    array = array_type(inp[a])
97    out_array = array_type(0.0)
98    jit_ret = f(m, array, None, m, out_array)
99
100    assert jit_ret == 0
101
102    res = float(e.subs(inp).evalf())
103
104    assert isclose(out_array[0], res)
105
106
107def test_callback_two():
108    e = 3*a*b
109    f = g.llvm_callable([a, b], e, callback_type='scipy.integrate.test')
110    m = ctypes.c_int(2)
111    array_type = ctypes.c_double * 2
112    inp = {a: 0.2, b: 1.7}
113    array = array_type(inp[a], inp[b])
114    jit_res = f(m, array)
115
116    res = float(e.subs(inp).evalf())
117
118    assert isclose(jit_res, res)
119
120
121def test_callback_alt_two():
122    d = sympy.IndexedBase('d')
123    e = 3*d[0]*d[1]
124    f = g.llvm_callable([n, d], e, callback_type='scipy.integrate.test')
125    m = ctypes.c_int(2)
126    array_type = ctypes.c_double * 2
127    inp = {d[0]: 0.2, d[1]: 1.7}
128    array = array_type(inp[d[0]], inp[d[1]])
129    jit_res = f(m, array)
130
131    res = float(e.subs(inp).evalf())
132
133    assert isclose(jit_res, res)
134
135
136def test_multiple_statements():
137    # Match return from CSE
138    e = [[(b, 4.0*a)], [b + 5]]
139    f = g.llvm_callable([a], e)
140    b_val = e[0][0][1].subs({a: 1.5})
141    res = float(e[1][0].subs({b: b_val}).evalf())
142    jit_res = f(1.5)
143    assert isclose(jit_res, res)
144
145    f_callback = g.llvm_callable([a], e, callback_type='scipy.integrate.test')
146    m = ctypes.c_int(1)
147    array_type = ctypes.c_double * 1
148    array = array_type(1.5)
149    jit_callback_res = f_callback(m, array)
150    assert isclose(jit_callback_res, res)
151
152
153def test_cse():
154    e = a*a + b*b + sympy.exp(-a*a - b*b)
155    e2 = sympy.cse(e)
156    f = g.llvm_callable([a, b], e2)
157    res = float(e.subs({a: 2.3, b: 0.1}).evalf())
158    jit_res = f(2.3, 0.1)
159
160    assert isclose(jit_res, res)
161
162
163def eval_cse(e, sub_dict):
164    tmp_dict = dict()
165    for tmp_name, tmp_expr in e[0]:
166        e2 = tmp_expr.subs(sub_dict)
167        e3 = e2.subs(tmp_dict)
168        tmp_dict[tmp_name] = e3
169    return [e.subs(sub_dict).subs(tmp_dict) for e in e[1]]
170
171
172def test_cse_multiple():
173    e1 = a*a
174    e2 = a*a + b*b
175    e3 = sympy.cse([e1, e2])
176
177    raises(NotImplementedError,
178           lambda: g.llvm_callable([a, b], e3, callback_type='scipy.integrate'))
179
180    # XXX: The commented lines below lead to a segfault in Python 3.9 although
181    # they work fine in Python 3.8. It is not sufficient to mark the test as
182    # XFAIL because it crashes the test runner.
183
184    #f = g.llvm_callable([a, b], e3)
185
186    #jit_res = f(0.1, 1.5)
187    #assert len(jit_res) == 2
188    #res = eval_cse(e3, {a: 0.1, b: 1.5})
189    #assert isclose(res[0], jit_res[0])
190    #assert isclose(res[1], jit_res[1])
191
192
193def test_callback_cubature_multiple():
194    e1 = a*a
195    e2 = a*a + b*b
196    e3 = sympy.cse([e1, e2, 4*e2])
197    f = g.llvm_callable([a, b], e3, callback_type='cubature')
198
199    # Number of input variables
200    ndim = 2
201    # Number of output expression values
202    outdim = 3
203
204    m = ctypes.c_int(ndim)
205    fdim = ctypes.c_int(outdim)
206    array_type = ctypes.c_double * ndim
207    out_array_type = ctypes.c_double * outdim
208    inp = {a: 0.2, b: 1.5}
209    array = array_type(inp[a], inp[b])
210    out_array = out_array_type()
211    jit_ret = f(m, array, None, fdim, out_array)
212
213    assert jit_ret == 0
214
215    res = eval_cse(e3, inp)
216
217    assert isclose(out_array[0], res[0])
218    assert isclose(out_array[1], res[1])
219    assert isclose(out_array[2], res[2])
220
221
222def test_symbol_not_found():
223    e = a*a + b
224    raises(LookupError, lambda: g.llvm_callable([a], e))
225
226
227def test_bad_callback():
228    e = a
229    raises(ValueError, lambda: g.llvm_callable([a], e, callback_type='bad_callback'))
230