1import _thread
2import contextlib
3import functools
4import sys
5import threading
6import time
7
8from test import support
9
10
11#=======================================================================
12# Threading support to prevent reporting refleaks when running regrtest.py -R
13
14# NOTE: we use thread._count() rather than threading.enumerate() (or the
15# moral equivalent thereof) because a threading.Thread object is still alive
16# until its __bootstrap() method has returned, even after it has been
17# unregistered from the threading module.
18# thread._count(), on the other hand, only gets decremented *after* the
19# __bootstrap() method has returned, which gives us reliable reference counts
20# at the end of a test run.
21
22
23def threading_setup():
24    return _thread._count(), threading._dangling.copy()
25
26
27def threading_cleanup(*original_values):
28    _MAX_COUNT = 100
29
30    for count in range(_MAX_COUNT):
31        values = _thread._count(), threading._dangling
32        if values == original_values:
33            break
34
35        if not count:
36            # Display a warning at the first iteration
37            support.environment_altered = True
38            dangling_threads = values[1]
39            support.print_warning(f"threading_cleanup() failed to cleanup "
40                                  f"{values[0] - original_values[0]} threads "
41                                  f"(count: {values[0]}, "
42                                  f"dangling: {len(dangling_threads)})")
43            for thread in dangling_threads:
44                support.print_warning(f"Dangling thread: {thread!r}")
45
46            # Don't hold references to threads
47            dangling_threads = None
48        values = None
49
50        time.sleep(0.01)
51        support.gc_collect()
52
53
54def reap_threads(func):
55    """Use this function when threads are being used.  This will
56    ensure that the threads are cleaned up even when the test fails.
57    """
58    @functools.wraps(func)
59    def decorator(*args):
60        key = threading_setup()
61        try:
62            return func(*args)
63        finally:
64            threading_cleanup(*key)
65    return decorator
66
67
68@contextlib.contextmanager
69def wait_threads_exit(timeout=None):
70    """
71    bpo-31234: Context manager to wait until all threads created in the with
72    statement exit.
73
74    Use _thread.count() to check if threads exited. Indirectly, wait until
75    threads exit the internal t_bootstrap() C function of the _thread module.
76
77    threading_setup() and threading_cleanup() are designed to emit a warning
78    if a test leaves running threads in the background. This context manager
79    is designed to cleanup threads started by the _thread.start_new_thread()
80    which doesn't allow to wait for thread exit, whereas thread.Thread has a
81    join() method.
82    """
83    if timeout is None:
84        timeout = support.SHORT_TIMEOUT
85    old_count = _thread._count()
86    try:
87        yield
88    finally:
89        start_time = time.monotonic()
90        deadline = start_time + timeout
91        while True:
92            count = _thread._count()
93            if count <= old_count:
94                break
95            if time.monotonic() > deadline:
96                dt = time.monotonic() - start_time
97                msg = (f"wait_threads() failed to cleanup {count - old_count} "
98                       f"threads after {dt:.1f} seconds "
99                       f"(count: {count}, old count: {old_count})")
100                raise AssertionError(msg)
101            time.sleep(0.010)
102            support.gc_collect()
103
104
105def join_thread(thread, timeout=None):
106    """Join a thread. Raise an AssertionError if the thread is still alive
107    after timeout seconds.
108    """
109    if timeout is None:
110        timeout = support.SHORT_TIMEOUT
111    thread.join(timeout)
112    if thread.is_alive():
113        msg = f"failed to join the thread in {timeout:.1f} seconds"
114        raise AssertionError(msg)
115
116
117@contextlib.contextmanager
118def start_threads(threads, unlock=None):
119    import faulthandler
120    threads = list(threads)
121    started = []
122    try:
123        try:
124            for t in threads:
125                t.start()
126                started.append(t)
127        except:
128            if support.verbose:
129                print("Can't start %d threads, only %d threads started" %
130                      (len(threads), len(started)))
131            raise
132        yield
133    finally:
134        try:
135            if unlock:
136                unlock()
137            endtime = time.monotonic()
138            for timeout in range(1, 16):
139                endtime += 60
140                for t in started:
141                    t.join(max(endtime - time.monotonic(), 0.01))
142                started = [t for t in started if t.is_alive()]
143                if not started:
144                    break
145                if support.verbose:
146                    print('Unable to join %d threads during a period of '
147                          '%d minutes' % (len(started), timeout))
148        finally:
149            started = [t for t in started if t.is_alive()]
150            if started:
151                faulthandler.dump_traceback(sys.stdout)
152                raise AssertionError('Unable to join %d threads' % len(started))
153
154
155class catch_threading_exception:
156    """
157    Context manager catching threading.Thread exception using
158    threading.excepthook.
159
160    Attributes set when an exception is caught:
161
162    * exc_type
163    * exc_value
164    * exc_traceback
165    * thread
166
167    See threading.excepthook() documentation for these attributes.
168
169    These attributes are deleted at the context manager exit.
170
171    Usage:
172
173        with threading_helper.catch_threading_exception() as cm:
174            # code spawning a thread which raises an exception
175            ...
176
177            # check the thread exception, use cm attributes:
178            # exc_type, exc_value, exc_traceback, thread
179            ...
180
181        # exc_type, exc_value, exc_traceback, thread attributes of cm no longer
182        # exists at this point
183        # (to avoid reference cycles)
184    """
185
186    def __init__(self):
187        self.exc_type = None
188        self.exc_value = None
189        self.exc_traceback = None
190        self.thread = None
191        self._old_hook = None
192
193    def _hook(self, args):
194        self.exc_type = args.exc_type
195        self.exc_value = args.exc_value
196        self.exc_traceback = args.exc_traceback
197        self.thread = args.thread
198
199    def __enter__(self):
200        self._old_hook = threading.excepthook
201        threading.excepthook = self._hook
202        return self
203
204    def __exit__(self, *exc_info):
205        threading.excepthook = self._old_hook
206        del self.exc_type
207        del self.exc_value
208        del self.exc_traceback
209        del self.thread
210