1import math
2from numba.core import types, utils
3from numba.core.typing.templates import (AttributeTemplate, ConcreteTemplate,
4                                         signature, Registry)
5
6registry = Registry()
7builtin_attr = registry.register_attr
8infer_global = registry.register_global
9
10
11@builtin_attr
12class MathModuleAttribute(AttributeTemplate):
13    key = types.Module(math)
14
15    def resolve_fabs(self, mod):
16        return types.Function(Math_fabs)
17
18    def resolve_exp(self, mod):
19        return types.Function(Math_exp)
20
21    def resolve_expm1(self, mod):
22        return types.Function(Math_expm1)
23
24    def resolve_sqrt(self, mod):
25        return types.Function(Math_sqrt)
26
27    def resolve_log(self, mod):
28        return types.Function(Math_log)
29
30    def resolve_log1p(self, mod):
31        return types.Function(Math_log1p)
32
33    def resolve_log10(self, mod):
34        return types.Function(Math_log10)
35
36    def resolve_sin(self, mod):
37        return types.Function(Math_sin)
38
39    def resolve_cos(self, mod):
40        return types.Function(Math_cos)
41
42    def resolve_tan(self, mod):
43        return types.Function(Math_tan)
44
45    def resolve_sinh(self, mod):
46        return types.Function(Math_sinh)
47
48    def resolve_cosh(self, mod):
49        return types.Function(Math_cosh)
50
51    def resolve_tanh(self, mod):
52        return types.Function(Math_tanh)
53
54    def resolve_asin(self, mod):
55        return types.Function(Math_asin)
56
57    def resolve_acos(self, mod):
58        return types.Function(Math_acos)
59
60    def resolve_atan(self, mod):
61        return types.Function(Math_atan)
62
63    def resolve_atan2(self, mod):
64        return types.Function(Math_atan2)
65
66    def resolve_asinh(self, mod):
67        return types.Function(Math_asinh)
68
69    def resolve_acosh(self, mod):
70        return types.Function(Math_acosh)
71
72    def resolve_atanh(self, mod):
73        return types.Function(Math_atanh)
74
75    def resolve_pi(self, mod):
76        return types.float64
77
78    def resolve_e(self, mod):
79        return types.float64
80
81    def resolve_floor(self, mod):
82        return types.Function(Math_floor)
83
84    def resolve_ceil(self, mod):
85        return types.Function(Math_ceil)
86
87    def resolve_trunc(self, mod):
88        return types.Function(Math_trunc)
89
90    def resolve_isnan(self, mod):
91        return types.Function(Math_isnan)
92
93    def resolve_isinf(self, mod):
94        return types.Function(Math_isinf)
95
96    def resolve_degrees(self, mod):
97        return types.Function(Math_degrees)
98
99    def resolve_radians(self, mod):
100        return types.Function(Math_radians)
101
102    # def resolve_hypot(self, mod):
103    # return types.Function(Math_hypot)
104
105    def resolve_copysign(self, mod):
106        return types.Function(Math_copysign)
107
108    def resolve_fmod(self, mod):
109        return types.Function(Math_fmod)
110
111    def resolve_pow(self, mod):
112        return types.Function(Math_pow)
113
114    def resolve_erf(self, mod):
115        return types.Function(Math_erf)
116
117    def resolve_erfc(self, mod):
118        return types.Function(Math_erfc)
119
120    def resolve_gamma(self, mod):
121        return types.Function(Math_gamma)
122
123    def resolve_lgamma(self, mod):
124        return types.Function(Math_lgamma)
125
126
127class Math_unary(ConcreteTemplate):
128    cases = [
129        signature(types.float64, types.int64),
130        signature(types.float64, types.uint64),
131        signature(types.float32, types.float32),
132        signature(types.float64, types.float64),
133    ]
134
135
136class Math_fabs(Math_unary):
137    key = math.fabs
138
139
140class Math_exp(Math_unary):
141    key = math.exp
142
143
144class Math_expm1(Math_unary):
145    key = math.expm1
146
147
148class Math_sqrt(Math_unary):
149    key = math.sqrt
150
151
152class Math_log(Math_unary):
153    key = math.log
154
155
156class Math_log1p(Math_unary):
157    key = math.log1p
158
159
160class Math_log10(Math_unary):
161    key = math.log10
162
163
164class Math_sin(Math_unary):
165    key = math.sin
166
167
168class Math_cos(Math_unary):
169    key = math.cos
170
171
172class Math_tan(Math_unary):
173    key = math.tan
174
175
176class Math_sinh(Math_unary):
177    key = math.sinh
178
179
180class Math_cosh(Math_unary):
181    key = math.cosh
182
183
184class Math_tanh(Math_unary):
185    key = math.tanh
186
187
188class Math_asin(Math_unary):
189    key = math.asin
190
191
192class Math_acos(Math_unary):
193    key = math.acos
194
195
196class Math_atan(Math_unary):
197    key = math.atan
198
199
200class Math_atan2(ConcreteTemplate):
201    key = math.atan2
202    cases = [
203        signature(types.float64, types.int64, types.int64),
204        signature(types.float64, types.uint64, types.uint64),
205        signature(types.float32, types.float32, types.float32),
206        signature(types.float64, types.float64, types.float64),
207    ]
208
209
210class Math_asinh(Math_unary):
211    key = math.asinh
212
213
214class Math_acosh(Math_unary):
215    key = math.acosh
216
217
218class Math_atanh(Math_unary):
219    key = math.atanh
220
221
222class Math_floor(Math_unary):
223    key = math.floor
224
225
226class Math_ceil(Math_unary):
227    key = math.ceil
228
229
230class Math_trunc(Math_unary):
231    key = math.trunc
232
233
234class Math_radians(Math_unary):
235    key = math.radians
236
237
238class Math_degrees(Math_unary):
239    key = math.degrees
240
241
242# class Math_hypot(ConcreteTemplate):
243# key = math.hypot
244#     cases = [
245#         signature(types.float64, types.int64, types.int64),
246#         signature(types.float64, types.uint64, types.uint64),
247#         signature(types.float32, types.float32, types.float32),
248#         signature(types.float64, types.float64, types.float64),
249#     ]
250
251
252class Math_erf(Math_unary):
253    key = math.erf
254
255class Math_erfc(Math_unary):
256    key = math.erfc
257
258class Math_gamma(Math_unary):
259    key = math.gamma
260
261class Math_lgamma(Math_unary):
262    key = math.lgamma
263
264
265class Math_binary(ConcreteTemplate):
266    cases = [
267        signature(types.float32, types.float32, types.float32),
268        signature(types.float64, types.float64, types.float64),
269    ]
270
271
272class Math_copysign(Math_binary):
273    key = math.copysign
274
275
276class Math_fmod(Math_binary):
277    key = math.fmod
278
279
280class Math_pow(ConcreteTemplate):
281    key = math.pow
282    cases = [
283        signature(types.float32, types.float32, types.float32),
284        signature(types.float64, types.float64, types.float64),
285        signature(types.float32, types.float32, types.int32),
286        signature(types.float64, types.float64, types.int32),
287    ]
288
289
290class Math_isnan(ConcreteTemplate):
291    key = math.isnan
292    cases = [
293        signature(types.boolean, types.int64),
294        signature(types.boolean, types.uint64),
295        signature(types.boolean, types.float32),
296        signature(types.boolean, types.float64),
297    ]
298
299
300class Math_isinf(ConcreteTemplate):
301    key = math.isinf
302    cases = [
303        signature(types.boolean, types.int64),
304        signature(types.boolean, types.uint64),
305        signature(types.boolean, types.float32),
306        signature(types.boolean, types.float64),
307    ]
308
309
310infer_global(math, types.Module(math))
311infer_global(math.fabs, types.Function(Math_fabs))
312infer_global(math.exp, types.Function(Math_exp))
313infer_global(math.expm1, types.Function(Math_expm1))
314infer_global(math.sqrt, types.Function(Math_sqrt))
315infer_global(math.log, types.Function(Math_log))
316infer_global(math.log1p, types.Function(Math_log1p))
317infer_global(math.log10, types.Function(Math_log10))
318infer_global(math.sin, types.Function(Math_sin))
319infer_global(math.cos, types.Function(Math_cos))
320infer_global(math.tan, types.Function(Math_tan))
321infer_global(math.sinh, types.Function(Math_sinh))
322infer_global(math.cosh, types.Function(Math_cosh))
323infer_global(math.tanh, types.Function(Math_tanh))
324infer_global(math.asin, types.Function(Math_asin))
325infer_global(math.acos, types.Function(Math_acos))
326infer_global(math.atan, types.Function(Math_atan))
327infer_global(math.atan2, types.Function(Math_atan2))
328infer_global(math.asinh, types.Function(Math_asinh))
329infer_global(math.acosh, types.Function(Math_acosh))
330infer_global(math.atanh, types.Function(Math_atanh))
331# infer_global(math.hypot, types.Function(Math_hypot))
332infer_global(math.floor, types.Function(Math_floor))
333infer_global(math.ceil, types.Function(Math_ceil))
334infer_global(math.trunc, types.Function(Math_trunc))
335infer_global(math.isnan, types.Function(Math_isnan))
336infer_global(math.isinf, types.Function(Math_isinf))
337infer_global(math.degrees, types.Function(Math_degrees))
338infer_global(math.radians, types.Function(Math_radians))
339infer_global(math.copysign, types.Function(Math_copysign))
340infer_global(math.fmod, types.Function(Math_fmod))
341infer_global(math.pow, types.Function(Math_pow))
342infer_global(math.erf, types.Function(Math_erf))
343infer_global(math.erfc, types.Function(Math_erfc))
344infer_global(math.gamma, types.Function(Math_gamma))
345infer_global(math.lgamma, types.Function(Math_lgamma))
346