1# coding: utf-8
2"""
3:class:`asynctest.TestCase` decorator which controls checks performed after
4tests.
5
6This module is separated from :mod:`asynctest.case` to avoid circular imports
7in modules registering new checks.
8
9To implement new checks:
10
11    * its name must be added in the ``DEFAULTS`` dict,
12
13    * a static method of the same name must be added to the :class:`_fail_on`
14      class,
15
16    * an optional static method named ``before_[name of the check]`` can be
17      added to :class:`_fail_on` to implement some set-up before the test runs.
18
19A check may be only available on some platforms, activated by a conditional
20import. In this case, ``DEFAULT`` and :class:`_fail_on` can be updated in the
21module. There is an example in the :mod:`asynctest.selector` module.
22"""
23from asyncio import TimerHandle
24
25
26_FAIL_ON_ATTR = "_asynctest_fail_on"
27
28
29#: Default value of the arguments of @fail_on, the name of the argument matches
30#: the name of the static method performing the check in the :class:`_fail_on`.
31#: The value is True when the check is enabled by default, False otherwise.
32DEFAULTS = {
33    "unused_loop": False,
34    "active_handles": False,
35}
36
37
38class _fail_on:
39    def __init__(self, checks=None):
40        self.checks = checks or {}
41        self._computed_checks = None
42
43    def __call__(self, func):
44        checker = getattr(func, _FAIL_ON_ATTR, None)
45        if checker:
46            checker = checker.copy()
47            checker.update(self.checks)
48        else:
49            checker = self.copy()
50
51        setattr(func, _FAIL_ON_ATTR, checker)
52        return func
53
54    def update(self, checks, override=True):
55        if override:
56            self.checks.update(checks)
57        else:
58            for check, value in checks.items():
59                self.checks.setdefault(check, value)
60
61    def copy(self):
62        return _fail_on(self.checks.copy())
63
64    def get_checks(self, case):
65        # cache the result so it's consistent across calls to get_checks()
66        if self._computed_checks is None:
67            checks = DEFAULTS.copy()
68
69            try:
70                checks.update(getattr(case, _FAIL_ON_ATTR, None).checks)
71            except AttributeError:
72                pass
73
74            checks.update(self.checks)
75            self._computed_checks = checks
76
77        return self._computed_checks
78
79    def before_test(self, case):
80        checks = self.get_checks(case)
81        for check in filter(checks.get, checks):
82            try:
83                getattr(self, "before_test_" + check)(case)
84            except (AttributeError, TypeError):
85                pass
86
87    def check_test(self, case):
88        checks = self.get_checks(case)
89        for check in filter(checks.get, checks):
90            getattr(self, check)(case)
91
92    # checks
93
94    @staticmethod
95    def unused_loop(case):
96        if not case.loop._asynctest_ran:
97            case.fail("Loop did not run during the test")
98
99    @staticmethod
100    def _is_live_timer_handle(handle):
101        return isinstance(handle, TimerHandle) and not handle._cancelled
102
103    @classmethod
104    def _live_timer_handles(cls, loop):
105        return filter(cls._is_live_timer_handle, loop._scheduled)
106
107    @classmethod
108    def active_handles(cls, case):
109        handles = tuple(cls._live_timer_handles(case.loop))
110        if handles:
111            case.fail("Loop contained unfinished work {!r}".format(handles))
112
113
114def fail_on(**kwargs):
115    """
116    Enable checks on the loop state after a test ran to help testers to
117    identify common mistakes.
118    """
119    # documented in asynctest.case.rst
120    for kwarg in kwargs:
121        if kwarg not in DEFAULTS:
122            raise TypeError("fail_on() got an unexpected keyword argument "
123                            "'{}'".format(kwarg))
124
125    return _fail_on(kwargs)
126
127
128def _fail_on_all(flag, func):
129    checker = _fail_on(dict((arg, flag) for arg in DEFAULTS))
130    return checker if func is None else checker(func)
131
132
133def strict(func=None):
134    """
135    Activate strict checking of the state of the loop after a test ran.
136    """
137    # documented in asynctest.case.rst
138    return _fail_on_all(True, func)
139
140
141def lenient(func=None):
142    """
143    Deactivate all checks after a test ran.
144    """
145    # documented in asynctest.case.rst
146    return _fail_on_all(False, func)
147