1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2"""
3This module provides the tools used to internally run the astropy test suite
4from the installed astropy.  It makes use of the `pytest`_ testing framework.
5"""
6import os
7import sys
8import types
9import pickle
10import warnings
11import functools
12
13import pytest
14
15from astropy.units import allclose as quantity_allclose  # noqa
16from astropy.utils.exceptions import (AstropyDeprecationWarning,
17                                      AstropyPendingDeprecationWarning)
18
19
20# For backward-compatibility with affiliated packages
21from .runner import TestRunner  # pylint: disable=W0611  # noqa
22
23__all__ = ['raises', 'enable_deprecations_as_exceptions', 'remote_data',
24           'treat_deprecations_as_exceptions', 'catch_warnings',
25           'assert_follows_unicode_guidelines',
26           'assert_quantity_allclose', 'check_pickling_recovery',
27           'pickle_protocol', 'generic_recursive_equality_test']
28
29# pytest marker to mark tests which get data from the web
30# This is being maintained for backwards compatibility
31remote_data = pytest.mark.remote_data
32
33
34# distutils expects options to be Unicode strings
35def _fix_user_options(options):
36    def to_str_or_none(x):
37        if x is None:
38            return None
39        return str(x)
40
41    return [tuple(to_str_or_none(x) for x in y) for y in options]
42
43
44def _save_coverage(cov, result, rootdir, testing_path):
45    """
46    This method is called after the tests have been run in coverage mode
47    to cleanup and then save the coverage data and report.
48    """
49    from astropy.utils.console import color_print
50
51    if result != 0:
52        return
53
54    # The coverage report includes the full path to the temporary
55    # directory, so we replace all the paths with the true source
56    # path. Note that this will not work properly for packages that still
57    # rely on 2to3.
58    try:
59        # Coverage 4.0: _harvest_data has been renamed to get_data, the
60        # lines dict is private
61        cov.get_data()
62    except AttributeError:
63        # Coverage < 4.0
64        cov._harvest_data()
65        lines = cov.data.lines
66    else:
67        lines = cov.data._lines
68
69    for key in list(lines.keys()):
70        new_path = os.path.relpath(
71            os.path.realpath(key),
72            os.path.realpath(testing_path))
73        new_path = os.path.abspath(
74            os.path.join(rootdir, new_path))
75        lines[new_path] = lines.pop(key)
76
77    color_print('Saving coverage data in .coverage...', 'green')
78    cov.save()
79
80    color_print('Saving HTML coverage report in htmlcov...', 'green')
81    cov.html_report(directory=os.path.join(rootdir, 'htmlcov'))
82
83
84# TODO: Plan a roadmap of deprecation as pytest.raises has matured over the years.
85# See https://github.com/astropy/astropy/issues/6761
86class raises:
87    """
88    A decorator to mark that a test should raise a given exception.
89    Use as follows::
90
91        @raises(ZeroDivisionError)
92        def test_foo():
93            x = 1/0
94
95    This can also be used a context manager, in which case it is just
96    an alias for the ``pytest.raises`` context manager (because the
97    two have the same name this help avoid confusion by being
98    flexible).
99
100    .. note:: Usage of ``pytest.raises`` is preferred.
101
102    """
103
104    # pep-8 naming exception -- this is a decorator class
105    def __init__(self, exc):
106        self._exc = exc
107        self._ctx = None
108
109    def __call__(self, func):
110        @functools.wraps(func)
111        def run_raises_test(*args, **kwargs):
112            pytest.raises(self._exc, func, *args, **kwargs)
113        return run_raises_test
114
115    def __enter__(self):
116        self._ctx = pytest.raises(self._exc)
117        return self._ctx.__enter__()
118
119    def __exit__(self, *exc_info):
120        return self._ctx.__exit__(*exc_info)
121
122
123_deprecations_as_exceptions = False
124_include_astropy_deprecations = True
125_modules_to_ignore_on_import = set([
126    r'compiler',  # A deprecated stdlib module used by pytest
127    r'scipy',
128    r'pygments',
129    r'ipykernel',
130    r'IPython',   # deprecation warnings for async and await
131    r'setuptools'])
132_warnings_to_ignore_entire_module = set([])
133_warnings_to_ignore_by_pyver = {
134    None: set([  # Python version agnostic
135        # https://github.com/astropy/astropy/pull/7372
136        (r"Importing from numpy\.testing\.decorators is deprecated, "
137         r"import from numpy\.testing instead\.", DeprecationWarning),
138        # inspect raises this slightly different warning on Python 3.7.
139        # Keeping it since e.g. lxml as of 3.8.0 is still calling getargspec()
140        (r"inspect\.getargspec\(\) is deprecated, use "
141         r"inspect\.signature\(\) or inspect\.getfullargspec\(\)",
142         DeprecationWarning),
143        # https://github.com/astropy/pytest-doctestplus/issues/29
144        (r"split\(\) requires a non-empty pattern match", FutureWarning),
145        # Package resolution warning that we can do nothing about
146        (r"can't resolve package from __spec__ or __package__, "
147         r"falling back on __name__ and __path__", ImportWarning)]),
148    (3, 7): set([
149        # Deprecation warning for collections.abc, fixed in Astropy but still
150        # used in lxml, and maybe others
151        (r"Using or importing the ABCs from 'collections'",
152         DeprecationWarning)])
153}
154
155
156def enable_deprecations_as_exceptions(include_astropy_deprecations=True,
157                                      modules_to_ignore_on_import=[],
158                                      warnings_to_ignore_entire_module=[],
159                                      warnings_to_ignore_by_pyver={}):
160    """
161    Turn on the feature that turns deprecations into exceptions.
162
163    Parameters
164    ----------
165    include_astropy_deprecations : bool
166        If set to `True`, ``AstropyDeprecationWarning`` and
167        ``AstropyPendingDeprecationWarning`` are also turned into exceptions.
168
169    modules_to_ignore_on_import : list of str
170        List of additional modules that generate deprecation warnings
171        on import, which are to be ignored. By default, these are already
172        included: ``compiler``, ``scipy``, ``pygments``, ``ipykernel``, and
173        ``setuptools``.
174
175    warnings_to_ignore_entire_module : list of str
176        List of modules with deprecation warnings to ignore completely,
177        not just during import. If ``include_astropy_deprecations=True``
178        is given, ``AstropyDeprecationWarning`` and
179        ``AstropyPendingDeprecationWarning`` are also ignored for the modules.
180
181    warnings_to_ignore_by_pyver : dict
182        Dictionary mapping tuple of ``(major, minor)`` Python version to
183        a list of ``(warning_message, warning_class)`` to ignore.
184        Python version-agnostic warnings should be mapped to `None` key.
185        This is in addition of those already ignored by default
186        (see ``_warnings_to_ignore_by_pyver`` values).
187
188    """
189    global _deprecations_as_exceptions
190    _deprecations_as_exceptions = True
191
192    global _include_astropy_deprecations
193    _include_astropy_deprecations = include_astropy_deprecations
194
195    global _modules_to_ignore_on_import
196    _modules_to_ignore_on_import.update(modules_to_ignore_on_import)
197
198    global _warnings_to_ignore_entire_module
199    _warnings_to_ignore_entire_module.update(warnings_to_ignore_entire_module)
200
201    global _warnings_to_ignore_by_pyver
202    for key, val in warnings_to_ignore_by_pyver.items():
203        if key in _warnings_to_ignore_by_pyver:
204            _warnings_to_ignore_by_pyver[key].update(val)
205        else:
206            _warnings_to_ignore_by_pyver[key] = set(val)
207
208
209def treat_deprecations_as_exceptions():
210    """
211    Turn all DeprecationWarnings (which indicate deprecated uses of
212    Python itself or Numpy, but not within Astropy, where we use our
213    own deprecation warning class) into exceptions so that we find
214    out about them early.
215
216    This completely resets the warning filters and any "already seen"
217    warning state.
218    """
219    # First, totally reset the warning state. The modules may change during
220    # this iteration thus we copy the original state to a list to iterate
221    # on. See https://github.com/astropy/astropy/pull/5513.
222    for module in list(sys.modules.values()):
223        try:
224            del module.__warningregistry__
225        except Exception:
226            pass
227
228    if not _deprecations_as_exceptions:
229        return
230
231    warnings.resetwarnings()
232
233    # Hide the next couple of DeprecationWarnings
234    warnings.simplefilter('ignore', DeprecationWarning)
235    # Here's the wrinkle: a couple of our third-party dependencies
236    # (pytest and scipy) are still using deprecated features
237    # themselves, and we'd like to ignore those.  Fortunately, those
238    # show up only at import time, so if we import those things *now*,
239    # before we turn the warnings into exceptions, we're golden.
240    for m in _modules_to_ignore_on_import:
241        try:
242            __import__(m)
243        except ImportError:
244            pass
245
246    # Now, start over again with the warning filters
247    warnings.resetwarnings()
248    # Now, turn these warnings into exceptions
249    _all_warns = [DeprecationWarning, FutureWarning, ImportWarning]
250
251    # Only turn astropy deprecation warnings into exceptions if requested
252    if _include_astropy_deprecations:
253        _all_warns += [AstropyDeprecationWarning,
254                       AstropyPendingDeprecationWarning]
255
256    for w in _all_warns:
257        warnings.filterwarnings("error", ".*", w)
258
259    # This ignores all specified warnings from given module(s),
260    # not just on import, for use of Astropy affiliated packages.
261    for m in _warnings_to_ignore_entire_module:
262        for w in _all_warns:
263            warnings.filterwarnings('ignore', category=w, module=m)
264
265    # This ignores only specified warnings by Python version, if applicable.
266    for v in _warnings_to_ignore_by_pyver:
267        if v is None or sys.version_info[:2] == v:
268            for s in _warnings_to_ignore_by_pyver[v]:
269                warnings.filterwarnings("ignore", s[0], s[1])
270
271
272# TODO: Plan a roadmap of deprecation as pytest.warns has matured over the years.
273# See https://github.com/astropy/astropy/issues/6761
274class catch_warnings(warnings.catch_warnings):
275    """
276    A high-powered version of warnings.catch_warnings to use for testing
277    and to make sure that there is no dependence on the order in which
278    the tests are run.
279
280    This completely blitzes any memory of any warnings that have
281    appeared before so that all warnings will be caught and displayed.
282
283    ``*args`` is a set of warning classes to collect.  If no arguments are
284    provided, all warnings are collected.
285
286    Use as follows::
287
288        with catch_warnings(MyCustomWarning) as w:
289            do.something.bad()
290        assert len(w) > 0
291
292    .. note:: Usage of :ref:`pytest.warns <pytest:warns>` is preferred.
293
294    """
295
296    def __init__(self, *classes):
297        super().__init__(record=True)
298        self.classes = classes
299
300    def __enter__(self):
301        warning_list = super().__enter__()
302        treat_deprecations_as_exceptions()
303        if len(self.classes) == 0:
304            warnings.simplefilter('always')
305        else:
306            warnings.simplefilter('ignore')
307            for cls in self.classes:
308                warnings.simplefilter('always', cls)
309        return warning_list
310
311    def __exit__(self, type, value, traceback):
312        treat_deprecations_as_exceptions()
313
314
315class ignore_warnings(catch_warnings):
316    """
317    This can be used either as a context manager or function decorator to
318    ignore all warnings that occur within a function or block of code.
319
320    An optional category option can be supplied to only ignore warnings of a
321    certain category or categories (if a list is provided).
322    """
323
324    def __init__(self, category=None):
325        super().__init__()
326
327        if isinstance(category, type) and issubclass(category, Warning):
328            self.category = [category]
329        else:
330            self.category = category
331
332    def __call__(self, func):
333        @functools.wraps(func)
334        def wrapper(*args, **kwargs):
335            # Originally this just reused self, but that doesn't work if the
336            # function is called more than once so we need to make a new
337            # context manager instance for each call
338            with self.__class__(category=self.category):
339                return func(*args, **kwargs)
340
341        return wrapper
342
343    def __enter__(self):
344        retval = super().__enter__()
345        if self.category is not None:
346            for category in self.category:
347                warnings.simplefilter('ignore', category)
348        else:
349            warnings.simplefilter('ignore')
350        return retval
351
352
353def assert_follows_unicode_guidelines(
354        x, roundtrip=None):
355    """
356    Test that an object follows our Unicode policy.  See
357    "Unicode guidelines" in the coding guidelines.
358
359    Parameters
360    ----------
361    x : object
362        The instance to test
363
364    roundtrip : module, optional
365        When provided, this namespace will be used to evaluate
366        ``repr(x)`` and ensure that it roundtrips.  It will also
367        ensure that ``__bytes__(x)`` roundtrip.
368        If not provided, no roundtrip testing will be performed.
369    """
370    from astropy import conf
371
372    with conf.set_temp('unicode_output', False):
373        bytes_x = bytes(x)
374        unicode_x = str(x)
375        repr_x = repr(x)
376
377        assert isinstance(bytes_x, bytes)
378        bytes_x.decode('ascii')
379        assert isinstance(unicode_x, str)
380        unicode_x.encode('ascii')
381        assert isinstance(repr_x, str)
382        if isinstance(repr_x, bytes):
383            repr_x.decode('ascii')
384        else:
385            repr_x.encode('ascii')
386
387        if roundtrip is not None:
388            assert x.__class__(bytes_x) == x
389            assert x.__class__(unicode_x) == x
390            assert eval(repr_x, roundtrip) == x
391
392    with conf.set_temp('unicode_output', True):
393        bytes_x = bytes(x)
394        unicode_x = str(x)
395        repr_x = repr(x)
396
397        assert isinstance(bytes_x, bytes)
398        bytes_x.decode('ascii')
399        assert isinstance(unicode_x, str)
400        assert isinstance(repr_x, str)
401        if isinstance(repr_x, bytes):
402            repr_x.decode('ascii')
403        else:
404            repr_x.encode('ascii')
405
406        if roundtrip is not None:
407            assert x.__class__(bytes_x) == x
408            assert x.__class__(unicode_x) == x
409            assert eval(repr_x, roundtrip) == x
410
411
412@pytest.fixture(params=[0, 1, -1])
413def pickle_protocol(request):
414    """
415    Fixture to run all the tests for protocols 0 and 1, and -1 (most advanced).
416    (Originally from astropy.table.tests.test_pickle)
417    """
418    return request.param
419
420
421def generic_recursive_equality_test(a, b, class_history):
422    """
423    Check if the attributes of a and b are equal. Then,
424    check if the attributes of the attributes are equal.
425    """
426    dict_a = a.__getstate__() if hasattr(a, '__getstate__') else a.__dict__
427    dict_b = b.__dict__
428    for key in dict_a:
429        assert key in dict_b,\
430          f"Did not pickle {key}"
431        if hasattr(dict_a[key], '__eq__'):
432            eq = (dict_a[key] == dict_b[key])
433            if '__iter__' in dir(eq):
434                eq = (False not in eq)
435            assert eq, f"Value of {key} changed by pickling"
436
437        if hasattr(dict_a[key], '__dict__'):
438            if dict_a[key].__class__ in class_history:
439                # attempt to prevent infinite recursion
440                pass
441            else:
442                new_class_history = [dict_a[key].__class__]
443                new_class_history.extend(class_history)
444                generic_recursive_equality_test(dict_a[key],
445                                                dict_b[key],
446                                                new_class_history)
447
448
449def check_pickling_recovery(original, protocol):
450    """
451    Try to pickle an object. If successful, make sure
452    the object's attributes survived pickling and unpickling.
453    """
454    f = pickle.dumps(original, protocol=protocol)
455    unpickled = pickle.loads(f)
456    class_history = [original.__class__]
457    generic_recursive_equality_test(original, unpickled,
458                                    class_history)
459
460
461def assert_quantity_allclose(actual, desired, rtol=1.e-7, atol=None,
462                             **kwargs):
463    """
464    Raise an assertion if two objects are not equal up to desired tolerance.
465
466    This is a :class:`~astropy.units.Quantity`-aware version of
467    :func:`numpy.testing.assert_allclose`.
468    """
469    import numpy as np
470    from astropy.units.quantity import _unquantify_allclose_arguments
471    np.testing.assert_allclose(*_unquantify_allclose_arguments(
472        actual, desired, rtol, atol), **kwargs)
473