1"""
2Testing utilities.
3"""
4
5import os
6import re
7import struct
8import threading
9import functools
10from tempfile import NamedTemporaryFile
11
12import numpy as np
13from numpy import testing
14from numpy.testing import (assert_array_equal, assert_array_almost_equal,
15                           assert_array_less, assert_array_almost_equal_nulp,
16                           assert_equal, TestCase, assert_allclose,
17                           assert_almost_equal, assert_, assert_warns,
18                           assert_no_warnings)
19
20import warnings
21
22from .. import data, io
23from ..data._fetchers import _fetch
24from ..util import img_as_uint, img_as_float, img_as_int, img_as_ubyte
25from ._warnings import expected_warnings
26
27
28SKIP_RE = re.compile(r"(\s*>>>.*?)(\s*)#\s*skip\s+if\s+(.*)$")
29
30import pytest
31skipif = pytest.mark.skipif
32xfail = pytest.mark.xfail
33parametrize = pytest.mark.parametrize
34raises = pytest.raises
35fixture = pytest.fixture
36
37# true if python is running in 32bit mode
38# Calculate the size of a void * pointer in bits
39# https://docs.python.org/3/library/struct.html
40arch32 = struct.calcsize("P") * 8 == 32
41
42
43_error_on_warnings = os.environ.get('SKIMAGE_TEST_STRICT_WARNINGS_GLOBAL', '0')
44if _error_on_warnings.lower() == 'true':
45    _error_on_warnings = True
46elif _error_on_warnings.lower() == 'false':
47    _error_on_warnings = False
48else:
49    _error_on_warnings = bool(int(_error_on_warnings))
50
51
52def assert_less(a, b, msg=None):
53    message = "%r is not lower than %r" % (a, b)
54    if msg is not None:
55        message += ": " + msg
56    assert a < b, message
57
58
59def assert_greater(a, b, msg=None):
60    message = "%r is not greater than %r" % (a, b)
61    if msg is not None:
62        message += ": " + msg
63    assert a > b, message
64
65
66def doctest_skip_parser(func):
67    """ Decorator replaces custom skip test markup in doctests
68
69    Say a function has a docstring::
70
71        >>> something, HAVE_AMODULE, HAVE_BMODULE = 0, False, False
72        >>> something # skip if not HAVE_AMODULE
73        0
74        >>> something # skip if HAVE_BMODULE
75        0
76
77    This decorator will evaluate the expression after ``skip if``.  If this
78    evaluates to True, then the comment is replaced by ``# doctest: +SKIP``. If
79    False, then the comment is just removed. The expression is evaluated in the
80    ``globals`` scope of `func`.
81
82    For example, if the module global ``HAVE_AMODULE`` is False, and module
83    global ``HAVE_BMODULE`` is False, the returned function will have docstring::
84
85        >>> something # doctest: +SKIP
86        >>> something + else # doctest: +SKIP
87        >>> something # doctest: +SKIP
88
89    """
90    lines = func.__doc__.split('\n')
91    new_lines = []
92    for line in lines:
93        match = SKIP_RE.match(line)
94        if match is None:
95            new_lines.append(line)
96            continue
97        code, space, expr = match.groups()
98
99        try:
100            # Works as a function decorator
101            if eval(expr, func.__globals__):
102                code = code + space + "# doctest: +SKIP"
103        except AttributeError:
104            # Works as a class decorator
105            if eval(expr, func.__init__.__globals__):
106                code = code + space + "# doctest: +SKIP"
107
108        new_lines.append(code)
109    func.__doc__ = "\n".join(new_lines)
110    return func
111
112
113def roundtrip(image, plugin, suffix):
114    """Save and read an image using a specified plugin"""
115    if '.' not in suffix:
116        suffix = '.' + suffix
117    temp_file = NamedTemporaryFile(suffix=suffix, delete=False)
118    fname = temp_file.name
119    temp_file.close()
120    io.imsave(fname, image, plugin=plugin)
121    new = io.imread(fname, plugin=plugin)
122    try:
123        os.remove(fname)
124    except Exception:
125        pass
126    return new
127
128
129def color_check(plugin, fmt='png'):
130    """Check roundtrip behavior for color images.
131
132    All major input types should be handled as ubytes and read
133    back correctly.
134    """
135    img = img_as_ubyte(data.chelsea())
136    r1 = roundtrip(img, plugin, fmt)
137    testing.assert_allclose(img, r1)
138
139    img2 = img > 128
140    r2 = roundtrip(img2, plugin, fmt)
141    testing.assert_allclose(img2, r2.astype(bool))
142
143    img3 = img_as_float(img)
144    r3 = roundtrip(img3, plugin, fmt)
145    testing.assert_allclose(r3, img)
146
147    img4 = img_as_int(img)
148    if fmt.lower() in (('tif', 'tiff')):
149        img4 -= 100
150        r4 = roundtrip(img4, plugin, fmt)
151        testing.assert_allclose(r4, img4)
152    else:
153        r4 = roundtrip(img4, plugin, fmt)
154        testing.assert_allclose(r4, img_as_ubyte(img4))
155
156    img5 = img_as_uint(img)
157    r5 = roundtrip(img5, plugin, fmt)
158    testing.assert_allclose(r5, img)
159
160
161def mono_check(plugin, fmt='png'):
162    """Check the roundtrip behavior for images that support most types.
163
164    All major input types should be handled.
165    """
166
167    img = img_as_ubyte(data.moon())
168    r1 = roundtrip(img, plugin, fmt)
169    testing.assert_allclose(img, r1)
170
171    img2 = img > 128
172    r2 = roundtrip(img2, plugin, fmt)
173    testing.assert_allclose(img2, r2.astype(bool))
174
175    img3 = img_as_float(img)
176    r3 = roundtrip(img3, plugin, fmt)
177    if r3.dtype.kind == 'f':
178        testing.assert_allclose(img3, r3)
179    else:
180        testing.assert_allclose(r3, img_as_uint(img))
181
182    img4 = img_as_int(img)
183    if fmt.lower() in (('tif', 'tiff')):
184        img4 -= 100
185        r4 = roundtrip(img4, plugin, fmt)
186        testing.assert_allclose(r4, img4)
187    else:
188        r4 = roundtrip(img4, plugin, fmt)
189        testing.assert_allclose(r4, img_as_uint(img4))
190
191    img5 = img_as_uint(img)
192    r5 = roundtrip(img5, plugin, fmt)
193    testing.assert_allclose(r5, img5)
194
195
196def setup_test():
197    """Default package level setup routine for skimage tests.
198
199    Import packages known to raise warnings, and then
200    force warnings to raise errors.
201
202    Also set the random seed to zero.
203    """
204    warnings.simplefilter('default')
205
206    if _error_on_warnings:
207        from scipy import signal, ndimage, special, optimize, linalg
208        from scipy.io import loadmat
209        from skimage import viewer
210
211        np.random.seed(0)
212
213        warnings.simplefilter('error')
214
215        # do not error on specific warnings from the skimage.io module
216        # https://github.com/scikit-image/scikit-image/issues/5337
217        warnings.filterwarnings(
218            'default', message='TiffFile:', category=DeprecationWarning
219        )
220
221        warnings.filterwarnings(
222            'default', message='TiffWriter:', category=DeprecationWarning
223        )
224
225        warnings.filterwarnings(
226            'default', message='unclosed file', category=ResourceWarning
227        )
228
229        # ignore known FutureWarnings from viewer module
230        warnings.filterwarnings(
231            'ignore', category=FutureWarning, module='skimage.viewer'
232        )
233
234        # Ignore other warnings only seen when using older versions of
235        # dependencies.
236        warnings.filterwarnings(
237            'default',
238            message='Conversion of the second argument of issubdtype',
239            category=FutureWarning
240        )
241
242        warnings.filterwarnings(
243            'default',
244            message='the matrix subclass is not the recommended way',
245            category=PendingDeprecationWarning, module='numpy'
246        )
247
248        warnings.filterwarnings(
249            'default',
250            message='Your installed pillow version',
251            category=UserWarning,
252            module='skimage.io'
253        )
254
255        warnings.filterwarnings(
256            'default', message='Viewer requires Qt', category=UserWarning
257        )
258
259        warnings.filterwarnings(
260            'default',
261            message='numpy.ufunc size changed',
262            category=RuntimeWarning
263        )
264
265
266def teardown_test():
267    """Default package level teardown routine for skimage tests.
268
269    Restore warnings to default behavior
270    """
271    if _error_on_warnings:
272        warnings.resetwarnings()
273        warnings.simplefilter('default')
274
275
276def fetch(data_filename):
277    """Attempt to fetch data, but if unavailable, skip the tests."""
278    try:
279        return _fetch(data_filename)
280    except (ConnectionError, ModuleNotFoundError):
281        pytest.skip(f'Unable to download {data_filename}',
282                    allow_module_level=True)
283
284
285def test_parallel(num_threads=2, warnings_matching=None):
286    """Decorator to run the same function multiple times in parallel.
287
288    This decorator is useful to ensure that separate threads execute
289    concurrently and correctly while releasing the GIL.
290
291    Parameters
292    ----------
293    num_threads : int, optional
294        The number of times the function is run in parallel.
295
296    warnings_matching: list or None
297        This parameter is passed on to `expected_warnings` so as not to have
298        race conditions with the warnings filters. A single
299        `expected_warnings` context manager is used for all threads.
300        If None, then no warnings are checked.
301
302    """
303
304    assert num_threads > 0
305
306    def wrapper(func):
307        @functools.wraps(func)
308        def inner(*args, **kwargs):
309            with expected_warnings(warnings_matching):
310                threads = []
311                for i in range(num_threads - 1):
312                    thread = threading.Thread(target=func, args=args,
313                                              kwargs=kwargs)
314                    threads.append(thread)
315                for thread in threads:
316                    thread.start()
317
318                result = func(*args, **kwargs)
319
320                for thread in threads:
321                    thread.join()
322
323                return result
324
325        return inner
326
327    return wrapper
328
329
330if __name__ == '__main__':
331    color_check('pil')
332    mono_check('pil')
333    mono_check('pil', 'bmp')
334    mono_check('pil', 'tiff')
335