1import math
2import pprint
3from collections.abc import Iterable
4from collections.abc import Mapping
5from collections.abc import Sized
6from decimal import Decimal
7from numbers import Number
8from types import TracebackType
9from typing import Any
10from typing import Callable
11from typing import cast
12from typing import Generic
13from typing import Optional
14from typing import Pattern
15from typing import Tuple
16from typing import TypeVar
17from typing import Union
18
19import _pytest._code
20from _pytest.compat import final
21from _pytest.compat import overload
22from _pytest.compat import STRING_TYPES
23from _pytest.compat import TYPE_CHECKING
24from _pytest.outcomes import fail
25
26if TYPE_CHECKING:
27    from typing import Type
28
29
30def _non_numeric_type_error(value, at: Optional[str]) -> TypeError:
31    at_str = " at {}".format(at) if at else ""
32    return TypeError(
33        "cannot make approximate comparisons to non-numeric values: {!r} {}".format(
34            value, at_str
35        )
36    )
37
38
39# builtin pytest.approx helper
40
41
42class ApproxBase:
43    """Provide shared utilities for making approximate comparisons between
44    numbers or sequences of numbers."""
45
46    # Tell numpy to use our `__eq__` operator instead of its.
47    __array_ufunc__ = None
48    __array_priority__ = 100
49
50    def __init__(self, expected, rel=None, abs=None, nan_ok: bool = False) -> None:
51        __tracebackhide__ = True
52        self.expected = expected
53        self.abs = abs
54        self.rel = rel
55        self.nan_ok = nan_ok
56        self._check_type()
57
58    def __repr__(self) -> str:
59        raise NotImplementedError
60
61    def __eq__(self, actual) -> bool:
62        return all(
63            a == self._approx_scalar(x) for a, x in self._yield_comparisons(actual)
64        )
65
66    # Ignore type because of https://github.com/python/mypy/issues/4266.
67    __hash__ = None  # type: ignore
68
69    def __ne__(self, actual) -> bool:
70        return not (actual == self)
71
72    def _approx_scalar(self, x) -> "ApproxScalar":
73        return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
74
75    def _yield_comparisons(self, actual):
76        """Yield all the pairs of numbers to be compared.
77
78        This is used to implement the `__eq__` method.
79        """
80        raise NotImplementedError
81
82    def _check_type(self) -> None:
83        """Raise a TypeError if the expected value is not a valid type."""
84        # This is only a concern if the expected value is a sequence.  In every
85        # other case, the approx() function ensures that the expected value has
86        # a numeric type.  For this reason, the default is to do nothing.  The
87        # classes that deal with sequences should reimplement this method to
88        # raise if there are any non-numeric elements in the sequence.
89        pass
90
91
92def _recursive_list_map(f, x):
93    if isinstance(x, list):
94        return list(_recursive_list_map(f, xi) for xi in x)
95    else:
96        return f(x)
97
98
99class ApproxNumpy(ApproxBase):
100    """Perform approximate comparisons where the expected value is numpy array."""
101
102    def __repr__(self) -> str:
103        list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist())
104        return "approx({!r})".format(list_scalars)
105
106    def __eq__(self, actual) -> bool:
107        import numpy as np
108
109        # self.expected is supposed to always be an array here.
110
111        if not np.isscalar(actual):
112            try:
113                actual = np.asarray(actual)
114            except Exception as e:
115                raise TypeError(
116                    "cannot compare '{}' to numpy.ndarray".format(actual)
117                ) from e
118
119        if not np.isscalar(actual) and actual.shape != self.expected.shape:
120            return False
121
122        return ApproxBase.__eq__(self, actual)
123
124    def _yield_comparisons(self, actual):
125        import numpy as np
126
127        # `actual` can either be a numpy array or a scalar, it is treated in
128        # `__eq__` before being passed to `ApproxBase.__eq__`, which is the
129        # only method that calls this one.
130
131        if np.isscalar(actual):
132            for i in np.ndindex(self.expected.shape):
133                yield actual, self.expected[i].item()
134        else:
135            for i in np.ndindex(self.expected.shape):
136                yield actual[i].item(), self.expected[i].item()
137
138
139class ApproxMapping(ApproxBase):
140    """Perform approximate comparisons where the expected value is a mapping
141    with numeric values (the keys can be anything)."""
142
143    def __repr__(self) -> str:
144        return "approx({!r})".format(
145            {k: self._approx_scalar(v) for k, v in self.expected.items()}
146        )
147
148    def __eq__(self, actual) -> bool:
149        if set(actual.keys()) != set(self.expected.keys()):
150            return False
151
152        return ApproxBase.__eq__(self, actual)
153
154    def _yield_comparisons(self, actual):
155        for k in self.expected.keys():
156            yield actual[k], self.expected[k]
157
158    def _check_type(self) -> None:
159        __tracebackhide__ = True
160        for key, value in self.expected.items():
161            if isinstance(value, type(self.expected)):
162                msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n  full mapping={}"
163                raise TypeError(msg.format(key, value, pprint.pformat(self.expected)))
164            elif not isinstance(value, Number):
165                raise _non_numeric_type_error(self.expected, at="key={!r}".format(key))
166
167
168class ApproxSequencelike(ApproxBase):
169    """Perform approximate comparisons where the expected value is a sequence of numbers."""
170
171    def __repr__(self) -> str:
172        seq_type = type(self.expected)
173        if seq_type not in (tuple, list, set):
174            seq_type = list
175        return "approx({!r})".format(
176            seq_type(self._approx_scalar(x) for x in self.expected)
177        )
178
179    def __eq__(self, actual) -> bool:
180        if len(actual) != len(self.expected):
181            return False
182        return ApproxBase.__eq__(self, actual)
183
184    def _yield_comparisons(self, actual):
185        return zip(actual, self.expected)
186
187    def _check_type(self) -> None:
188        __tracebackhide__ = True
189        for index, x in enumerate(self.expected):
190            if isinstance(x, type(self.expected)):
191                msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n  full sequence: {}"
192                raise TypeError(msg.format(x, index, pprint.pformat(self.expected)))
193            elif not isinstance(x, Number):
194                raise _non_numeric_type_error(
195                    self.expected, at="index {}".format(index)
196                )
197
198
199class ApproxScalar(ApproxBase):
200    """Perform approximate comparisons where the expected value is a single number."""
201
202    # Using Real should be better than this Union, but not possible yet:
203    # https://github.com/python/typeshed/pull/3108
204    DEFAULT_ABSOLUTE_TOLERANCE = 1e-12  # type: Union[float, Decimal]
205    DEFAULT_RELATIVE_TOLERANCE = 1e-6  # type: Union[float, Decimal]
206
207    def __repr__(self) -> str:
208        """Return a string communicating both the expected value and the
209        tolerance for the comparison being made.
210
211        For example, ``1.0 ± 1e-6``, ``(3+4j) ± 5e-6 ∠ ±180°``.
212        """
213
214        # Infinities aren't compared using tolerances, so don't show a
215        # tolerance. Need to call abs to handle complex numbers, e.g. (inf + 1j).
216        if math.isinf(abs(self.expected)):
217            return str(self.expected)
218
219        # If a sensible tolerance can't be calculated, self.tolerance will
220        # raise a ValueError.  In this case, display '???'.
221        try:
222            vetted_tolerance = "{:.1e}".format(self.tolerance)
223            if isinstance(self.expected, complex) and not math.isinf(self.tolerance):
224                vetted_tolerance += " ∠ ±180°"
225        except ValueError:
226            vetted_tolerance = "???"
227
228        return "{} ± {}".format(self.expected, vetted_tolerance)
229
230    def __eq__(self, actual) -> bool:
231        """Return whether the given value is equal to the expected value
232        within the pre-specified tolerance."""
233        if _is_numpy_array(actual):
234            # Call ``__eq__()`` manually to prevent infinite-recursion with
235            # numpy<1.13.  See #3748.
236            return all(self.__eq__(a) for a in actual.flat)
237
238        # Short-circuit exact equality.
239        if actual == self.expected:
240            return True
241
242        # Allow the user to control whether NaNs are considered equal to each
243        # other or not.  The abs() calls are for compatibility with complex
244        # numbers.
245        if math.isnan(abs(self.expected)):
246            return self.nan_ok and math.isnan(abs(actual))
247
248        # Infinity shouldn't be approximately equal to anything but itself, but
249        # if there's a relative tolerance, it will be infinite and infinity
250        # will seem approximately equal to everything.  The equal-to-itself
251        # case would have been short circuited above, so here we can just
252        # return false if the expected value is infinite.  The abs() call is
253        # for compatibility with complex numbers.
254        if math.isinf(abs(self.expected)):
255            return False
256
257        # Return true if the two numbers are within the tolerance.
258        result = abs(self.expected - actual) <= self.tolerance  # type: bool
259        return result
260
261    # Ignore type because of https://github.com/python/mypy/issues/4266.
262    __hash__ = None  # type: ignore
263
264    @property
265    def tolerance(self):
266        """Return the tolerance for the comparison.
267
268        This could be either an absolute tolerance or a relative tolerance,
269        depending on what the user specified or which would be larger.
270        """
271
272        def set_default(x, default):
273            return x if x is not None else default
274
275        # Figure out what the absolute tolerance should be.  ``self.abs`` is
276        # either None or a value specified by the user.
277        absolute_tolerance = set_default(self.abs, self.DEFAULT_ABSOLUTE_TOLERANCE)
278
279        if absolute_tolerance < 0:
280            raise ValueError(
281                "absolute tolerance can't be negative: {}".format(absolute_tolerance)
282            )
283        if math.isnan(absolute_tolerance):
284            raise ValueError("absolute tolerance can't be NaN.")
285
286        # If the user specified an absolute tolerance but not a relative one,
287        # just return the absolute tolerance.
288        if self.rel is None:
289            if self.abs is not None:
290                return absolute_tolerance
291
292        # Figure out what the relative tolerance should be.  ``self.rel`` is
293        # either None or a value specified by the user.  This is done after
294        # we've made sure the user didn't ask for an absolute tolerance only,
295        # because we don't want to raise errors about the relative tolerance if
296        # we aren't even going to use it.
297        relative_tolerance = set_default(
298            self.rel, self.DEFAULT_RELATIVE_TOLERANCE
299        ) * abs(self.expected)
300
301        if relative_tolerance < 0:
302            raise ValueError(
303                "relative tolerance can't be negative: {}".format(absolute_tolerance)
304            )
305        if math.isnan(relative_tolerance):
306            raise ValueError("relative tolerance can't be NaN.")
307
308        # Return the larger of the relative and absolute tolerances.
309        return max(relative_tolerance, absolute_tolerance)
310
311
312class ApproxDecimal(ApproxScalar):
313    """Perform approximate comparisons where the expected value is a Decimal."""
314
315    DEFAULT_ABSOLUTE_TOLERANCE = Decimal("1e-12")
316    DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
317
318
319def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
320    """Assert that two numbers (or two sets of numbers) are equal to each other
321    within some tolerance.
322
323    Due to the `intricacies of floating-point arithmetic`__, numbers that we
324    would intuitively expect to be equal are not always so::
325
326        >>> 0.1 + 0.2 == 0.3
327        False
328
329    __ https://docs.python.org/3/tutorial/floatingpoint.html
330
331    This problem is commonly encountered when writing tests, e.g. when making
332    sure that floating-point values are what you expect them to be.  One way to
333    deal with this problem is to assert that two floating-point numbers are
334    equal to within some appropriate tolerance::
335
336        >>> abs((0.1 + 0.2) - 0.3) < 1e-6
337        True
338
339    However, comparisons like this are tedious to write and difficult to
340    understand.  Furthermore, absolute comparisons like the one above are
341    usually discouraged because there's no tolerance that works well for all
342    situations.  ``1e-6`` is good for numbers around ``1``, but too small for
343    very big numbers and too big for very small ones.  It's better to express
344    the tolerance as a fraction of the expected value, but relative comparisons
345    like that are even more difficult to write correctly and concisely.
346
347    The ``approx`` class performs floating-point comparisons using a syntax
348    that's as intuitive as possible::
349
350        >>> from pytest import approx
351        >>> 0.1 + 0.2 == approx(0.3)
352        True
353
354    The same syntax also works for sequences of numbers::
355
356        >>> (0.1 + 0.2, 0.2 + 0.4) == approx((0.3, 0.6))
357        True
358
359    Dictionary *values*::
360
361        >>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6})
362        True
363
364    ``numpy`` arrays::
365
366        >>> import numpy as np                                                          # doctest: +SKIP
367        >>> np.array([0.1, 0.2]) + np.array([0.2, 0.4]) == approx(np.array([0.3, 0.6])) # doctest: +SKIP
368        True
369
370    And for a ``numpy`` array against a scalar::
371
372        >>> import numpy as np                                         # doctest: +SKIP
373        >>> np.array([0.1, 0.2]) + np.array([0.2, 0.1]) == approx(0.3) # doctest: +SKIP
374        True
375
376    By default, ``approx`` considers numbers within a relative tolerance of
377    ``1e-6`` (i.e. one part in a million) of its expected value to be equal.
378    This treatment would lead to surprising results if the expected value was
379    ``0.0``, because nothing but ``0.0`` itself is relatively close to ``0.0``.
380    To handle this case less surprisingly, ``approx`` also considers numbers
381    within an absolute tolerance of ``1e-12`` of its expected value to be
382    equal.  Infinity and NaN are special cases.  Infinity is only considered
383    equal to itself, regardless of the relative tolerance.  NaN is not
384    considered equal to anything by default, but you can make it be equal to
385    itself by setting the ``nan_ok`` argument to True.  (This is meant to
386    facilitate comparing arrays that use NaN to mean "no data".)
387
388    Both the relative and absolute tolerances can be changed by passing
389    arguments to the ``approx`` constructor::
390
391        >>> 1.0001 == approx(1)
392        False
393        >>> 1.0001 == approx(1, rel=1e-3)
394        True
395        >>> 1.0001 == approx(1, abs=1e-3)
396        True
397
398    If you specify ``abs`` but not ``rel``, the comparison will not consider
399    the relative tolerance at all.  In other words, two numbers that are within
400    the default relative tolerance of ``1e-6`` will still be considered unequal
401    if they exceed the specified absolute tolerance.  If you specify both
402    ``abs`` and ``rel``, the numbers will be considered equal if either
403    tolerance is met::
404
405        >>> 1 + 1e-8 == approx(1)
406        True
407        >>> 1 + 1e-8 == approx(1, abs=1e-12)
408        False
409        >>> 1 + 1e-8 == approx(1, rel=1e-6, abs=1e-12)
410        True
411
412    If you're thinking about using ``approx``, then you might want to know how
413    it compares to other good ways of comparing floating-point numbers.  All of
414    these algorithms are based on relative and absolute tolerances and should
415    agree for the most part, but they do have meaningful differences:
416
417    - ``math.isclose(a, b, rel_tol=1e-9, abs_tol=0.0)``:  True if the relative
418      tolerance is met w.r.t. either ``a`` or ``b`` or if the absolute
419      tolerance is met.  Because the relative tolerance is calculated w.r.t.
420      both ``a`` and ``b``, this test is symmetric (i.e.  neither ``a`` nor
421      ``b`` is a "reference value").  You have to specify an absolute tolerance
422      if you want to compare to ``0.0`` because there is no tolerance by
423      default.  Only available in python>=3.5.  `More information...`__
424
425      __ https://docs.python.org/3/library/math.html#math.isclose
426
427    - ``numpy.isclose(a, b, rtol=1e-5, atol=1e-8)``: True if the difference
428      between ``a`` and ``b`` is less that the sum of the relative tolerance
429      w.r.t. ``b`` and the absolute tolerance.  Because the relative tolerance
430      is only calculated w.r.t. ``b``, this test is asymmetric and you can
431      think of ``b`` as the reference value.  Support for comparing sequences
432      is provided by ``numpy.allclose``.  `More information...`__
433
434      __ http://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.isclose.html
435
436    - ``unittest.TestCase.assertAlmostEqual(a, b)``: True if ``a`` and ``b``
437      are within an absolute tolerance of ``1e-7``.  No relative tolerance is
438      considered and the absolute tolerance cannot be changed, so this function
439      is not appropriate for very large or very small numbers.  Also, it's only
440      available in subclasses of ``unittest.TestCase`` and it's ugly because it
441      doesn't follow PEP8.  `More information...`__
442
443      __ https://docs.python.org/3/library/unittest.html#unittest.TestCase.assertAlmostEqual
444
445    - ``a == pytest.approx(b, rel=1e-6, abs=1e-12)``: True if the relative
446      tolerance is met w.r.t. ``b`` or if the absolute tolerance is met.
447      Because the relative tolerance is only calculated w.r.t. ``b``, this test
448      is asymmetric and you can think of ``b`` as the reference value.  In the
449      special case that you explicitly specify an absolute tolerance but not a
450      relative tolerance, only the absolute tolerance is considered.
451
452    .. warning::
453
454       .. versionchanged:: 3.2
455
456       In order to avoid inconsistent behavior, ``TypeError`` is
457       raised for ``>``, ``>=``, ``<`` and ``<=`` comparisons.
458       The example below illustrates the problem::
459
460           assert approx(0.1) > 0.1 + 1e-10  # calls approx(0.1).__gt__(0.1 + 1e-10)
461           assert 0.1 + 1e-10 > approx(0.1)  # calls approx(0.1).__lt__(0.1 + 1e-10)
462
463       In the second example one expects ``approx(0.1).__le__(0.1 + 1e-10)``
464       to be called. But instead, ``approx(0.1).__lt__(0.1 + 1e-10)`` is used to
465       comparison. This is because the call hierarchy of rich comparisons
466       follows a fixed behavior. `More information...`__
467
468       __ https://docs.python.org/3/reference/datamodel.html#object.__ge__
469    """
470
471    # Delegate the comparison to a class that knows how to deal with the type
472    # of the expected value (e.g. int, float, list, dict, numpy.array, etc).
473    #
474    # The primary responsibility of these classes is to implement ``__eq__()``
475    # and ``__repr__()``.  The former is used to actually check if some
476    # "actual" value is equivalent to the given expected value within the
477    # allowed tolerance.  The latter is used to show the user the expected
478    # value and tolerance, in the case that a test failed.
479    #
480    # The actual logic for making approximate comparisons can be found in
481    # ApproxScalar, which is used to compare individual numbers.  All of the
482    # other Approx classes eventually delegate to this class.  The ApproxBase
483    # class provides some convenient methods and overloads, but isn't really
484    # essential.
485
486    __tracebackhide__ = True
487
488    if isinstance(expected, Decimal):
489        cls = ApproxDecimal  # type: Type[ApproxBase]
490    elif isinstance(expected, Number):
491        cls = ApproxScalar
492    elif isinstance(expected, Mapping):
493        cls = ApproxMapping
494    elif _is_numpy_array(expected):
495        cls = ApproxNumpy
496    elif (
497        isinstance(expected, Iterable)
498        and isinstance(expected, Sized)
499        # Type ignored because the error is wrong -- not unreachable.
500        and not isinstance(expected, STRING_TYPES)  # type: ignore[unreachable]
501    ):
502        cls = ApproxSequencelike
503    else:
504        raise _non_numeric_type_error(expected, at=None)
505
506    return cls(expected, rel, abs, nan_ok)
507
508
509def _is_numpy_array(obj: object) -> bool:
510    """Return true if the given object is a numpy array.
511
512    A special effort is made to avoid importing numpy unless it's really necessary.
513    """
514    import sys
515
516    np = sys.modules.get("numpy")  # type: Any
517    if np is not None:
518        return isinstance(obj, np.ndarray)
519    return False
520
521
522# builtin pytest.raises helper
523
524_E = TypeVar("_E", bound=BaseException)
525
526
527@overload
528def raises(
529    expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
530    *,
531    match: "Optional[Union[str, Pattern[str]]]" = ...
532) -> "RaisesContext[_E]":
533    ...
534
535
536@overload  # noqa: F811
537def raises(  # noqa: F811
538    expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
539    func: Callable[..., Any],
540    *args: Any,
541    **kwargs: Any
542) -> _pytest._code.ExceptionInfo[_E]:
543    ...
544
545
546def raises(  # noqa: F811
547    expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
548    *args: Any,
549    **kwargs: Any
550) -> Union["RaisesContext[_E]", _pytest._code.ExceptionInfo[_E]]:
551    r"""Assert that a code block/function call raises ``expected_exception``
552    or raise a failure exception otherwise.
553
554    :kwparam match:
555        If specified, a string containing a regular expression,
556        or a regular expression object, that is tested against the string
557        representation of the exception using ``re.search``. To match a literal
558        string that may contain `special characters`__, the pattern can
559        first be escaped with ``re.escape``.
560
561        (This is only used when ``pytest.raises`` is used as a context manager,
562        and passed through to the function otherwise.
563        When using ``pytest.raises`` as a function, you can use:
564        ``pytest.raises(Exc, func, match="passed on").match("my pattern")``.)
565
566        __ https://docs.python.org/3/library/re.html#regular-expression-syntax
567
568    .. currentmodule:: _pytest._code
569
570    Use ``pytest.raises`` as a context manager, which will capture the exception of the given
571    type::
572
573        >>> with raises(ZeroDivisionError):
574        ...    1/0
575
576    If the code block does not raise the expected exception (``ZeroDivisionError`` in the example
577    above), or no exception at all, the check will fail instead.
578
579    You can also use the keyword argument ``match`` to assert that the
580    exception matches a text or regex::
581
582        >>> with raises(ValueError, match='must be 0 or None'):
583        ...     raise ValueError("value must be 0 or None")
584
585        >>> with raises(ValueError, match=r'must be \d+$'):
586        ...     raise ValueError("value must be 42")
587
588    The context manager produces an :class:`ExceptionInfo` object which can be used to inspect the
589    details of the captured exception::
590
591        >>> with raises(ValueError) as exc_info:
592        ...     raise ValueError("value must be 42")
593        >>> assert exc_info.type is ValueError
594        >>> assert exc_info.value.args[0] == "value must be 42"
595
596    .. note::
597
598       When using ``pytest.raises`` as a context manager, it's worthwhile to
599       note that normal context manager rules apply and that the exception
600       raised *must* be the final line in the scope of the context manager.
601       Lines of code after that, within the scope of the context manager will
602       not be executed. For example::
603
604           >>> value = 15
605           >>> with raises(ValueError) as exc_info:
606           ...     if value > 10:
607           ...         raise ValueError("value must be <= 10")
608           ...     assert exc_info.type is ValueError  # this will not execute
609
610       Instead, the following approach must be taken (note the difference in
611       scope)::
612
613           >>> with raises(ValueError) as exc_info:
614           ...     if value > 10:
615           ...         raise ValueError("value must be <= 10")
616           ...
617           >>> assert exc_info.type is ValueError
618
619    **Using with** ``pytest.mark.parametrize``
620
621    When using :ref:`pytest.mark.parametrize ref`
622    it is possible to parametrize tests such that
623    some runs raise an exception and others do not.
624
625    See :ref:`parametrizing_conditional_raising` for an example.
626
627    **Legacy form**
628
629    It is possible to specify a callable by passing a to-be-called lambda::
630
631        >>> raises(ZeroDivisionError, lambda: 1/0)
632        <ExceptionInfo ...>
633
634    or you can specify an arbitrary callable with arguments::
635
636        >>> def f(x): return 1/x
637        ...
638        >>> raises(ZeroDivisionError, f, 0)
639        <ExceptionInfo ...>
640        >>> raises(ZeroDivisionError, f, x=0)
641        <ExceptionInfo ...>
642
643    The form above is fully supported but discouraged for new code because the
644    context manager form is regarded as more readable and less error-prone.
645
646    .. note::
647        Similar to caught exception objects in Python, explicitly clearing
648        local references to returned ``ExceptionInfo`` objects can
649        help the Python interpreter speed up its garbage collection.
650
651        Clearing those references breaks a reference cycle
652        (``ExceptionInfo`` --> caught exception --> frame stack raising
653        the exception --> current frame stack --> local variables -->
654        ``ExceptionInfo``) which makes Python keep all objects referenced
655        from that cycle (including all local variables in the current
656        frame) alive until the next cyclic garbage collection run.
657        More detailed information can be found in the official Python
658        documentation for :ref:`the try statement <python:try>`.
659    """
660    __tracebackhide__ = True
661
662    if isinstance(expected_exception, type):
663        excepted_exceptions = (expected_exception,)  # type: Tuple[Type[_E], ...]
664    else:
665        excepted_exceptions = expected_exception
666    for exc in excepted_exceptions:
667        if not isinstance(exc, type) or not issubclass(exc, BaseException):  # type: ignore[unreachable]
668            msg = "expected exception must be a BaseException type, not {}"  # type: ignore[unreachable]
669            not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__
670            raise TypeError(msg.format(not_a))
671
672    message = "DID NOT RAISE {}".format(expected_exception)
673
674    if not args:
675        match = kwargs.pop("match", None)  # type: Optional[Union[str, Pattern[str]]]
676        if kwargs:
677            msg = "Unexpected keyword arguments passed to pytest.raises: "
678            msg += ", ".join(sorted(kwargs))
679            msg += "\nUse context-manager form instead?"
680            raise TypeError(msg)
681        return RaisesContext(expected_exception, message, match)
682    else:
683        func = args[0]
684        if not callable(func):
685            raise TypeError(
686                "{!r} object (type: {}) must be callable".format(func, type(func))
687            )
688        try:
689            func(*args[1:], **kwargs)
690        except expected_exception as e:
691            # We just caught the exception - there is a traceback.
692            assert e.__traceback__ is not None
693            return _pytest._code.ExceptionInfo.from_exc_info(
694                (type(e), e, e.__traceback__)
695            )
696    fail(message)
697
698
699# This doesn't work with mypy for now. Use fail.Exception instead.
700raises.Exception = fail.Exception  # type: ignore
701
702
703@final
704class RaisesContext(Generic[_E]):
705    def __init__(
706        self,
707        expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
708        message: str,
709        match_expr: Optional[Union[str, "Pattern[str]"]] = None,
710    ) -> None:
711        self.expected_exception = expected_exception
712        self.message = message
713        self.match_expr = match_expr
714        self.excinfo = None  # type: Optional[_pytest._code.ExceptionInfo[_E]]
715
716    def __enter__(self) -> _pytest._code.ExceptionInfo[_E]:
717        self.excinfo = _pytest._code.ExceptionInfo.for_later()
718        return self.excinfo
719
720    def __exit__(
721        self,
722        exc_type: Optional["Type[BaseException]"],
723        exc_val: Optional[BaseException],
724        exc_tb: Optional[TracebackType],
725    ) -> bool:
726        __tracebackhide__ = True
727        if exc_type is None:
728            fail(self.message)
729        assert self.excinfo is not None
730        if not issubclass(exc_type, self.expected_exception):
731            return False
732        # Cast to narrow the exception type now that it's verified.
733        exc_info = cast(
734            Tuple["Type[_E]", _E, TracebackType], (exc_type, exc_val, exc_tb)
735        )
736        self.excinfo.fill_unfilled(exc_info)
737        if self.match_expr is not None:
738            self.excinfo.match(self.match_expr)
739        return True
740