1import os 2import functools 3import operator 4from distutils.version import LooseVersion 5 6import numpy as np 7from numpy.testing import assert_ 8import pytest 9 10import scipy.special as sc 11 12__all__ = ['with_special_errors', 'assert_func_equal', 'FuncData'] 13 14 15#------------------------------------------------------------------------------ 16# Check if a module is present to be used in tests 17#------------------------------------------------------------------------------ 18 19class MissingModule: 20 def __init__(self, name): 21 self.name = name 22 23 24def check_version(module, min_ver): 25 if type(module) == MissingModule: 26 return pytest.mark.skip(reason="{} is not installed".format(module.name)) 27 return pytest.mark.skipif(LooseVersion(module.__version__) < LooseVersion(min_ver), 28 reason="{} version >= {} required".format(module.__name__, min_ver)) 29 30 31#------------------------------------------------------------------------------ 32# Enable convergence and loss of precision warnings -- turn off one by one 33#------------------------------------------------------------------------------ 34 35def with_special_errors(func): 36 """ 37 Enable special function errors (such as underflow, overflow, 38 loss of precision, etc.) 39 """ 40 @functools.wraps(func) 41 def wrapper(*a, **kw): 42 with sc.errstate(all='raise'): 43 res = func(*a, **kw) 44 return res 45 return wrapper 46 47 48#------------------------------------------------------------------------------ 49# Comparing function values at many data points at once, with helpful 50# error reports 51#------------------------------------------------------------------------------ 52 53def assert_func_equal(func, results, points, rtol=None, atol=None, 54 param_filter=None, knownfailure=None, 55 vectorized=True, dtype=None, nan_ok=False, 56 ignore_inf_sign=False, distinguish_nan_and_inf=True): 57 if hasattr(points, 'next'): 58 # it's a generator 59 points = list(points) 60 61 points = np.asarray(points) 62 if points.ndim == 1: 63 points = points[:,None] 64 nparams = points.shape[1] 65 66 if hasattr(results, '__name__'): 67 # function 68 data = points 69 result_columns = None 70 result_func = results 71 else: 72 # dataset 73 data = np.c_[points, results] 74 result_columns = list(range(nparams, data.shape[1])) 75 result_func = None 76 77 fdata = FuncData(func, data, list(range(nparams)), 78 result_columns=result_columns, result_func=result_func, 79 rtol=rtol, atol=atol, param_filter=param_filter, 80 knownfailure=knownfailure, nan_ok=nan_ok, vectorized=vectorized, 81 ignore_inf_sign=ignore_inf_sign, 82 distinguish_nan_and_inf=distinguish_nan_and_inf) 83 fdata.check() 84 85 86class FuncData: 87 """ 88 Data set for checking a special function. 89 90 Parameters 91 ---------- 92 func : function 93 Function to test 94 data : numpy array 95 columnar data to use for testing 96 param_columns : int or tuple of ints 97 Columns indices in which the parameters to `func` lie. 98 Can be imaginary integers to indicate that the parameter 99 should be cast to complex. 100 result_columns : int or tuple of ints, optional 101 Column indices for expected results from `func`. 102 result_func : callable, optional 103 Function to call to obtain results. 104 rtol : float, optional 105 Required relative tolerance. Default is 5*eps. 106 atol : float, optional 107 Required absolute tolerance. Default is 5*tiny. 108 param_filter : function, or tuple of functions/Nones, optional 109 Filter functions to exclude some parameter ranges. 110 If omitted, no filtering is done. 111 knownfailure : str, optional 112 Known failure error message to raise when the test is run. 113 If omitted, no exception is raised. 114 nan_ok : bool, optional 115 If nan is always an accepted result. 116 vectorized : bool, optional 117 Whether all functions passed in are vectorized. 118 ignore_inf_sign : bool, optional 119 Whether to ignore signs of infinities. 120 (Doesn't matter for complex-valued functions.) 121 distinguish_nan_and_inf : bool, optional 122 If True, treat numbers which contain nans or infs as as 123 equal. Sets ignore_inf_sign to be True. 124 125 """ 126 127 def __init__(self, func, data, param_columns, result_columns=None, 128 result_func=None, rtol=None, atol=None, param_filter=None, 129 knownfailure=None, dataname=None, nan_ok=False, vectorized=True, 130 ignore_inf_sign=False, distinguish_nan_and_inf=True): 131 self.func = func 132 self.data = data 133 self.dataname = dataname 134 if not hasattr(param_columns, '__len__'): 135 param_columns = (param_columns,) 136 self.param_columns = tuple(param_columns) 137 if result_columns is not None: 138 if not hasattr(result_columns, '__len__'): 139 result_columns = (result_columns,) 140 self.result_columns = tuple(result_columns) 141 if result_func is not None: 142 raise ValueError("Only result_func or result_columns should be provided") 143 elif result_func is not None: 144 self.result_columns = None 145 else: 146 raise ValueError("Either result_func or result_columns should be provided") 147 self.result_func = result_func 148 self.rtol = rtol 149 self.atol = atol 150 if not hasattr(param_filter, '__len__'): 151 param_filter = (param_filter,) 152 self.param_filter = param_filter 153 self.knownfailure = knownfailure 154 self.nan_ok = nan_ok 155 self.vectorized = vectorized 156 self.ignore_inf_sign = ignore_inf_sign 157 self.distinguish_nan_and_inf = distinguish_nan_and_inf 158 if not self.distinguish_nan_and_inf: 159 self.ignore_inf_sign = True 160 161 def get_tolerances(self, dtype): 162 if not np.issubdtype(dtype, np.inexact): 163 dtype = np.dtype(float) 164 info = np.finfo(dtype) 165 rtol, atol = self.rtol, self.atol 166 if rtol is None: 167 rtol = 5*info.eps 168 if atol is None: 169 atol = 5*info.tiny 170 return rtol, atol 171 172 def check(self, data=None, dtype=None, dtypes=None): 173 """Check the special function against the data.""" 174 __tracebackhide__ = operator.methodcaller( 175 'errisinstance', AssertionError 176 ) 177 178 if self.knownfailure: 179 pytest.xfail(reason=self.knownfailure) 180 181 if data is None: 182 data = self.data 183 184 if dtype is None: 185 dtype = data.dtype 186 else: 187 data = data.astype(dtype) 188 189 rtol, atol = self.get_tolerances(dtype) 190 191 # Apply given filter functions 192 if self.param_filter: 193 param_mask = np.ones((data.shape[0],), np.bool_) 194 for j, filter in zip(self.param_columns, self.param_filter): 195 if filter: 196 param_mask &= list(filter(data[:,j])) 197 data = data[param_mask] 198 199 # Pick parameters from the correct columns 200 params = [] 201 for idx, j in enumerate(self.param_columns): 202 if np.iscomplexobj(j): 203 j = int(j.imag) 204 params.append(data[:,j].astype(complex)) 205 elif dtypes and idx < len(dtypes): 206 params.append(data[:, j].astype(dtypes[idx])) 207 else: 208 params.append(data[:,j]) 209 210 # Helper for evaluating results 211 def eval_func_at_params(func, skip_mask=None): 212 if self.vectorized: 213 got = func(*params) 214 else: 215 got = [] 216 for j in range(len(params[0])): 217 if skip_mask is not None and skip_mask[j]: 218 got.append(np.nan) 219 continue 220 got.append(func(*tuple([params[i][j] for i in range(len(params))]))) 221 got = np.asarray(got) 222 if not isinstance(got, tuple): 223 got = (got,) 224 return got 225 226 # Evaluate function to be tested 227 got = eval_func_at_params(self.func) 228 229 # Grab the correct results 230 if self.result_columns is not None: 231 # Correct results passed in with the data 232 wanted = tuple([data[:,icol] for icol in self.result_columns]) 233 else: 234 # Function producing correct results passed in 235 skip_mask = None 236 if self.nan_ok and len(got) == 1: 237 # Don't spend time evaluating what doesn't need to be evaluated 238 skip_mask = np.isnan(got[0]) 239 wanted = eval_func_at_params(self.result_func, skip_mask=skip_mask) 240 241 # Check the validity of each output returned 242 assert_(len(got) == len(wanted)) 243 244 for output_num, (x, y) in enumerate(zip(got, wanted)): 245 if np.issubdtype(x.dtype, np.complexfloating) or self.ignore_inf_sign: 246 pinf_x = np.isinf(x) 247 pinf_y = np.isinf(y) 248 minf_x = np.isinf(x) 249 minf_y = np.isinf(y) 250 else: 251 pinf_x = np.isposinf(x) 252 pinf_y = np.isposinf(y) 253 minf_x = np.isneginf(x) 254 minf_y = np.isneginf(y) 255 nan_x = np.isnan(x) 256 nan_y = np.isnan(y) 257 258 with np.errstate(all='ignore'): 259 abs_y = np.absolute(y) 260 abs_y[~np.isfinite(abs_y)] = 0 261 diff = np.absolute(x - y) 262 diff[~np.isfinite(diff)] = 0 263 264 rdiff = diff / np.absolute(y) 265 rdiff[~np.isfinite(rdiff)] = 0 266 267 tol_mask = (diff <= atol + rtol*abs_y) 268 pinf_mask = (pinf_x == pinf_y) 269 minf_mask = (minf_x == minf_y) 270 271 nan_mask = (nan_x == nan_y) 272 273 bad_j = ~(tol_mask & pinf_mask & minf_mask & nan_mask) 274 275 point_count = bad_j.size 276 if self.nan_ok: 277 bad_j &= ~nan_x 278 bad_j &= ~nan_y 279 point_count -= (nan_x | nan_y).sum() 280 281 if not self.distinguish_nan_and_inf and not self.nan_ok: 282 # If nan's are okay we've already covered all these cases 283 inf_x = np.isinf(x) 284 inf_y = np.isinf(y) 285 both_nonfinite = (inf_x & nan_y) | (nan_x & inf_y) 286 bad_j &= ~both_nonfinite 287 point_count -= both_nonfinite.sum() 288 289 if np.any(bad_j): 290 # Some bad results: inform what, where, and how bad 291 msg = [""] 292 msg.append("Max |adiff|: %g" % diff[bad_j].max()) 293 msg.append("Max |rdiff|: %g" % rdiff[bad_j].max()) 294 msg.append("Bad results (%d out of %d) for the following points (in output %d):" 295 % (np.sum(bad_j), point_count, output_num,)) 296 for j in np.nonzero(bad_j)[0]: 297 j = int(j) 298 fmt = lambda x: "%30s" % np.array2string(x[j], precision=18) 299 a = " ".join(map(fmt, params)) 300 b = " ".join(map(fmt, got)) 301 c = " ".join(map(fmt, wanted)) 302 d = fmt(rdiff) 303 msg.append("%s => %s != %s (rdiff %s)" % (a, b, c, d)) 304 assert_(False, "\n".join(msg)) 305 306 def __repr__(self): 307 """Pretty-printing, esp. for Nose output""" 308 if np.any(list(map(np.iscomplexobj, self.param_columns))): 309 is_complex = " (complex)" 310 else: 311 is_complex = "" 312 if self.dataname: 313 return "<Data for %s%s: %s>" % (self.func.__name__, is_complex, 314 os.path.basename(self.dataname)) 315 else: 316 return "<Data for %s%s>" % (self.func.__name__, is_complex) 317