1# coding: utf-8 2""" 3:class:`asynctest.TestCase` decorator which controls checks performed after 4tests. 5 6This module is separated from :mod:`asynctest.case` to avoid circular imports 7in modules registering new checks. 8 9To implement new checks: 10 11 * its name must be added in the ``DEFAULTS`` dict, 12 13 * a static method of the same name must be added to the :class:`_fail_on` 14 class, 15 16 * an optional static method named ``before_[name of the check]`` can be 17 added to :class:`_fail_on` to implement some set-up before the test runs. 18 19A check may be only available on some platforms, activated by a conditional 20import. In this case, ``DEFAULT`` and :class:`_fail_on` can be updated in the 21module. There is an example in the :mod:`asynctest.selector` module. 22""" 23from asyncio import TimerHandle 24 25 26_FAIL_ON_ATTR = "_asynctest_fail_on" 27 28 29#: Default value of the arguments of @fail_on, the name of the argument matches 30#: the name of the static method performing the check in the :class:`_fail_on`. 31#: The value is True when the check is enabled by default, False otherwise. 32DEFAULTS = { 33 "unused_loop": False, 34 "active_handles": False, 35} 36 37 38class _fail_on: 39 def __init__(self, checks=None): 40 self.checks = checks or {} 41 self._computed_checks = None 42 43 def __call__(self, func): 44 checker = getattr(func, _FAIL_ON_ATTR, None) 45 if checker: 46 checker = checker.copy() 47 checker.update(self.checks) 48 else: 49 checker = self.copy() 50 51 setattr(func, _FAIL_ON_ATTR, checker) 52 return func 53 54 def update(self, checks, override=True): 55 if override: 56 self.checks.update(checks) 57 else: 58 for check, value in checks.items(): 59 self.checks.setdefault(check, value) 60 61 def copy(self): 62 return _fail_on(self.checks.copy()) 63 64 def get_checks(self, case): 65 # cache the result so it's consistent across calls to get_checks() 66 if self._computed_checks is None: 67 checks = DEFAULTS.copy() 68 69 try: 70 checks.update(getattr(case, _FAIL_ON_ATTR, None).checks) 71 except AttributeError: 72 pass 73 74 checks.update(self.checks) 75 self._computed_checks = checks 76 77 return self._computed_checks 78 79 def before_test(self, case): 80 checks = self.get_checks(case) 81 for check in filter(checks.get, checks): 82 try: 83 getattr(self, "before_test_" + check)(case) 84 except (AttributeError, TypeError): 85 pass 86 87 def check_test(self, case): 88 checks = self.get_checks(case) 89 for check in filter(checks.get, checks): 90 getattr(self, check)(case) 91 92 # checks 93 94 @staticmethod 95 def unused_loop(case): 96 if not case.loop._asynctest_ran: 97 case.fail("Loop did not run during the test") 98 99 @staticmethod 100 def _is_live_timer_handle(handle): 101 return isinstance(handle, TimerHandle) and not handle._cancelled 102 103 @classmethod 104 def _live_timer_handles(cls, loop): 105 return filter(cls._is_live_timer_handle, loop._scheduled) 106 107 @classmethod 108 def active_handles(cls, case): 109 handles = tuple(cls._live_timer_handles(case.loop)) 110 if handles: 111 case.fail("Loop contained unfinished work {!r}".format(handles)) 112 113 114def fail_on(**kwargs): 115 """ 116 Enable checks on the loop state after a test ran to help testers to 117 identify common mistakes. 118 """ 119 # documented in asynctest.case.rst 120 for kwarg in kwargs: 121 if kwarg not in DEFAULTS: 122 raise TypeError("fail_on() got an unexpected keyword argument " 123 "'{}'".format(kwarg)) 124 125 return _fail_on(kwargs) 126 127 128def _fail_on_all(flag, func): 129 checker = _fail_on(dict((arg, flag) for arg in DEFAULTS)) 130 return checker if func is None else checker(func) 131 132 133def strict(func=None): 134 """ 135 Activate strict checking of the state of the loop after a test ran. 136 """ 137 # documented in asynctest.case.rst 138 return _fail_on_all(True, func) 139 140 141def lenient(func=None): 142 """ 143 Deactivate all checks after a test ran. 144 """ 145 # documented in asynctest.case.rst 146 return _fail_on_all(False, func) 147