1"""
2Tests for the threading module.
3"""
4
5import test.support
6from test.support import (verbose, import_module, cpython_only,
7                          requires_type_collecting)
8from test.support.script_helper import assert_python_ok, assert_python_failure
9
10import random
11import sys
12import _thread
13import threading
14import time
15import unittest
16import weakref
17import os
18import subprocess
19import signal
20
21from gevent.tests import lock_tests # gevent: use our local copy
22from test import support
23
24
25# Between fork() and exec(), only async-safe functions are allowed (issues
26# #12316 and #11870), and fork() from a worker thread is known to trigger
27# problems with some operating systems (issue #3863): skip problematic tests
28# on platforms known to behave badly.
29platforms_to_skip = ('netbsd5', 'hp-ux11')
30
31
32# A trivial mutable counter.
33class Counter(object):
34    def __init__(self):
35        self.value = 0
36    def inc(self):
37        self.value += 1
38    def dec(self):
39        self.value -= 1
40    def get(self):
41        return self.value
42
43class TestThread(threading.Thread):
44    def __init__(self, name, testcase, sema, mutex, nrunning):
45        threading.Thread.__init__(self, name=name)
46        self.testcase = testcase
47        self.sema = sema
48        self.mutex = mutex
49        self.nrunning = nrunning
50
51    def run(self):
52        delay = random.random() / 10000.0
53        if verbose:
54            print('task %s will run for %.1f usec' %
55                  (self.name, delay * 1e6))
56
57        with self.sema:
58            with self.mutex:
59                self.nrunning.inc()
60                if verbose:
61                    print(self.nrunning.get(), 'tasks are running')
62                self.testcase.assertLessEqual(self.nrunning.get(), 3)
63
64            time.sleep(delay)
65            if verbose:
66                print('task', self.name, 'done')
67
68            with self.mutex:
69                self.nrunning.dec()
70                self.testcase.assertGreaterEqual(self.nrunning.get(), 0)
71                if verbose:
72                    print('%s is finished. %d tasks are running' %
73                          (self.name, self.nrunning.get()))
74
75
76class BaseTestCase(unittest.TestCase):
77    def setUp(self):
78        self._threads = test.support.threading_setup()
79
80    def tearDown(self):
81        test.support.threading_cleanup(*self._threads)
82        test.support.reap_children()
83
84
85class ThreadTests(BaseTestCase):
86
87    # Create a bunch of threads, let each do some work, wait until all are
88    # done.
89    def test_various_ops(self):
90        # This takes about n/3 seconds to run (about n/3 clumps of tasks,
91        # times about 1 second per clump).
92        NUMTASKS = 10
93
94        # no more than 3 of the 10 can run at once
95        sema = threading.BoundedSemaphore(value=3)
96        mutex = threading.RLock()
97        numrunning = Counter()
98
99        threads = []
100
101        for i in range(NUMTASKS):
102            t = TestThread("<thread %d>"%i, self, sema, mutex, numrunning)
103            threads.append(t)
104            self.assertIsNone(t.ident)
105            self.assertRegex(repr(t), r'^<TestThread\(.*, initial\)>$')
106            t.start()
107
108        if hasattr(threading, 'get_native_id'):
109            native_ids = set(t.native_id for t in threads) | {threading.get_native_id()}
110            self.assertNotIn(None, native_ids)
111            self.assertEqual(len(native_ids), NUMTASKS + 1)
112
113        if verbose:
114            print('waiting for all tasks to complete')
115        for t in threads:
116            t.join()
117            self.assertFalse(t.is_alive())
118            self.assertNotEqual(t.ident, 0)
119            self.assertIsNotNone(t.ident)
120            self.assertRegex(repr(t), r'^<TestThread\(.*, stopped -?\d+\)>$')
121        if verbose:
122            print('all tasks done')
123        self.assertEqual(numrunning.get(), 0)
124
125    def test_ident_of_no_threading_threads(self):
126        # The ident still must work for the main thread and dummy threads.
127        self.assertIsNotNone(threading.currentThread().ident)
128        def f():
129            ident.append(threading.currentThread().ident)
130            done.set()
131        done = threading.Event()
132        ident = []
133        with support.wait_threads_exit():
134            tid = _thread.start_new_thread(f, ())
135            done.wait()
136            self.assertEqual(ident[0], tid)
137        # Kill the "immortal" _DummyThread
138        del threading._active[ident[0]]
139
140    # run with a small(ish) thread stack size (256 KiB)
141    def test_various_ops_small_stack(self):
142        if verbose:
143            print('with 256 KiB thread stack size...')
144        try:
145            threading.stack_size(262144)
146        except _thread.error:
147            raise unittest.SkipTest(
148                'platform does not support changing thread stack size')
149        self.test_various_ops()
150        threading.stack_size(0)
151
152    # run with a large thread stack size (1 MiB)
153    def test_various_ops_large_stack(self):
154        if verbose:
155            print('with 1 MiB thread stack size...')
156        try:
157            threading.stack_size(0x100000)
158        except _thread.error:
159            raise unittest.SkipTest(
160                'platform does not support changing thread stack size')
161        self.test_various_ops()
162        threading.stack_size(0)
163
164    def test_foreign_thread(self):
165        # Check that a "foreign" thread can use the threading module.
166        def f(mutex):
167            # Calling current_thread() forces an entry for the foreign
168            # thread to get made in the threading._active map.
169            threading.current_thread()
170            mutex.release()
171
172        mutex = threading.Lock()
173        mutex.acquire()
174        with support.wait_threads_exit():
175            tid = _thread.start_new_thread(f, (mutex,))
176            # Wait for the thread to finish.
177            mutex.acquire()
178        self.assertIn(tid, threading._active)
179        self.assertIsInstance(threading._active[tid], threading._DummyThread)
180        #Issue 29376
181        self.assertTrue(threading._active[tid].is_alive())
182        self.assertRegex(repr(threading._active[tid]), '_DummyThread')
183        del threading._active[tid]
184
185    # PyThreadState_SetAsyncExc() is a CPython-only gimmick, not (currently)
186    # exposed at the Python level.  This test relies on ctypes to get at it.
187    def test_PyThreadState_SetAsyncExc(self):
188        ctypes = import_module("ctypes")
189
190        set_async_exc = ctypes.pythonapi.PyThreadState_SetAsyncExc
191        set_async_exc.argtypes = (ctypes.c_ulong, ctypes.py_object)
192
193        class AsyncExc(Exception):
194            pass
195
196        exception = ctypes.py_object(AsyncExc)
197
198        # First check it works when setting the exception from the same thread.
199        tid = threading.get_ident()
200        self.assertIsInstance(tid, int)
201        self.assertGreater(tid, 0)
202
203        try:
204            result = set_async_exc(tid, exception)
205            # The exception is async, so we might have to keep the VM busy until
206            # it notices.
207            while True:
208                pass
209        except AsyncExc:
210            pass
211        else:
212            # This code is unreachable but it reflects the intent. If we wanted
213            # to be smarter the above loop wouldn't be infinite.
214            self.fail("AsyncExc not raised")
215        try:
216            self.assertEqual(result, 1) # one thread state modified
217        except UnboundLocalError:
218            # The exception was raised too quickly for us to get the result.
219            pass
220
221        # `worker_started` is set by the thread when it's inside a try/except
222        # block waiting to catch the asynchronously set AsyncExc exception.
223        # `worker_saw_exception` is set by the thread upon catching that
224        # exception.
225        worker_started = threading.Event()
226        worker_saw_exception = threading.Event()
227
228        class Worker(threading.Thread):
229            def run(self):
230                self.id = threading.get_ident()
231                self.finished = False
232
233                try:
234                    while True:
235                        worker_started.set()
236                        time.sleep(0.1)
237                except AsyncExc:
238                    self.finished = True
239                    worker_saw_exception.set()
240
241        t = Worker()
242        t.daemon = True # so if this fails, we don't hang Python at shutdown
243        t.start()
244        if verbose:
245            print("    started worker thread")
246
247        # Try a thread id that doesn't make sense.
248        if verbose:
249            print("    trying nonsensical thread id")
250        result = set_async_exc(-1, exception)
251        self.assertEqual(result, 0)  # no thread states modified
252
253        # Now raise an exception in the worker thread.
254        if verbose:
255            print("    waiting for worker thread to get started")
256        ret = worker_started.wait()
257        self.assertTrue(ret)
258        if verbose:
259            print("    verifying worker hasn't exited")
260        self.assertFalse(t.finished)
261        if verbose:
262            print("    attempting to raise asynch exception in worker")
263        result = set_async_exc(t.id, exception)
264        self.assertEqual(result, 1) # one thread state modified
265        if verbose:
266            print("    waiting for worker to say it caught the exception")
267        worker_saw_exception.wait(timeout=10)
268        self.assertTrue(t.finished)
269        if verbose:
270            print("    all OK -- joining worker")
271        if t.finished:
272            t.join()
273        # else the thread is still running, and we have no way to kill it
274
275    def test_limbo_cleanup(self):
276        # Issue 7481: Failure to start thread should cleanup the limbo map.
277        def fail_new_thread(*args):
278            raise threading.ThreadError()
279        _start_new_thread = threading._start_new_thread
280        threading._start_new_thread = fail_new_thread
281        try:
282            t = threading.Thread(target=lambda: None)
283            self.assertRaises(threading.ThreadError, t.start)
284            self.assertFalse(
285                t in threading._limbo,
286                "Failed to cleanup _limbo map on failure of Thread.start().")
287        finally:
288            threading._start_new_thread = _start_new_thread
289
290    def test_finalize_runnning_thread(self):
291        # Issue 1402: the PyGILState_Ensure / _Release functions may be called
292        # very late on python exit: on deallocation of a running thread for
293        # example.
294        import_module("ctypes")
295
296        rc, out, err = assert_python_failure("-c", """if 1:
297            import ctypes, sys, time, _thread
298
299            # This lock is used as a simple event variable.
300            ready = _thread.allocate_lock()
301            ready.acquire()
302
303            # Module globals are cleared before __del__ is run
304            # So we save the functions in class dict
305            class C:
306                ensure = ctypes.pythonapi.PyGILState_Ensure
307                release = ctypes.pythonapi.PyGILState_Release
308                def __del__(self):
309                    state = self.ensure()
310                    self.release(state)
311
312            def waitingThread():
313                x = C()
314                ready.release()
315                time.sleep(100)
316
317            _thread.start_new_thread(waitingThread, ())
318            ready.acquire()  # Be sure the other thread is waiting.
319            sys.exit(42)
320            """)
321        self.assertEqual(rc, 42)
322
323    def test_finalize_with_trace(self):
324        # Issue1733757
325        # Avoid a deadlock when sys.settrace steps into threading._shutdown
326        assert_python_ok("-c", """if 1:
327            import sys, threading
328
329            # A deadlock-killer, to prevent the
330            # testsuite to hang forever
331            def killer():
332                import os, time
333                time.sleep(2)
334                print('program blocked; aborting')
335                os._exit(2)
336            t = threading.Thread(target=killer)
337            t.daemon = True
338            t.start()
339
340            # This is the trace function
341            def func(frame, event, arg):
342                threading.current_thread()
343                return func
344
345            sys.settrace(func)
346            """)
347
348    def test_join_nondaemon_on_shutdown(self):
349        # Issue 1722344
350        # Raising SystemExit skipped threading._shutdown
351        rc, out, err = assert_python_ok("-c", """if 1:
352                import threading
353                from time import sleep
354
355                def child():
356                    sleep(1)
357                    # As a non-daemon thread we SHOULD wake up and nothing
358                    # should be torn down yet
359                    print("Woke up, sleep function is:", sleep)
360
361                threading.Thread(target=child).start()
362                raise SystemExit
363            """)
364        self.assertEqual(out.strip(),
365            b"Woke up, sleep function is: <built-in function sleep>")
366        self.assertEqual(err, b"")
367
368    def test_enumerate_after_join(self):
369        # Try hard to trigger #1703448: a thread is still returned in
370        # threading.enumerate() after it has been join()ed.
371        enum = threading.enumerate
372        old_interval = sys.getswitchinterval()
373        try:
374            for i in range(1, 100):
375                sys.setswitchinterval(i * 0.0002)
376                t = threading.Thread(target=lambda: None)
377                t.start()
378                t.join()
379                l = enum()
380                self.assertNotIn(t, l,
381                    "#1703448 triggered after %d trials: %s" % (i, l))
382        finally:
383            sys.setswitchinterval(old_interval)
384
385    def test_no_refcycle_through_target(self):
386        class RunSelfFunction(object):
387            def __init__(self, should_raise):
388                # The links in this refcycle from Thread back to self
389                # should be cleaned up when the thread completes.
390                self.should_raise = should_raise
391                self.thread = threading.Thread(target=self._run,
392                                               args=(self,),
393                                               kwargs={'yet_another':self})
394                self.thread.start()
395
396            def _run(self, other_ref, yet_another):
397                if self.should_raise:
398                    raise SystemExit
399
400        cyclic_object = RunSelfFunction(should_raise=False)
401        weak_cyclic_object = weakref.ref(cyclic_object)
402        cyclic_object.thread.join()
403        del cyclic_object
404        self.assertIsNone(weak_cyclic_object(),
405                         msg=('%d references still around' %
406                              sys.getrefcount(weak_cyclic_object())))
407
408        raising_cyclic_object = RunSelfFunction(should_raise=True)
409        weak_raising_cyclic_object = weakref.ref(raising_cyclic_object)
410        raising_cyclic_object.thread.join()
411        del raising_cyclic_object
412        self.assertIsNone(weak_raising_cyclic_object(),
413                         msg=('%d references still around' %
414                              sys.getrefcount(weak_raising_cyclic_object())))
415
416    def test_old_threading_api(self):
417        # Just a quick sanity check to make sure the old method names are
418        # still present
419        t = threading.Thread()
420        t.isDaemon()
421        t.setDaemon(True)
422        t.getName()
423        t.setName("name")
424        with self.assertWarnsRegex(DeprecationWarning, 'use is_alive()'):
425            t.isAlive()
426        e = threading.Event()
427        e.isSet()
428        threading.activeCount()
429
430    def test_repr_daemon(self):
431        t = threading.Thread()
432        self.assertNotIn('daemon', repr(t))
433        t.daemon = True
434        self.assertIn('daemon', repr(t))
435
436    def test_daemon_param(self):
437        t = threading.Thread()
438        self.assertFalse(t.daemon)
439        t = threading.Thread(daemon=False)
440        self.assertFalse(t.daemon)
441        t = threading.Thread(daemon=True)
442        self.assertTrue(t.daemon)
443
444    @unittest.skipUnless(hasattr(os, 'fork'), 'test needs fork()')
445    def test_dummy_thread_after_fork(self):
446        # Issue #14308: a dummy thread in the active list doesn't mess up
447        # the after-fork mechanism.
448        code = """if 1:
449            import _thread, threading, os, time
450
451            def background_thread(evt):
452                # Creates and registers the _DummyThread instance
453                threading.current_thread()
454                evt.set()
455                time.sleep(10)
456
457            evt = threading.Event()
458            _thread.start_new_thread(background_thread, (evt,))
459            evt.wait()
460            assert threading.active_count() == 2, threading.active_count()
461            if os.fork() == 0:
462                assert threading.active_count() == 1, threading.active_count()
463                os._exit(0)
464            else:
465                os.wait()
466        """
467        _, out, err = assert_python_ok("-c", code)
468        self.assertEqual(out, b'')
469        self.assertEqual(err, b'')
470
471    @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()")
472    def test_is_alive_after_fork(self):
473        # Try hard to trigger #18418: is_alive() could sometimes be True on
474        # threads that vanished after a fork.
475        old_interval = sys.getswitchinterval()
476        self.addCleanup(sys.setswitchinterval, old_interval)
477
478        # Make the bug more likely to manifest.
479        test.support.setswitchinterval(1e-6)
480
481        for i in range(20):
482            t = threading.Thread(target=lambda: None)
483            t.start()
484            pid = os.fork()
485            if pid == 0:
486                os._exit(11 if t.is_alive() else 10)
487            else:
488                t.join()
489
490                pid, status = os.waitpid(pid, 0)
491                self.assertTrue(os.WIFEXITED(status))
492                self.assertEqual(10, os.WEXITSTATUS(status))
493
494    def test_main_thread(self):
495        main = threading.main_thread()
496        self.assertEqual(main.name, 'MainThread')
497        self.assertEqual(main.ident, threading.current_thread().ident)
498        self.assertEqual(main.ident, threading.get_ident())
499
500        def f():
501            self.assertNotEqual(threading.main_thread().ident,
502                                threading.current_thread().ident)
503        th = threading.Thread(target=f)
504        th.start()
505        th.join()
506
507    @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
508    @unittest.skipUnless(hasattr(os, 'waitpid'), "test needs os.waitpid()")
509    def test_main_thread_after_fork(self):
510        code = """if 1:
511            import os, threading
512
513            pid = os.fork()
514            if pid == 0:
515                main = threading.main_thread()
516                print(main.name)
517                print(main.ident == threading.current_thread().ident)
518                print(main.ident == threading.get_ident())
519            else:
520                os.waitpid(pid, 0)
521        """
522        _, out, err = assert_python_ok("-c", code)
523        data = out.decode().replace('\r', '')
524        self.assertEqual(err, b"")
525        self.assertEqual(data, "MainThread\nTrue\nTrue\n")
526
527    @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
528    @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
529    @unittest.skipUnless(hasattr(os, 'waitpid'), "test needs os.waitpid()")
530    def test_main_thread_after_fork_from_nonmain_thread(self):
531        code = """if 1:
532            import os, threading, sys
533
534            def f():
535                pid = os.fork()
536                if pid == 0:
537                    main = threading.main_thread()
538                    print(main.name)
539                    print(main.ident == threading.current_thread().ident)
540                    print(main.ident == threading.get_ident())
541                    # stdout is fully buffered because not a tty,
542                    # we have to flush before exit.
543                    sys.stdout.flush()
544                else:
545                    os.waitpid(pid, 0)
546
547            th = threading.Thread(target=f)
548            th.start()
549            th.join()
550        """
551        _, out, err = assert_python_ok("-c", code)
552        data = out.decode().replace('\r', '')
553        self.assertEqual(err, b"")
554        self.assertEqual(data, "Thread-1\nTrue\nTrue\n")
555
556    @requires_type_collecting
557    def test_main_thread_during_shutdown(self):
558        # bpo-31516: current_thread() should still point to the main thread
559        # at shutdown
560        code = """if 1:
561            import gc, threading
562
563            main_thread = threading.current_thread()
564            assert main_thread is threading.main_thread()  # sanity check
565
566            class RefCycle:
567                def __init__(self):
568                    self.cycle = self
569
570                def __del__(self):
571                    print("GC:",
572                          threading.current_thread() is main_thread,
573                          threading.main_thread() is main_thread,
574                          threading.enumerate() == [main_thread])
575
576            RefCycle()
577            gc.collect()  # sanity check
578            x = RefCycle()
579        """
580        _, out, err = assert_python_ok("-c", code)
581        data = out.decode()
582        self.assertEqual(err, b"")
583        self.assertEqual(data.splitlines(),
584                         ["GC: True True True"] * 2)
585
586    def test_finalization_shutdown(self):
587        # bpo-36402: Py_Finalize() calls threading._shutdown() which must wait
588        # until Python thread states of all non-daemon threads get deleted.
589        #
590        # Test similar to SubinterpThreadingTests.test_threads_join_2(), but
591        # test the finalization of the main interpreter.
592        code = """if 1:
593            import os
594            import threading
595            import time
596            import random
597
598            def random_sleep():
599                seconds = random.random() * 0.010
600                time.sleep(seconds)
601
602            class Sleeper:
603                def __del__(self):
604                    random_sleep()
605
606            tls = threading.local()
607
608            def f():
609                # Sleep a bit so that the thread is still running when
610                # Py_Finalize() is called.
611                random_sleep()
612                tls.x = Sleeper()
613                random_sleep()
614
615            threading.Thread(target=f).start()
616            random_sleep()
617        """
618        rc, out, err = assert_python_ok("-c", code)
619        self.assertEqual(err, b"")
620
621    def test_tstate_lock(self):
622        # Test an implementation detail of Thread objects.
623        started = _thread.allocate_lock()
624        finish = _thread.allocate_lock()
625        started.acquire()
626        finish.acquire()
627        def f():
628            started.release()
629            finish.acquire()
630            time.sleep(0.01)
631        # The tstate lock is None until the thread is started
632        t = threading.Thread(target=f)
633        self.assertIs(t._tstate_lock, None)
634        t.start()
635        started.acquire()
636        self.assertTrue(t.is_alive())
637        # The tstate lock can't be acquired when the thread is running
638        # (or suspended).
639        tstate_lock = t._tstate_lock
640        self.assertFalse(tstate_lock.acquire(timeout=0), False)
641        finish.release()
642        # When the thread ends, the state_lock can be successfully
643        # acquired.
644        self.assertTrue(tstate_lock.acquire(timeout=5), False)
645        # But is_alive() is still True:  we hold _tstate_lock now, which
646        # prevents is_alive() from knowing the thread's end-of-life C code
647        # is done.
648        self.assertTrue(t.is_alive())
649        # Let is_alive() find out the C code is done.
650        tstate_lock.release()
651        self.assertFalse(t.is_alive())
652        # And verify the thread disposed of _tstate_lock.
653        self.assertIsNone(t._tstate_lock)
654        t.join()
655
656    def test_repr_stopped(self):
657        # Verify that "stopped" shows up in repr(Thread) appropriately.
658        started = _thread.allocate_lock()
659        finish = _thread.allocate_lock()
660        started.acquire()
661        finish.acquire()
662        def f():
663            started.release()
664            finish.acquire()
665        t = threading.Thread(target=f)
666        t.start()
667        started.acquire()
668        self.assertIn("started", repr(t))
669        finish.release()
670        # "stopped" should appear in the repr in a reasonable amount of time.
671        # Implementation detail:  as of this writing, that's trivially true
672        # if .join() is called, and almost trivially true if .is_alive() is
673        # called.  The detail we're testing here is that "stopped" shows up
674        # "all on its own".
675        LOOKING_FOR = "stopped"
676        for i in range(500):
677            if LOOKING_FOR in repr(t):
678                break
679            time.sleep(0.01)
680        self.assertIn(LOOKING_FOR, repr(t)) # we waited at least 5 seconds
681        t.join()
682
683    def test_BoundedSemaphore_limit(self):
684        # BoundedSemaphore should raise ValueError if released too often.
685        for limit in range(1, 10):
686            bs = threading.BoundedSemaphore(limit)
687            threads = [threading.Thread(target=bs.acquire)
688                       for _ in range(limit)]
689            for t in threads:
690                t.start()
691            for t in threads:
692                t.join()
693            threads = [threading.Thread(target=bs.release)
694                       for _ in range(limit)]
695            for t in threads:
696                t.start()
697            for t in threads:
698                t.join()
699            self.assertRaises(ValueError, bs.release)
700
701    @cpython_only
702    def test_frame_tstate_tracing(self):
703        # Issue #14432: Crash when a generator is created in a C thread that is
704        # destroyed while the generator is still used. The issue was that a
705        # generator contains a frame, and the frame kept a reference to the
706        # Python state of the destroyed C thread. The crash occurs when a trace
707        # function is setup.
708
709        def noop_trace(frame, event, arg):
710            # no operation
711            return noop_trace
712
713        def generator():
714            while 1:
715                yield "generator"
716
717        def callback():
718            if callback.gen is None:
719                callback.gen = generator()
720            return next(callback.gen)
721        callback.gen = None
722
723        old_trace = sys.gettrace()
724        sys.settrace(noop_trace)
725        try:
726            # Install a trace function
727            threading.settrace(noop_trace)
728
729            # Create a generator in a C thread which exits after the call
730            import _testcapi
731            _testcapi.call_in_temporary_c_thread(callback)
732
733            # Call the generator in a different Python thread, check that the
734            # generator didn't keep a reference to the destroyed thread state
735            for test in range(3):
736                # The trace function is still called here
737                callback()
738        finally:
739            sys.settrace(old_trace)
740
741    @cpython_only
742    def test_shutdown_locks(self):
743        for daemon in (False, True):
744            with self.subTest(daemon=daemon):
745                event = threading.Event()
746                thread = threading.Thread(target=event.wait, daemon=daemon)
747
748                # Thread.start() must add lock to _shutdown_locks,
749                # but only for non-daemon thread
750                thread.start()
751                tstate_lock = thread._tstate_lock
752                if not daemon:
753                    self.assertIn(tstate_lock, threading._shutdown_locks)
754                else:
755                    self.assertNotIn(tstate_lock, threading._shutdown_locks)
756
757                # unblock the thread and join it
758                event.set()
759                thread.join()
760
761                # Thread._stop() must remove tstate_lock from _shutdown_locks.
762                # Daemon threads must never add it to _shutdown_locks.
763                self.assertNotIn(tstate_lock, threading._shutdown_locks)
764
765
766class ThreadJoinOnShutdown(BaseTestCase):
767
768    def _run_and_join(self, script):
769        script = """if 1:
770            import sys, os, time, threading
771
772            # a thread, which waits for the main program to terminate
773            def joiningfunc(mainthread):
774                mainthread.join()
775                print('end of thread')
776                # stdout is fully buffered because not a tty, we have to flush
777                # before exit.
778                sys.stdout.flush()
779        \n""" + script
780
781        rc, out, err = assert_python_ok("-c", script)
782        data = out.decode().replace('\r', '')
783        self.assertEqual(data, "end of main\nend of thread\n")
784
785    def test_1_join_on_shutdown(self):
786        # The usual case: on exit, wait for a non-daemon thread
787        script = """if 1:
788            import os
789            t = threading.Thread(target=joiningfunc,
790                                 args=(threading.current_thread(),))
791            t.start()
792            time.sleep(0.1)
793            print('end of main')
794            """
795        self._run_and_join(script)
796
797    @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()")
798    @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
799    def test_2_join_in_forked_process(self):
800        # Like the test above, but from a forked interpreter
801        script = """if 1:
802            childpid = os.fork()
803            if childpid != 0:
804                os.waitpid(childpid, 0)
805                sys.exit(0)
806
807            t = threading.Thread(target=joiningfunc,
808                                 args=(threading.current_thread(),))
809            t.start()
810            print('end of main')
811            """
812        self._run_and_join(script)
813
814    @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()")
815    @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
816    def test_3_join_in_forked_from_thread(self):
817        # Like the test above, but fork() was called from a worker thread
818        # In the forked process, the main Thread object must be marked as stopped.
819
820        script = """if 1:
821            main_thread = threading.current_thread()
822            def worker():
823                childpid = os.fork()
824                if childpid != 0:
825                    os.waitpid(childpid, 0)
826                    sys.exit(0)
827
828                t = threading.Thread(target=joiningfunc,
829                                     args=(main_thread,))
830                print('end of main')
831                t.start()
832                t.join() # Should not block: main_thread is already stopped
833
834            w = threading.Thread(target=worker)
835            w.start()
836            """
837        self._run_and_join(script)
838
839    @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
840    def test_4_daemon_threads(self):
841        # Check that a daemon thread cannot crash the interpreter on shutdown
842        # by manipulating internal structures that are being disposed of in
843        # the main thread.
844        script = """if True:
845            import os
846            import random
847            import sys
848            import time
849            import threading
850
851            thread_has_run = set()
852
853            def random_io():
854                '''Loop for a while sleeping random tiny amounts and doing some I/O.'''
855                while True:
856                    with open(os.__file__, 'rb') as in_f:
857                        stuff = in_f.read(200)
858                        with open(os.devnull, 'wb') as null_f:
859                            null_f.write(stuff)
860                            time.sleep(random.random() / 1995)
861                    thread_has_run.add(threading.current_thread())
862
863            def main():
864                count = 0
865                for _ in range(40):
866                    new_thread = threading.Thread(target=random_io)
867                    new_thread.daemon = True
868                    new_thread.start()
869                    count += 1
870                while len(thread_has_run) < count:
871                    time.sleep(0.001)
872                # Trigger process shutdown
873                sys.exit(0)
874
875            main()
876            """
877        rc, out, err = assert_python_ok('-c', script)
878        self.assertFalse(err)
879
880    @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()")
881    @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
882    def test_reinit_tls_after_fork(self):
883        # Issue #13817: fork() would deadlock in a multithreaded program with
884        # the ad-hoc TLS implementation.
885
886        def do_fork_and_wait():
887            # just fork a child process and wait it
888            pid = os.fork()
889            if pid > 0:
890                os.waitpid(pid, 0)
891            else:
892                os._exit(0)
893
894        # start a bunch of threads that will fork() child processes
895        threads = []
896        for i in range(16):
897            t = threading.Thread(target=do_fork_and_wait)
898            threads.append(t)
899            t.start()
900
901        for t in threads:
902            t.join()
903
904    @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()")
905    def test_clear_threads_states_after_fork(self):
906        # Issue #17094: check that threads states are cleared after fork()
907
908        # start a bunch of threads
909        threads = []
910        for i in range(16):
911            t = threading.Thread(target=lambda : time.sleep(0.3))
912            threads.append(t)
913            t.start()
914
915        pid = os.fork()
916        if pid == 0:
917            # check that threads states have been cleared
918            if len(sys._current_frames()) == 1:
919                os._exit(0)
920            else:
921                os._exit(1)
922        else:
923            _, status = os.waitpid(pid, 0)
924            self.assertEqual(0, status)
925
926        for t in threads:
927            t.join()
928
929
930class SubinterpThreadingTests(BaseTestCase):
931
932    def test_threads_join(self):
933        # Non-daemon threads should be joined at subinterpreter shutdown
934        # (issue #18808)
935        r, w = os.pipe()
936        self.addCleanup(os.close, r)
937        self.addCleanup(os.close, w)
938        code = r"""if 1:
939            import os
940            import random
941            import threading
942            import time
943
944            def random_sleep():
945                seconds = random.random() * 0.010
946                time.sleep(seconds)
947
948            def f():
949                # Sleep a bit so that the thread is still running when
950                # Py_EndInterpreter is called.
951                random_sleep()
952                os.write(%d, b"x")
953
954            threading.Thread(target=f).start()
955            random_sleep()
956            """ % (w,)
957        ret = test.support.run_in_subinterp(code)
958        self.assertEqual(ret, 0)
959        # The thread was joined properly.
960        self.assertEqual(os.read(r, 1), b"x")
961
962    def test_threads_join_2(self):
963        # Same as above, but a delay gets introduced after the thread's
964        # Python code returned but before the thread state is deleted.
965        # To achieve this, we register a thread-local object which sleeps
966        # a bit when deallocated.
967        r, w = os.pipe()
968        self.addCleanup(os.close, r)
969        self.addCleanup(os.close, w)
970        code = r"""if 1:
971            import os
972            import random
973            import threading
974            import time
975
976            def random_sleep():
977                seconds = random.random() * 0.010
978                time.sleep(seconds)
979
980            class Sleeper:
981                def __del__(self):
982                    random_sleep()
983
984            tls = threading.local()
985
986            def f():
987                # Sleep a bit so that the thread is still running when
988                # Py_EndInterpreter is called.
989                random_sleep()
990                tls.x = Sleeper()
991                os.write(%d, b"x")
992
993            threading.Thread(target=f).start()
994            random_sleep()
995            """ % (w,)
996        ret = test.support.run_in_subinterp(code)
997        self.assertEqual(ret, 0)
998        # The thread was joined properly.
999        self.assertEqual(os.read(r, 1), b"x")
1000
1001    @cpython_only
1002    def test_daemon_threads_fatal_error(self):
1003        subinterp_code = r"""if 1:
1004            import os
1005            import threading
1006            import time
1007
1008            def f():
1009                # Make sure the daemon thread is still running when
1010                # Py_EndInterpreter is called.
1011                time.sleep(10)
1012            threading.Thread(target=f, daemon=True).start()
1013            """
1014        script = r"""if 1:
1015            import _testcapi
1016
1017            _testcapi.run_in_subinterp(%r)
1018            """ % (subinterp_code,)
1019        with test.support.SuppressCrashReport():
1020            rc, out, err = assert_python_failure("-c", script)
1021        self.assertIn("Fatal Python error: Py_EndInterpreter: "
1022                      "not the last thread", err.decode())
1023
1024
1025class ThreadingExceptionTests(BaseTestCase):
1026    # A RuntimeError should be raised if Thread.start() is called
1027    # multiple times.
1028    def test_start_thread_again(self):
1029        thread = threading.Thread()
1030        thread.start()
1031        self.assertRaises(RuntimeError, thread.start)
1032        thread.join()
1033
1034    def test_joining_current_thread(self):
1035        current_thread = threading.current_thread()
1036        self.assertRaises(RuntimeError, current_thread.join);
1037
1038    def test_joining_inactive_thread(self):
1039        thread = threading.Thread()
1040        self.assertRaises(RuntimeError, thread.join)
1041
1042    def test_daemonize_active_thread(self):
1043        thread = threading.Thread()
1044        thread.start()
1045        self.assertRaises(RuntimeError, setattr, thread, "daemon", True)
1046        thread.join()
1047
1048    def test_releasing_unacquired_lock(self):
1049        lock = threading.Lock()
1050        self.assertRaises(RuntimeError, lock.release)
1051
1052    def test_recursion_limit(self):
1053        # Issue 9670
1054        # test that excessive recursion within a non-main thread causes
1055        # an exception rather than crashing the interpreter on platforms
1056        # like Mac OS X or FreeBSD which have small default stack sizes
1057        # for threads
1058        script = """if True:
1059            import threading
1060
1061            def recurse():
1062                return recurse()
1063
1064            def outer():
1065                try:
1066                    recurse()
1067                except RecursionError:
1068                    pass
1069
1070            w = threading.Thread(target=outer)
1071            w.start()
1072            w.join()
1073            print('end of main thread')
1074            """
1075        expected_output = "end of main thread\n"
1076        p = subprocess.Popen([sys.executable, "-c", script],
1077                             stdout=subprocess.PIPE, stderr=subprocess.PIPE)
1078        stdout, stderr = p.communicate()
1079        data = stdout.decode().replace('\r', '')
1080        self.assertEqual(p.returncode, 0, "Unexpected error: " + stderr.decode())
1081        self.assertEqual(data, expected_output)
1082
1083    def test_print_exception(self):
1084        script = r"""if True:
1085            import threading
1086            import time
1087
1088            running = False
1089            def run():
1090                global running
1091                running = True
1092                while running:
1093                    time.sleep(0.01)
1094                1/0
1095            t = threading.Thread(target=run)
1096            t.start()
1097            while not running:
1098                time.sleep(0.01)
1099            running = False
1100            t.join()
1101            """
1102        rc, out, err = assert_python_ok("-c", script)
1103        self.assertEqual(out, b'')
1104        err = err.decode()
1105        self.assertIn("Exception in thread", err)
1106        self.assertIn("Traceback (most recent call last):", err)
1107        self.assertIn("ZeroDivisionError", err)
1108        self.assertNotIn("Unhandled exception", err)
1109
1110    @requires_type_collecting
1111    def test_print_exception_stderr_is_none_1(self):
1112        script = r"""if True:
1113            import sys
1114            import threading
1115            import time
1116
1117            running = False
1118            def run():
1119                global running
1120                running = True
1121                while running:
1122                    time.sleep(0.01)
1123                1/0
1124            t = threading.Thread(target=run)
1125            t.start()
1126            while not running:
1127                time.sleep(0.01)
1128            sys.stderr = None
1129            running = False
1130            t.join()
1131            """
1132        rc, out, err = assert_python_ok("-c", script)
1133        self.assertEqual(out, b'')
1134        err = err.decode()
1135        self.assertIn("Exception in thread", err)
1136        self.assertIn("Traceback (most recent call last):", err)
1137        self.assertIn("ZeroDivisionError", err)
1138        self.assertNotIn("Unhandled exception", err)
1139
1140    def test_print_exception_stderr_is_none_2(self):
1141        script = r"""if True:
1142            import sys
1143            import threading
1144            import time
1145
1146            running = False
1147            def run():
1148                global running
1149                running = True
1150                while running:
1151                    time.sleep(0.01)
1152                1/0
1153            sys.stderr = None
1154            t = threading.Thread(target=run)
1155            t.start()
1156            while not running:
1157                time.sleep(0.01)
1158            running = False
1159            t.join()
1160            """
1161        rc, out, err = assert_python_ok("-c", script)
1162        self.assertEqual(out, b'')
1163        self.assertNotIn("Unhandled exception", err.decode())
1164
1165    def test_bare_raise_in_brand_new_thread(self):
1166        def bare_raise():
1167            raise
1168
1169        class Issue27558(threading.Thread):
1170            exc = None
1171
1172            def run(self):
1173                try:
1174                    bare_raise()
1175                except Exception as exc:
1176                    self.exc = exc
1177
1178        thread = Issue27558()
1179        thread.start()
1180        thread.join()
1181        self.assertIsNotNone(thread.exc)
1182        self.assertIsInstance(thread.exc, RuntimeError)
1183        # explicitly break the reference cycle to not leak a dangling thread
1184        thread.exc = None
1185
1186
1187class ThreadRunFail(threading.Thread):
1188    def run(self):
1189        raise ValueError("run failed")
1190
1191
1192class ExceptHookTests(BaseTestCase):
1193    def test_excepthook(self):
1194        with support.captured_output("stderr") as stderr:
1195            thread = ThreadRunFail(name="excepthook thread")
1196            thread.start()
1197            thread.join()
1198
1199        stderr = stderr.getvalue().strip()
1200        self.assertIn(f'Exception in thread {thread.name}:\n', stderr)
1201        self.assertIn('Traceback (most recent call last):\n', stderr)
1202        self.assertIn('  raise ValueError("run failed")', stderr)
1203        self.assertIn('ValueError: run failed', stderr)
1204
1205    @support.cpython_only
1206    def test_excepthook_thread_None(self):
1207        # threading.excepthook called with thread=None: log the thread
1208        # identifier in this case.
1209        with support.captured_output("stderr") as stderr:
1210            try:
1211                raise ValueError("bug")
1212            except Exception as exc:
1213                args = threading.ExceptHookArgs([*sys.exc_info(), None])
1214                try:
1215                    threading.excepthook(args)
1216                finally:
1217                    # Explicitly break a reference cycle
1218                    args = None
1219
1220        stderr = stderr.getvalue().strip()
1221        self.assertIn(f'Exception in thread {threading.get_ident()}:\n', stderr)
1222        self.assertIn('Traceback (most recent call last):\n', stderr)
1223        self.assertIn('  raise ValueError("bug")', stderr)
1224        self.assertIn('ValueError: bug', stderr)
1225
1226    def test_system_exit(self):
1227        class ThreadExit(threading.Thread):
1228            def run(self):
1229                sys.exit(1)
1230
1231        # threading.excepthook() silently ignores SystemExit
1232        with support.captured_output("stderr") as stderr:
1233            thread = ThreadExit()
1234            thread.start()
1235            thread.join()
1236
1237        self.assertEqual(stderr.getvalue(), '')
1238
1239    def test_custom_excepthook(self):
1240        args = None
1241
1242        def hook(hook_args):
1243            nonlocal args
1244            args = hook_args
1245
1246        try:
1247            with support.swap_attr(threading, 'excepthook', hook):
1248                thread = ThreadRunFail()
1249                thread.start()
1250                thread.join()
1251
1252            self.assertEqual(args.exc_type, ValueError)
1253            self.assertEqual(str(args.exc_value), 'run failed')
1254            self.assertEqual(args.exc_traceback, args.exc_value.__traceback__)
1255            self.assertIs(args.thread, thread)
1256        finally:
1257            # Break reference cycle
1258            args = None
1259
1260    def test_custom_excepthook_fail(self):
1261        def threading_hook(args):
1262            raise ValueError("threading_hook failed")
1263
1264        err_str = None
1265
1266        def sys_hook(exc_type, exc_value, exc_traceback):
1267            nonlocal err_str
1268            err_str = str(exc_value)
1269
1270        with support.swap_attr(threading, 'excepthook', threading_hook), \
1271             support.swap_attr(sys, 'excepthook', sys_hook), \
1272             support.captured_output('stderr') as stderr:
1273            thread = ThreadRunFail()
1274            thread.start()
1275            thread.join()
1276
1277        self.assertEqual(stderr.getvalue(),
1278                         'Exception in threading.excepthook:\n')
1279        self.assertEqual(err_str, 'threading_hook failed')
1280
1281
1282class TimerTests(BaseTestCase):
1283
1284    def setUp(self):
1285        BaseTestCase.setUp(self)
1286        self.callback_args = []
1287        self.callback_event = threading.Event()
1288
1289    def test_init_immutable_default_args(self):
1290        # Issue 17435: constructor defaults were mutable objects, they could be
1291        # mutated via the object attributes and affect other Timer objects.
1292        timer1 = threading.Timer(0.01, self._callback_spy)
1293        timer1.start()
1294        self.callback_event.wait()
1295        timer1.args.append("blah")
1296        timer1.kwargs["foo"] = "bar"
1297        self.callback_event.clear()
1298        timer2 = threading.Timer(0.01, self._callback_spy)
1299        timer2.start()
1300        self.callback_event.wait()
1301        self.assertEqual(len(self.callback_args), 2)
1302        self.assertEqual(self.callback_args, [((), {}), ((), {})])
1303        timer1.join()
1304        timer2.join()
1305
1306    def _callback_spy(self, *args, **kwargs):
1307        self.callback_args.append((args[:], kwargs.copy()))
1308        self.callback_event.set()
1309
1310class LockTests(lock_tests.LockTests):
1311    locktype = staticmethod(threading.Lock)
1312
1313class PyRLockTests(lock_tests.RLockTests):
1314    locktype = staticmethod(threading._PyRLock)
1315
1316@unittest.skipIf(threading._CRLock is None, 'RLock not implemented in C')
1317class CRLockTests(lock_tests.RLockTests):
1318    locktype = staticmethod(threading._CRLock)
1319
1320class EventTests(lock_tests.EventTests):
1321    eventtype = staticmethod(threading.Event)
1322
1323class ConditionAsRLockTests(lock_tests.RLockTests):
1324    # Condition uses an RLock by default and exports its API.
1325    locktype = staticmethod(threading.Condition)
1326
1327class ConditionTests(lock_tests.ConditionTests):
1328    condtype = staticmethod(threading.Condition)
1329
1330class SemaphoreTests(lock_tests.SemaphoreTests):
1331    semtype = staticmethod(threading.Semaphore)
1332
1333class BoundedSemaphoreTests(lock_tests.BoundedSemaphoreTests):
1334    semtype = staticmethod(threading.BoundedSemaphore)
1335
1336class BarrierTests(lock_tests.BarrierTests):
1337    barriertype = staticmethod(threading.Barrier)
1338
1339
1340class MiscTestCase(unittest.TestCase):
1341    def test__all__(self):
1342        extra = {"ThreadError"}
1343        blacklist = {'currentThread', 'activeCount'}
1344        support.check__all__(self, threading, ('threading', '_thread'),
1345                             extra=extra, blacklist=blacklist)
1346
1347
1348class InterruptMainTests(unittest.TestCase):
1349    def test_interrupt_main_subthread(self):
1350        # Calling start_new_thread with a function that executes interrupt_main
1351        # should raise KeyboardInterrupt upon completion.
1352        def call_interrupt():
1353            _thread.interrupt_main()
1354        t = threading.Thread(target=call_interrupt)
1355        with self.assertRaises(KeyboardInterrupt):
1356            t.start()
1357            t.join()
1358        t.join()
1359
1360    def test_interrupt_main_mainthread(self):
1361        # Make sure that if interrupt_main is called in main thread that
1362        # KeyboardInterrupt is raised instantly.
1363        with self.assertRaises(KeyboardInterrupt):
1364            _thread.interrupt_main()
1365
1366    def test_interrupt_main_noerror(self):
1367        handler = signal.getsignal(signal.SIGINT)
1368        try:
1369            # No exception should arise.
1370            signal.signal(signal.SIGINT, signal.SIG_IGN)
1371            _thread.interrupt_main()
1372
1373            signal.signal(signal.SIGINT, signal.SIG_DFL)
1374            _thread.interrupt_main()
1375        finally:
1376            # Restore original handler
1377            signal.signal(signal.SIGINT, handler)
1378
1379
1380if __name__ == "__main__":
1381    unittest.main()
1382