1import numpy as np 2import math 3 4import unittest 5from numba import roc 6from numba.core import utils 7 8 9class TestMath(unittest.TestCase): 10 def _get_tol(self, math_fn, ty): 11 """gets the tolerance for functions when the input is of type 'ty'""" 12 13 low_res = { 14 (math.gamma, np.float64): 1e-14, 15 (math.lgamma, np.float64): 1e-13, 16 (math.asin, np.float64): 1e-9, 17 (math.acos, np.float64): 4e-9, 18 (math.sqrt, np.float64): 2e-8, 19 } 20 default = 1e-15 if ty == np.float64 else 1e-6 21 return low_res.get((math_fn, ty), default) 22 23 def _generic_test_unary(self, math_fn, npy_fn, 24 cases=None, 25 span=(-1., 1.), count=128, 26 types=(np.float32, np.float64)): 27 28 @roc.jit 29 def fn(dst, src): 30 i = roc.get_global_id(0) 31 if i < dst.size: 32 dst[i] = math_fn(src[i]) 33 34 for dtype in types: 35 if cases is None: 36 src = np.linspace(span[0], span[1], count).astype(dtype) 37 else: 38 src = np.array(cases, dtype=dtype) 39 40 dst = np.zeros_like(src) 41 fn[src.size, 1](dst, src) 42 np.testing.assert_allclose(dst, npy_fn(src), 43 rtol=self._get_tol(math_fn, dtype), 44 err_msg='{0} ({1})'.format( 45 math_fn.__name__, 46 dtype.__name__)) 47 48 def _generic_test_binary(self, math_fn, npy_fn, 49 cases=None, 50 span=(-1., 1., 1., -1.), count=128, 51 types=(np.float32, np.float64)): 52 53 @roc.jit 54 def fn(dst, src1, src2): 55 i = roc.get_global_id(0) 56 if i < dst.size: 57 dst[i] = math_fn(src1[i], src2[i]) 58 59 for dtype in types: 60 if cases is None: 61 src1 = np.linspace(span[0], span[1], count).astype(dtype) 62 src2 = np.linspace(span[2], span[3], count).astype(dtype) 63 else: 64 src1 = np.array(cases[0], dtype=dtype) 65 src2 = np.array(cases[1], dtype=dtype) 66 67 dst = np.zeros_like(src1) 68 fn[dst.size, 1](dst, src1, src2) 69 np.testing.assert_allclose(dst, npy_fn(src1, src2), 70 rtol=self._get_tol(math_fn, dtype), 71 err_msg='{0} ({1})'.format( 72 math_fn.__name__, 73 dtype.__name__)) 74 75 def test_trig(self): 76 funcs = [math.sin, math.cos, math.tan] 77 78 for fn in funcs: 79 self._generic_test_unary(fn, getattr(np, fn.__name__), 80 span=(-np.pi, np.pi)) 81 82 def test_trig_inv(self): 83 funcs = [(math.asin, np.arcsin), 84 (math.acos, np.arccos), 85 (math.atan, np.arctan)] 86 87 for fn, np_fn in funcs: 88 self._generic_test_unary(fn, np_fn) 89 90 def test_trigh(self): 91 funcs = [math.sinh, math.cosh, math.tanh] 92 for fn in funcs: 93 self._generic_test_unary(fn, getattr(np, fn.__name__), 94 span=(-4.0, 4.0)) 95 96 def test_trigh_inv(self): 97 funcs = [(math.asinh, np.arcsinh, (-4, 4)), 98 (math.acosh, np.arccosh, (1, 9)), 99 (math.atanh, np.arctanh, (-0.9, 0.9))] 100 101 for fn, np_fn, span in funcs: 102 self._generic_test_unary(fn, np_fn, span=span) 103 104 def test_classify(self): 105 funcs = [math.isnan, math.isinf] 106 cases = (float('nan'), float('inf'), float('-inf'), float('-nan'), 107 0, 3, -2) 108 for fn in funcs: 109 self._generic_test_unary(fn, getattr(np, fn.__name__), 110 cases=cases) 111 112 def test_floor_ceil(self): 113 funcs = [math.ceil, math.floor] 114 115 for fn in funcs: 116 # cases with varied decimals 117 self._generic_test_unary(fn, getattr(np, fn.__name__), 118 span=(-1013.14, 843.21)) 119 # cases that include "exact" integers 120 self._generic_test_unary(fn, getattr(np, fn.__name__), 121 span=(-16, 16), count=129) 122 123 def test_fabs(self): 124 funcs = [math.fabs] 125 for fn in funcs: 126 self._generic_test_unary(fn, getattr(np, fn.__name__), 127 span=(-63.3, 63.3)) 128 129 def test_unary_exp(self): 130 funcs = [math.exp] 131 for fn in funcs: 132 self._generic_test_unary(fn, getattr(np, fn.__name__), 133 span=(-30, 30)) 134 135 def test_unary_expm1(self): 136 funcs = [math.expm1] 137 for fn in funcs: 138 self._generic_test_unary(fn, getattr(np, fn.__name__), 139 span=(-30, 30)) 140 141 def test_sqrt(self): 142 funcs = [math.sqrt] 143 for fn in funcs: 144 self._generic_test_unary(fn, getattr(np, fn.__name__), 145 span=(0, 1000)) 146 147 def test_log(self): 148 funcs = [math.log, math.log10, math.log1p] 149 for fn in funcs: 150 self._generic_test_unary(fn, getattr(np, fn.__name__), 151 span=(0.1, 2500)) 152 153 def test_binaries(self): 154 funcs = [math.copysign, math.fmod] 155 for fn in funcs: 156 self._generic_test_binary(fn, getattr(np, fn.__name__)) 157 158 def test_pow(self): 159 funcs = [(math.pow, np.power)] 160 for fn, npy_fn in funcs: 161 self._generic_test_binary(fn, npy_fn) 162 163 def test_atan2(self): 164 funcs = [(math.atan2, np.arctan2)] 165 for fn, npy_fn in funcs: 166 self._generic_test_binary(fn, npy_fn) 167 168 def test_erf(self): 169 funcs = [math.erf, math.erfc] 170 for fn in funcs: 171 self._generic_test_unary(fn, np.vectorize(fn)) 172 173 def test_gamma(self): 174 funcs = [math.gamma, math.lgamma] 175 for fn in funcs: 176 self._generic_test_unary(fn, np.vectorize(fn), span=(1e-4, 4.0)) 177 178 179if __name__ == '__main__': 180 unittest.main() 181