1"""
2Various tests for synchronization primitives.
3"""
4
5import sys
6import time
7from thread import start_new_thread, get_ident
8import threading
9import unittest
10
11from test import test_support as support
12
13
14def _wait():
15    # A crude wait/yield function not relying on synchronization primitives.
16    time.sleep(0.01)
17
18class Bunch(object):
19    """
20    A bunch of threads.
21    """
22    def __init__(self, f, n, wait_before_exit=False):
23        """
24        Construct a bunch of `n` threads running the same function `f`.
25        If `wait_before_exit` is True, the threads won't terminate until
26        do_finish() is called.
27        """
28        self.f = f
29        self.n = n
30        self.started = []
31        self.finished = []
32        self._can_exit = not wait_before_exit
33        def task():
34            tid = get_ident()
35            self.started.append(tid)
36            try:
37                f()
38            finally:
39                self.finished.append(tid)
40                while not self._can_exit:
41                    _wait()
42        try:
43            for i in range(n):
44                start_new_thread(task, ())
45        except:
46            self._can_exit = True
47            raise
48
49    def wait_for_started(self):
50        while len(self.started) < self.n:
51            _wait()
52
53    def wait_for_finished(self):
54        while len(self.finished) < self.n:
55            _wait()
56
57    def do_finish(self):
58        self._can_exit = True
59
60
61class BaseTestCase(unittest.TestCase):
62    def setUp(self):
63        self._threads = support.threading_setup()
64
65    def tearDown(self):
66        support.threading_cleanup(*self._threads)
67        support.reap_children()
68
69
70class BaseLockTests(BaseTestCase):
71    """
72    Tests for both recursive and non-recursive locks.
73    """
74
75    def test_constructor(self):
76        lock = self.locktype()
77        del lock
78
79    def test_acquire_destroy(self):
80        lock = self.locktype()
81        lock.acquire()
82        del lock
83
84    def test_acquire_release(self):
85        lock = self.locktype()
86        lock.acquire()
87        lock.release()
88        del lock
89
90    def test_try_acquire(self):
91        lock = self.locktype()
92        self.assertTrue(lock.acquire(False))
93        lock.release()
94
95    def test_try_acquire_contended(self):
96        lock = self.locktype()
97        lock.acquire()
98        result = []
99        def f():
100            result.append(lock.acquire(False))
101        Bunch(f, 1).wait_for_finished()
102        self.assertFalse(result[0])
103        lock.release()
104
105    def test_acquire_contended(self):
106        lock = self.locktype()
107        lock.acquire()
108        N = 5
109        def f():
110            lock.acquire()
111            lock.release()
112
113        b = Bunch(f, N)
114        b.wait_for_started()
115        _wait()
116        self.assertEqual(len(b.finished), 0)
117        lock.release()
118        b.wait_for_finished()
119        self.assertEqual(len(b.finished), N)
120
121    def test_with(self):
122        lock = self.locktype()
123        def f():
124            lock.acquire()
125            lock.release()
126        def _with(err=None):
127            with lock:
128                if err is not None:
129                    raise err
130        _with()
131        # Check the lock is unacquired
132        Bunch(f, 1).wait_for_finished()
133        self.assertRaises(TypeError, _with, TypeError)
134        # Check the lock is unacquired
135        Bunch(f, 1).wait_for_finished()
136
137    def test_thread_leak(self):
138        # The lock shouldn't leak a Thread instance when used from a foreign
139        # (non-threading) thread.
140        lock = self.locktype()
141        def f():
142            lock.acquire()
143            lock.release()
144        n = len(threading.enumerate())
145        # We run many threads in the hope that existing threads ids won't
146        # be recycled.
147        Bunch(f, 15).wait_for_finished()
148        self.assertEqual(n, len(threading.enumerate()))
149
150
151class LockTests(BaseLockTests):
152    """
153    Tests for non-recursive, weak locks
154    (which can be acquired and released from different threads).
155    """
156    def test_reacquire(self):
157        # Lock needs to be released before re-acquiring.
158        lock = self.locktype()
159        phase = []
160        def f():
161            lock.acquire()
162            phase.append(None)
163            lock.acquire()
164            phase.append(None)
165        start_new_thread(f, ())
166        while len(phase) == 0:
167            _wait()
168        _wait()
169        self.assertEqual(len(phase), 1)
170        lock.release()
171        while len(phase) == 1:
172            _wait()
173        self.assertEqual(len(phase), 2)
174
175    def test_different_thread(self):
176        # Lock can be released from a different thread.
177        lock = self.locktype()
178        lock.acquire()
179        def f():
180            lock.release()
181        b = Bunch(f, 1)
182        b.wait_for_finished()
183        lock.acquire()
184        lock.release()
185
186
187class RLockTests(BaseLockTests):
188    """
189    Tests for recursive locks.
190    """
191    def test_reacquire(self):
192        lock = self.locktype()
193        lock.acquire()
194        lock.acquire()
195        lock.release()
196        lock.acquire()
197        lock.release()
198        lock.release()
199
200    def test_release_unacquired(self):
201        # Cannot release an unacquired lock
202        lock = self.locktype()
203        self.assertRaises(RuntimeError, lock.release)
204        lock.acquire()
205        lock.acquire()
206        lock.release()
207        lock.acquire()
208        lock.release()
209        lock.release()
210        self.assertRaises(RuntimeError, lock.release)
211
212    def test_different_thread(self):
213        # Cannot release from a different thread
214        lock = self.locktype()
215        def f():
216            lock.acquire()
217        b = Bunch(f, 1, True)
218        try:
219            self.assertRaises(RuntimeError, lock.release)
220        finally:
221            b.do_finish()
222
223    def test__is_owned(self):
224        lock = self.locktype()
225        self.assertFalse(lock._is_owned())
226        lock.acquire()
227        self.assertTrue(lock._is_owned())
228        lock.acquire()
229        self.assertTrue(lock._is_owned())
230        result = []
231        def f():
232            result.append(lock._is_owned())
233        Bunch(f, 1).wait_for_finished()
234        self.assertFalse(result[0])
235        lock.release()
236        self.assertTrue(lock._is_owned())
237        lock.release()
238        self.assertFalse(lock._is_owned())
239
240
241class EventTests(BaseTestCase):
242    """
243    Tests for Event objects.
244    """
245
246    def test_is_set(self):
247        evt = self.eventtype()
248        self.assertFalse(evt.is_set())
249        evt.set()
250        self.assertTrue(evt.is_set())
251        evt.set()
252        self.assertTrue(evt.is_set())
253        evt.clear()
254        self.assertFalse(evt.is_set())
255        evt.clear()
256        self.assertFalse(evt.is_set())
257
258    def _check_notify(self, evt):
259        # All threads get notified
260        N = 5
261        results1 = []
262        results2 = []
263        def f():
264            results1.append(evt.wait())
265            results2.append(evt.wait())
266        b = Bunch(f, N)
267        b.wait_for_started()
268        _wait()
269        self.assertEqual(len(results1), 0)
270        evt.set()
271        b.wait_for_finished()
272        self.assertEqual(results1, [True] * N)
273        self.assertEqual(results2, [True] * N)
274
275    def test_notify(self):
276        evt = self.eventtype()
277        self._check_notify(evt)
278        # Another time, after an explicit clear()
279        evt.set()
280        evt.clear()
281        self._check_notify(evt)
282
283    def test_timeout(self):
284        evt = self.eventtype()
285        results1 = []
286        results2 = []
287        N = 5
288        def f():
289            results1.append(evt.wait(0.0))
290            t1 = time.time()
291            r = evt.wait(0.2)
292            t2 = time.time()
293            results2.append((r, t2 - t1))
294        Bunch(f, N).wait_for_finished()
295        self.assertEqual(results1, [False] * N)
296        for r, dt in results2:
297            self.assertFalse(r)
298            self.assertTrue(dt >= 0.2, dt)
299        # The event is set
300        results1 = []
301        results2 = []
302        evt.set()
303        Bunch(f, N).wait_for_finished()
304        self.assertEqual(results1, [True] * N)
305        for r, dt in results2:
306            self.assertTrue(r)
307
308    def test_reset_internal_locks(self):
309        evt = self.eventtype()
310        old_lock = evt._Event__cond._Condition__lock
311        evt._reset_internal_locks()
312        new_lock = evt._Event__cond._Condition__lock
313        self.assertIsNot(new_lock, old_lock)
314        self.assertIs(type(new_lock), type(old_lock))
315
316
317class ConditionTests(BaseTestCase):
318    """
319    Tests for condition variables.
320    """
321
322    def test_acquire(self):
323        cond = self.condtype()
324        # Be default we have an RLock: the condition can be acquired multiple
325        # times.
326        cond.acquire()
327        cond.acquire()
328        cond.release()
329        cond.release()
330        lock = threading.Lock()
331        cond = self.condtype(lock)
332        cond.acquire()
333        self.assertFalse(lock.acquire(False))
334        cond.release()
335        self.assertTrue(lock.acquire(False))
336        self.assertFalse(cond.acquire(False))
337        lock.release()
338        with cond:
339            self.assertFalse(lock.acquire(False))
340
341    def test_unacquired_wait(self):
342        cond = self.condtype()
343        self.assertRaises(RuntimeError, cond.wait)
344
345    def test_unacquired_notify(self):
346        cond = self.condtype()
347        self.assertRaises(RuntimeError, cond.notify)
348
349    def _check_notify(self, cond):
350        # Note that this test is sensitive to timing.  If the worker threads
351        # don't execute in a timely fashion, the main thread may think they
352        # are further along then they are.  The main thread therefore issues
353        # _wait() statements to try to make sure that it doesn't race ahead
354        # of the workers.
355        # Secondly, this test assumes that condition variables are not subject
356        # to spurious wakeups.  The absence of spurious wakeups is an implementation
357        # detail of Condition Cariables in current CPython, but in general, not
358        # a guaranteed property of condition variables as a programming
359        # construct.  In particular, it is possible that this can no longer
360        # be conveniently guaranteed should their implementation ever change.
361        N = 5
362        ready = []
363        results1 = []
364        results2 = []
365        phase_num = 0
366        def f():
367            cond.acquire()
368            ready.append(phase_num)
369            cond.wait()
370            cond.release()
371            results1.append(phase_num)
372            cond.acquire()
373            ready.append(phase_num)
374            cond.wait()
375            cond.release()
376            results2.append(phase_num)
377        b = Bunch(f, N)
378        b.wait_for_started()
379        # first wait, to ensure all workers settle into cond.wait() before
380        # we continue. See issues #8799 and #30727.
381        while len(ready) < 5:
382            _wait()
383        ready = []
384        self.assertEqual(results1, [])
385        # Notify 3 threads at first
386        cond.acquire()
387        cond.notify(3)
388        _wait()
389        phase_num = 1
390        cond.release()
391        while len(results1) < 3:
392            _wait()
393        self.assertEqual(results1, [1] * 3)
394        self.assertEqual(results2, [])
395        # make sure all awaken workers settle into cond.wait()
396        while len(ready) < 3:
397            _wait()
398        # Notify 5 threads: they might be in their first or second wait
399        cond.acquire()
400        cond.notify(5)
401        _wait()
402        phase_num = 2
403        cond.release()
404        while len(results1) + len(results2) < 8:
405            _wait()
406        self.assertEqual(results1, [1] * 3 + [2] * 2)
407        self.assertEqual(results2, [2] * 3)
408        # make sure all workers settle into cond.wait()
409        while len(ready) < 5:
410            _wait()
411        # Notify all threads: they are all in their second wait
412        cond.acquire()
413        cond.notify_all()
414        _wait()
415        phase_num = 3
416        cond.release()
417        while len(results2) < 5:
418            _wait()
419        self.assertEqual(results1, [1] * 3 + [2] * 2)
420        self.assertEqual(results2, [2] * 3 + [3] * 2)
421        b.wait_for_finished()
422
423    def test_notify(self):
424        cond = self.condtype()
425        self._check_notify(cond)
426        # A second time, to check internal state is still ok.
427        self._check_notify(cond)
428
429    def test_timeout(self):
430        cond = self.condtype()
431        results = []
432        N = 5
433        def f():
434            cond.acquire()
435            t1 = time.time()
436            cond.wait(0.2)
437            t2 = time.time()
438            cond.release()
439            results.append(t2 - t1)
440        Bunch(f, N).wait_for_finished()
441        self.assertEqual(len(results), 5)
442        for dt in results:
443            self.assertTrue(dt >= 0.2, dt)
444
445
446class BaseSemaphoreTests(BaseTestCase):
447    """
448    Common tests for {bounded, unbounded} semaphore objects.
449    """
450
451    def test_constructor(self):
452        self.assertRaises(ValueError, self.semtype, value = -1)
453        self.assertRaises(ValueError, self.semtype, value = -sys.maxint)
454
455    def test_acquire(self):
456        sem = self.semtype(1)
457        sem.acquire()
458        sem.release()
459        sem = self.semtype(2)
460        sem.acquire()
461        sem.acquire()
462        sem.release()
463        sem.release()
464
465    def test_acquire_destroy(self):
466        sem = self.semtype()
467        sem.acquire()
468        del sem
469
470    def test_acquire_contended(self):
471        sem = self.semtype(7)
472        sem.acquire()
473        N = 10
474        results1 = []
475        results2 = []
476        phase_num = 0
477        def f():
478            sem.acquire()
479            results1.append(phase_num)
480            sem.acquire()
481            results2.append(phase_num)
482        b = Bunch(f, 10)
483        b.wait_for_started()
484        while len(results1) + len(results2) < 6:
485            _wait()
486        self.assertEqual(results1 + results2, [0] * 6)
487        phase_num = 1
488        for i in range(7):
489            sem.release()
490        while len(results1) + len(results2) < 13:
491            _wait()
492        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7)
493        phase_num = 2
494        for i in range(6):
495            sem.release()
496        while len(results1) + len(results2) < 19:
497            _wait()
498        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6)
499        # The semaphore is still locked
500        self.assertFalse(sem.acquire(False))
501        # Final release, to let the last thread finish
502        sem.release()
503        b.wait_for_finished()
504
505    def test_try_acquire(self):
506        sem = self.semtype(2)
507        self.assertTrue(sem.acquire(False))
508        self.assertTrue(sem.acquire(False))
509        self.assertFalse(sem.acquire(False))
510        sem.release()
511        self.assertTrue(sem.acquire(False))
512
513    def test_try_acquire_contended(self):
514        sem = self.semtype(4)
515        sem.acquire()
516        results = []
517        def f():
518            results.append(sem.acquire(False))
519            results.append(sem.acquire(False))
520        Bunch(f, 5).wait_for_finished()
521        # There can be a thread switch between acquiring the semaphore and
522        # appending the result, therefore results will not necessarily be
523        # ordered.
524        self.assertEqual(sorted(results), [False] * 7 + [True] *  3 )
525
526    def test_default_value(self):
527        # The default initial value is 1.
528        sem = self.semtype()
529        sem.acquire()
530        def f():
531            sem.acquire()
532            sem.release()
533        b = Bunch(f, 1)
534        b.wait_for_started()
535        _wait()
536        self.assertFalse(b.finished)
537        sem.release()
538        b.wait_for_finished()
539
540    def test_with(self):
541        sem = self.semtype(2)
542        def _with(err=None):
543            with sem:
544                self.assertTrue(sem.acquire(False))
545                sem.release()
546                with sem:
547                    self.assertFalse(sem.acquire(False))
548                    if err:
549                        raise err
550        _with()
551        self.assertTrue(sem.acquire(False))
552        sem.release()
553        self.assertRaises(TypeError, _with, TypeError)
554        self.assertTrue(sem.acquire(False))
555        sem.release()
556
557class SemaphoreTests(BaseSemaphoreTests):
558    """
559    Tests for unbounded semaphores.
560    """
561
562    def test_release_unacquired(self):
563        # Unbounded releases are allowed and increment the semaphore's value
564        sem = self.semtype(1)
565        sem.release()
566        sem.acquire()
567        sem.acquire()
568        sem.release()
569
570
571class BoundedSemaphoreTests(BaseSemaphoreTests):
572    """
573    Tests for bounded semaphores.
574    """
575
576    def test_release_unacquired(self):
577        # Cannot go past the initial value
578        sem = self.semtype()
579        self.assertRaises(ValueError, sem.release)
580        sem.acquire()
581        sem.release()
582        self.assertRaises(ValueError, sem.release)
583