1import os
2import functools
3import operator
4from distutils.version import LooseVersion
5
6import numpy as np
7from numpy.testing import assert_
8import pytest
9
10import scipy.special as sc
11
12__all__ = ['with_special_errors', 'assert_func_equal', 'FuncData']
13
14
15#------------------------------------------------------------------------------
16# Check if a module is present to be used in tests
17#------------------------------------------------------------------------------
18
19class MissingModule:
20    def __init__(self, name):
21        self.name = name
22
23
24def check_version(module, min_ver):
25    if type(module) == MissingModule:
26        return pytest.mark.skip(reason="{} is not installed".format(module.name))
27    return pytest.mark.skipif(LooseVersion(module.__version__) < LooseVersion(min_ver),
28                              reason="{} version >= {} required".format(module.__name__, min_ver))
29
30
31#------------------------------------------------------------------------------
32# Enable convergence and loss of precision warnings -- turn off one by one
33#------------------------------------------------------------------------------
34
35def with_special_errors(func):
36    """
37    Enable special function errors (such as underflow, overflow,
38    loss of precision, etc.)
39    """
40    @functools.wraps(func)
41    def wrapper(*a, **kw):
42        with sc.errstate(all='raise'):
43            res = func(*a, **kw)
44        return res
45    return wrapper
46
47
48#------------------------------------------------------------------------------
49# Comparing function values at many data points at once, with helpful
50# error reports
51#------------------------------------------------------------------------------
52
53def assert_func_equal(func, results, points, rtol=None, atol=None,
54                      param_filter=None, knownfailure=None,
55                      vectorized=True, dtype=None, nan_ok=False,
56                      ignore_inf_sign=False, distinguish_nan_and_inf=True):
57    if hasattr(points, 'next'):
58        # it's a generator
59        points = list(points)
60
61    points = np.asarray(points)
62    if points.ndim == 1:
63        points = points[:,None]
64    nparams = points.shape[1]
65
66    if hasattr(results, '__name__'):
67        # function
68        data = points
69        result_columns = None
70        result_func = results
71    else:
72        # dataset
73        data = np.c_[points, results]
74        result_columns = list(range(nparams, data.shape[1]))
75        result_func = None
76
77    fdata = FuncData(func, data, list(range(nparams)),
78                     result_columns=result_columns, result_func=result_func,
79                     rtol=rtol, atol=atol, param_filter=param_filter,
80                     knownfailure=knownfailure, nan_ok=nan_ok, vectorized=vectorized,
81                     ignore_inf_sign=ignore_inf_sign,
82                     distinguish_nan_and_inf=distinguish_nan_and_inf)
83    fdata.check()
84
85
86class FuncData:
87    """
88    Data set for checking a special function.
89
90    Parameters
91    ----------
92    func : function
93        Function to test
94    data : numpy array
95        columnar data to use for testing
96    param_columns : int or tuple of ints
97        Columns indices in which the parameters to `func` lie.
98        Can be imaginary integers to indicate that the parameter
99        should be cast to complex.
100    result_columns : int or tuple of ints, optional
101        Column indices for expected results from `func`.
102    result_func : callable, optional
103        Function to call to obtain results.
104    rtol : float, optional
105        Required relative tolerance. Default is 5*eps.
106    atol : float, optional
107        Required absolute tolerance. Default is 5*tiny.
108    param_filter : function, or tuple of functions/Nones, optional
109        Filter functions to exclude some parameter ranges.
110        If omitted, no filtering is done.
111    knownfailure : str, optional
112        Known failure error message to raise when the test is run.
113        If omitted, no exception is raised.
114    nan_ok : bool, optional
115        If nan is always an accepted result.
116    vectorized : bool, optional
117        Whether all functions passed in are vectorized.
118    ignore_inf_sign : bool, optional
119        Whether to ignore signs of infinities.
120        (Doesn't matter for complex-valued functions.)
121    distinguish_nan_and_inf : bool, optional
122        If True, treat numbers which contain nans or infs as as
123        equal. Sets ignore_inf_sign to be True.
124
125    """
126
127    def __init__(self, func, data, param_columns, result_columns=None,
128                 result_func=None, rtol=None, atol=None, param_filter=None,
129                 knownfailure=None, dataname=None, nan_ok=False, vectorized=True,
130                 ignore_inf_sign=False, distinguish_nan_and_inf=True):
131        self.func = func
132        self.data = data
133        self.dataname = dataname
134        if not hasattr(param_columns, '__len__'):
135            param_columns = (param_columns,)
136        self.param_columns = tuple(param_columns)
137        if result_columns is not None:
138            if not hasattr(result_columns, '__len__'):
139                result_columns = (result_columns,)
140            self.result_columns = tuple(result_columns)
141            if result_func is not None:
142                raise ValueError("Only result_func or result_columns should be provided")
143        elif result_func is not None:
144            self.result_columns = None
145        else:
146            raise ValueError("Either result_func or result_columns should be provided")
147        self.result_func = result_func
148        self.rtol = rtol
149        self.atol = atol
150        if not hasattr(param_filter, '__len__'):
151            param_filter = (param_filter,)
152        self.param_filter = param_filter
153        self.knownfailure = knownfailure
154        self.nan_ok = nan_ok
155        self.vectorized = vectorized
156        self.ignore_inf_sign = ignore_inf_sign
157        self.distinguish_nan_and_inf = distinguish_nan_and_inf
158        if not self.distinguish_nan_and_inf:
159            self.ignore_inf_sign = True
160
161    def get_tolerances(self, dtype):
162        if not np.issubdtype(dtype, np.inexact):
163            dtype = np.dtype(float)
164        info = np.finfo(dtype)
165        rtol, atol = self.rtol, self.atol
166        if rtol is None:
167            rtol = 5*info.eps
168        if atol is None:
169            atol = 5*info.tiny
170        return rtol, atol
171
172    def check(self, data=None, dtype=None, dtypes=None):
173        """Check the special function against the data."""
174        __tracebackhide__ = operator.methodcaller(
175            'errisinstance', AssertionError
176        )
177
178        if self.knownfailure:
179            pytest.xfail(reason=self.knownfailure)
180
181        if data is None:
182            data = self.data
183
184        if dtype is None:
185            dtype = data.dtype
186        else:
187            data = data.astype(dtype)
188
189        rtol, atol = self.get_tolerances(dtype)
190
191        # Apply given filter functions
192        if self.param_filter:
193            param_mask = np.ones((data.shape[0],), np.bool_)
194            for j, filter in zip(self.param_columns, self.param_filter):
195                if filter:
196                    param_mask &= list(filter(data[:,j]))
197            data = data[param_mask]
198
199        # Pick parameters from the correct columns
200        params = []
201        for idx, j in enumerate(self.param_columns):
202            if np.iscomplexobj(j):
203                j = int(j.imag)
204                params.append(data[:,j].astype(complex))
205            elif dtypes and idx < len(dtypes):
206                params.append(data[:, j].astype(dtypes[idx]))
207            else:
208                params.append(data[:,j])
209
210        # Helper for evaluating results
211        def eval_func_at_params(func, skip_mask=None):
212            if self.vectorized:
213                got = func(*params)
214            else:
215                got = []
216                for j in range(len(params[0])):
217                    if skip_mask is not None and skip_mask[j]:
218                        got.append(np.nan)
219                        continue
220                    got.append(func(*tuple([params[i][j] for i in range(len(params))])))
221                got = np.asarray(got)
222            if not isinstance(got, tuple):
223                got = (got,)
224            return got
225
226        # Evaluate function to be tested
227        got = eval_func_at_params(self.func)
228
229        # Grab the correct results
230        if self.result_columns is not None:
231            # Correct results passed in with the data
232            wanted = tuple([data[:,icol] for icol in self.result_columns])
233        else:
234            # Function producing correct results passed in
235            skip_mask = None
236            if self.nan_ok and len(got) == 1:
237                # Don't spend time evaluating what doesn't need to be evaluated
238                skip_mask = np.isnan(got[0])
239            wanted = eval_func_at_params(self.result_func, skip_mask=skip_mask)
240
241        # Check the validity of each output returned
242        assert_(len(got) == len(wanted))
243
244        for output_num, (x, y) in enumerate(zip(got, wanted)):
245            if np.issubdtype(x.dtype, np.complexfloating) or self.ignore_inf_sign:
246                pinf_x = np.isinf(x)
247                pinf_y = np.isinf(y)
248                minf_x = np.isinf(x)
249                minf_y = np.isinf(y)
250            else:
251                pinf_x = np.isposinf(x)
252                pinf_y = np.isposinf(y)
253                minf_x = np.isneginf(x)
254                minf_y = np.isneginf(y)
255            nan_x = np.isnan(x)
256            nan_y = np.isnan(y)
257
258            with np.errstate(all='ignore'):
259                abs_y = np.absolute(y)
260                abs_y[~np.isfinite(abs_y)] = 0
261                diff = np.absolute(x - y)
262                diff[~np.isfinite(diff)] = 0
263
264                rdiff = diff / np.absolute(y)
265                rdiff[~np.isfinite(rdiff)] = 0
266
267            tol_mask = (diff <= atol + rtol*abs_y)
268            pinf_mask = (pinf_x == pinf_y)
269            minf_mask = (minf_x == minf_y)
270
271            nan_mask = (nan_x == nan_y)
272
273            bad_j = ~(tol_mask & pinf_mask & minf_mask & nan_mask)
274
275            point_count = bad_j.size
276            if self.nan_ok:
277                bad_j &= ~nan_x
278                bad_j &= ~nan_y
279                point_count -= (nan_x | nan_y).sum()
280
281            if not self.distinguish_nan_and_inf and not self.nan_ok:
282                # If nan's are okay we've already covered all these cases
283                inf_x = np.isinf(x)
284                inf_y = np.isinf(y)
285                both_nonfinite = (inf_x & nan_y) | (nan_x & inf_y)
286                bad_j &= ~both_nonfinite
287                point_count -= both_nonfinite.sum()
288
289            if np.any(bad_j):
290                # Some bad results: inform what, where, and how bad
291                msg = [""]
292                msg.append("Max |adiff|: %g" % diff[bad_j].max())
293                msg.append("Max |rdiff|: %g" % rdiff[bad_j].max())
294                msg.append("Bad results (%d out of %d) for the following points (in output %d):"
295                           % (np.sum(bad_j), point_count, output_num,))
296                for j in np.nonzero(bad_j)[0]:
297                    j = int(j)
298                    fmt = lambda x: "%30s" % np.array2string(x[j], precision=18)
299                    a = "  ".join(map(fmt, params))
300                    b = "  ".join(map(fmt, got))
301                    c = "  ".join(map(fmt, wanted))
302                    d = fmt(rdiff)
303                    msg.append("%s => %s != %s  (rdiff %s)" % (a, b, c, d))
304                assert_(False, "\n".join(msg))
305
306    def __repr__(self):
307        """Pretty-printing, esp. for Nose output"""
308        if np.any(list(map(np.iscomplexobj, self.param_columns))):
309            is_complex = " (complex)"
310        else:
311            is_complex = ""
312        if self.dataname:
313            return "<Data for %s%s: %s>" % (self.func.__name__, is_complex,
314                                            os.path.basename(self.dataname))
315        else:
316            return "<Data for %s%s>" % (self.func.__name__, is_complex)
317