1"""py.test hacks to support XFAIL/XPASS"""
2
3import sys
4import functools
5import os
6import contextlib
7import warnings
8
9from sympy.utilities.exceptions import SymPyDeprecationWarning
10
11ON_TRAVIS = os.getenv('TRAVIS_BUILD_NUMBER', None)
12
13try:
14    import pytest
15    USE_PYTEST = getattr(sys, '_running_pytest', False)
16except ImportError:
17    USE_PYTEST = False
18
19
20if USE_PYTEST:
21    raises = pytest.raises
22    warns = pytest.warns
23    skip = pytest.skip
24    XFAIL = pytest.mark.xfail
25    SKIP = pytest.mark.skip
26    slow = pytest.mark.slow
27    nocache_fail = pytest.mark.nocache_fail
28    from _pytest.outcomes import Failed
29
30else:
31    # Not using pytest so define the things that would have been imported from
32    # there.
33
34    # _pytest._code.code.ExceptionInfo
35    class ExceptionInfo:
36        def __init__(self, value):
37            self.value = value
38
39        def __repr__(self):
40            return "<ExceptionInfo {!r}>".format(self.value)
41
42
43    def raises(expectedException, code=None):
44        """
45        Tests that ``code`` raises the exception ``expectedException``.
46
47        ``code`` may be a callable, such as a lambda expression or function
48        name.
49
50        If ``code`` is not given or None, ``raises`` will return a context
51        manager for use in ``with`` statements; the code to execute then
52        comes from the scope of the ``with``.
53
54        ``raises()`` does nothing if the callable raises the expected exception,
55        otherwise it raises an AssertionError.
56
57        Examples
58        ========
59
60        >>> from sympy.testing.pytest import raises
61
62        >>> raises(ZeroDivisionError, lambda: 1/0)
63        <ExceptionInfo ZeroDivisionError(...)>
64        >>> raises(ZeroDivisionError, lambda: 1/2)
65        Traceback (most recent call last):
66        ...
67        Failed: DID NOT RAISE
68
69        >>> with raises(ZeroDivisionError):
70        ...     n = 1/0
71        >>> with raises(ZeroDivisionError):
72        ...     n = 1/2
73        Traceback (most recent call last):
74        ...
75        Failed: DID NOT RAISE
76
77        Note that you cannot test multiple statements via
78        ``with raises``:
79
80        >>> with raises(ZeroDivisionError):
81        ...     n = 1/0    # will execute and raise, aborting the ``with``
82        ...     n = 9999/0 # never executed
83
84        This is just what ``with`` is supposed to do: abort the
85        contained statement sequence at the first exception and let
86        the context manager deal with the exception.
87
88        To test multiple statements, you'll need a separate ``with``
89        for each:
90
91        >>> with raises(ZeroDivisionError):
92        ...     n = 1/0    # will execute and raise
93        >>> with raises(ZeroDivisionError):
94        ...     n = 9999/0 # will also execute and raise
95
96        """
97        if code is None:
98            return RaisesContext(expectedException)
99        elif callable(code):
100            try:
101                code()
102            except expectedException as e:
103                return ExceptionInfo(e)
104            raise Failed("DID NOT RAISE")
105        elif isinstance(code, str):
106            raise TypeError(
107                '\'raises(xxx, "code")\' has been phased out; '
108                'change \'raises(xxx, "expression")\' '
109                'to \'raises(xxx, lambda: expression)\', '
110                '\'raises(xxx, "statement")\' '
111                'to \'with raises(xxx): statement\'')
112        else:
113            raise TypeError(
114                'raises() expects a callable for the 2nd argument.')
115
116    class RaisesContext:
117        def __init__(self, expectedException):
118            self.expectedException = expectedException
119
120        def __enter__(self):
121            return None
122
123        def __exit__(self, exc_type, exc_value, traceback):
124            if exc_type is None:
125                raise Failed("DID NOT RAISE")
126            return issubclass(exc_type, self.expectedException)
127
128    class XFail(Exception):
129        pass
130
131    class XPass(Exception):
132        pass
133
134    class Skipped(Exception):
135        pass
136
137    class Failed(Exception):  # type: ignore
138        pass
139
140    def XFAIL(func):
141        def wrapper():
142            try:
143                func()
144            except Exception as e:
145                message = str(e)
146                if message != "Timeout":
147                    raise XFail(func.__name__)
148                else:
149                    raise Skipped("Timeout")
150            raise XPass(func.__name__)
151
152        wrapper = functools.update_wrapper(wrapper, func)
153        return wrapper
154
155    def skip(str):
156        raise Skipped(str)
157
158    def SKIP(reason):
159        """Similar to ``skip()``, but this is a decorator. """
160        def wrapper(func):
161            def func_wrapper():
162                raise Skipped(reason)
163
164            func_wrapper = functools.update_wrapper(func_wrapper, func)
165            return func_wrapper
166
167        return wrapper
168
169    def slow(func):
170        func._slow = True
171
172        def func_wrapper():
173            func()
174
175        func_wrapper = functools.update_wrapper(func_wrapper, func)
176        func_wrapper.__wrapped__ = func
177        return func_wrapper
178
179    def nocache_fail(func):
180        "Dummy decorator for marking tests that fail when cache is disabled"
181        return func
182
183    @contextlib.contextmanager
184    def warns(warningcls, *, match=''):
185        '''Like raises but tests that warnings are emitted.
186
187        >>> from sympy.testing.pytest import warns
188        >>> import warnings
189
190        >>> with warns(UserWarning):
191        ...     warnings.warn('deprecated', UserWarning)
192
193        >>> with warns(UserWarning):
194        ...     pass
195        Traceback (most recent call last):
196        ...
197        Failed: DID NOT WARN. No warnings of type UserWarning\
198        was emitted. The list of emitted warnings is: [].
199        '''
200        # Absorbs all warnings in warnrec
201        with warnings.catch_warnings(record=True) as warnrec:
202            # Hide all warnings but make sure that our warning is emitted
203            warnings.simplefilter("ignore")
204            warnings.filterwarnings("always", match, warningcls)
205            # Now run the test
206            yield
207
208        # Raise if expected warning not found
209        if not any(issubclass(w.category, warningcls) for w in warnrec):
210            msg = ('Failed: DID NOT WARN.'
211                   ' No warnings of type %s was emitted.'
212                   ' The list of emitted warnings is: %s.'
213                   ) % (warningcls, [w.message for w in warnrec])
214            raise Failed(msg)
215
216
217def _both_exp_pow(func):
218    """
219    Decorator used to run the test twice: the first time `e^x` is represented
220    as ``Pow(E, x)``, the second time as ``exp(x)`` (exponential object is not
221    a power).
222
223    This is a temporary trick helping to manage the elimination of the class
224    ``exp`` in favor of a replacement by ``Pow(E, ...)``.
225    """
226    from sympy.core.parameters import _exp_is_pow
227
228    def func_wrap():
229        with _exp_is_pow(True):
230            func()
231        with _exp_is_pow(False):
232            func()
233
234    wrapper = functools.update_wrapper(func_wrap, func)
235    return wrapper
236
237
238@contextlib.contextmanager
239def warns_deprecated_sympy():
240    '''Shorthand for ``warns(SymPyDeprecationWarning)``
241
242    This is the recommended way to test that ``SymPyDeprecationWarning`` is
243    emitted for deprecated features in SymPy. To test for other warnings use
244    ``warns``. To suppress warnings without asserting that they are emitted
245    use ``ignore_warnings``.
246
247    >>> from sympy.testing.pytest import warns_deprecated_sympy
248    >>> from sympy.utilities.exceptions import SymPyDeprecationWarning
249    >>> with warns_deprecated_sympy():
250    ...     SymPyDeprecationWarning("Don't use", feature="old thing",
251    ...         deprecated_since_version="1.0", issue=123).warn()
252
253    >>> with warns_deprecated_sympy():
254    ...     pass
255    Traceback (most recent call last):
256    ...
257    Failed: DID NOT WARN. No warnings of type \
258    SymPyDeprecationWarning was emitted. The list of emitted warnings is: [].
259    '''
260    with warns(SymPyDeprecationWarning):
261        yield
262
263@contextlib.contextmanager
264def ignore_warnings(warningcls):
265    '''Context manager to suppress warnings during tests.
266
267    This function is useful for suppressing warnings during tests. The warns
268    function should be used to assert that a warning is raised. The
269    ignore_warnings function is useful in situation when the warning is not
270    guaranteed to be raised (e.g. on importing a module) or if the warning
271    comes from third-party code.
272
273    When the warning is coming (reliably) from SymPy the warns function should
274    be preferred to ignore_warnings.
275
276    >>> from sympy.testing.pytest import ignore_warnings
277    >>> import warnings
278
279    Here's a warning:
280
281    >>> with warnings.catch_warnings():  # reset warnings in doctest
282    ...     warnings.simplefilter('error')
283    ...     warnings.warn('deprecated', UserWarning)
284    Traceback (most recent call last):
285      ...
286    UserWarning: deprecated
287
288    Let's suppress it with ignore_warnings:
289
290    >>> with warnings.catch_warnings():  # reset warnings in doctest
291    ...     warnings.simplefilter('error')
292    ...     with ignore_warnings(UserWarning):
293    ...         warnings.warn('deprecated', UserWarning)
294
295    (No warning emitted)
296    '''
297    # Absorbs all warnings in warnrec
298    with warnings.catch_warnings(record=True) as warnrec:
299        # Make sure our warning doesn't get filtered
300        warnings.simplefilter("always", warningcls)
301        # Now run the test
302        yield
303
304    # Reissue any warnings that we aren't testing for
305    for w in warnrec:
306        if not issubclass(w.category, warningcls):
307            warnings.warn_explicit(w.message, w.category, w.filename, w.lineno)
308