1import numpy as np
2import math
3
4import unittest
5from numba import roc
6from numba.core import utils
7
8
9class TestMath(unittest.TestCase):
10    def _get_tol(self, math_fn, ty):
11        """gets the tolerance for functions when the input is of type 'ty'"""
12
13        low_res = {
14            (math.gamma, np.float64): 1e-14,
15            (math.lgamma, np.float64): 1e-13,
16            (math.asin, np.float64): 1e-9,
17            (math.acos, np.float64): 4e-9,
18            (math.sqrt, np.float64): 2e-8,
19        }
20        default = 1e-15 if ty == np.float64 else 1e-6
21        return low_res.get((math_fn, ty), default)
22
23    def _generic_test_unary(self, math_fn, npy_fn,
24                            cases=None,
25                            span=(-1., 1.), count=128,
26                            types=(np.float32, np.float64)):
27
28        @roc.jit
29        def fn(dst, src):
30            i = roc.get_global_id(0)
31            if i < dst.size:
32                dst[i] = math_fn(src[i])
33
34        for dtype in types:
35            if cases is None:
36                src = np.linspace(span[0], span[1], count).astype(dtype)
37            else:
38                src = np.array(cases, dtype=dtype)
39
40            dst = np.zeros_like(src)
41            fn[src.size, 1](dst, src)
42            np.testing.assert_allclose(dst, npy_fn(src),
43                                       rtol=self._get_tol(math_fn, dtype),
44                                       err_msg='{0} ({1})'.format(
45                                           math_fn.__name__,
46                                           dtype.__name__))
47
48    def _generic_test_binary(self, math_fn, npy_fn,
49                             cases=None,
50                             span=(-1., 1., 1., -1.), count=128,
51                             types=(np.float32, np.float64)):
52
53        @roc.jit
54        def fn(dst, src1, src2):
55            i = roc.get_global_id(0)
56            if i < dst.size:
57                dst[i] = math_fn(src1[i], src2[i])
58
59        for dtype in types:
60            if cases is None:
61                src1 = np.linspace(span[0], span[1], count).astype(dtype)
62                src2 = np.linspace(span[2], span[3], count).astype(dtype)
63            else:
64                src1 = np.array(cases[0], dtype=dtype)
65                src2 = np.array(cases[1], dtype=dtype)
66
67            dst = np.zeros_like(src1)
68            fn[dst.size, 1](dst, src1, src2)
69            np.testing.assert_allclose(dst, npy_fn(src1, src2),
70                                       rtol=self._get_tol(math_fn, dtype),
71                                       err_msg='{0} ({1})'.format(
72                                           math_fn.__name__,
73                                           dtype.__name__))
74
75    def test_trig(self):
76        funcs = [math.sin, math.cos, math.tan]
77
78        for fn in funcs:
79            self._generic_test_unary(fn, getattr(np, fn.__name__),
80                                     span=(-np.pi, np.pi))
81
82    def test_trig_inv(self):
83        funcs = [(math.asin, np.arcsin),
84                 (math.acos, np.arccos),
85                 (math.atan, np.arctan)]
86
87        for fn, np_fn in funcs:
88            self._generic_test_unary(fn, np_fn)
89
90    def test_trigh(self):
91        funcs = [math.sinh, math.cosh, math.tanh]
92        for fn in funcs:
93            self._generic_test_unary(fn, getattr(np, fn.__name__),
94                                     span=(-4.0, 4.0))
95
96    def test_trigh_inv(self):
97        funcs = [(math.asinh, np.arcsinh, (-4, 4)),
98                 (math.acosh, np.arccosh, (1, 9)),
99                 (math.atanh, np.arctanh, (-0.9, 0.9))]
100
101        for fn, np_fn, span in funcs:
102            self._generic_test_unary(fn, np_fn, span=span)
103
104    def test_classify(self):
105        funcs = [math.isnan, math.isinf]
106        cases = (float('nan'), float('inf'), float('-inf'), float('-nan'),
107                 0, 3, -2)
108        for fn in funcs:
109            self._generic_test_unary(fn, getattr(np, fn.__name__),
110                                     cases=cases)
111
112    def test_floor_ceil(self):
113        funcs = [math.ceil, math.floor]
114
115        for fn in funcs:
116            # cases with varied decimals
117            self._generic_test_unary(fn, getattr(np, fn.__name__),
118                                     span=(-1013.14, 843.21))
119            # cases that include "exact" integers
120            self._generic_test_unary(fn, getattr(np, fn.__name__),
121                                     span=(-16, 16), count=129)
122
123    def test_fabs(self):
124        funcs = [math.fabs]
125        for fn in funcs:
126            self._generic_test_unary(fn, getattr(np, fn.__name__),
127                                     span=(-63.3, 63.3))
128
129    def test_unary_exp(self):
130        funcs = [math.exp]
131        for fn in funcs:
132            self._generic_test_unary(fn, getattr(np, fn.__name__),
133                                     span=(-30, 30))
134
135    def test_unary_expm1(self):
136        funcs = [math.expm1]
137        for fn in funcs:
138            self._generic_test_unary(fn, getattr(np, fn.__name__),
139                                     span=(-30, 30))
140
141    def test_sqrt(self):
142        funcs = [math.sqrt]
143        for fn in funcs:
144            self._generic_test_unary(fn, getattr(np, fn.__name__),
145                                     span=(0, 1000))
146
147    def test_log(self):
148        funcs = [math.log, math.log10, math.log1p]
149        for fn in funcs:
150            self._generic_test_unary(fn, getattr(np, fn.__name__),
151                                     span=(0.1, 2500))
152
153    def test_binaries(self):
154        funcs = [math.copysign, math.fmod]
155        for fn in funcs:
156            self._generic_test_binary(fn, getattr(np, fn.__name__))
157
158    def test_pow(self):
159        funcs = [(math.pow, np.power)]
160        for fn, npy_fn in funcs:
161            self._generic_test_binary(fn, npy_fn)
162
163    def test_atan2(self):
164        funcs = [(math.atan2, np.arctan2)]
165        for fn, npy_fn in funcs:
166            self._generic_test_binary(fn, npy_fn)
167
168    def test_erf(self):
169        funcs = [math.erf, math.erfc]
170        for fn in funcs:
171            self._generic_test_unary(fn, np.vectorize(fn))
172
173    def test_gamma(self):
174        funcs = [math.gamma, math.lgamma]
175        for fn in funcs:
176            self._generic_test_unary(fn, np.vectorize(fn), span=(1e-4, 4.0))
177
178
179if __name__ == '__main__':
180    unittest.main()
181