1"""
2Various tests for synchronization primitives.
3"""
4
5import sys
6import time
7from thread import start_new_thread, get_ident
8import threading
9import unittest
10
11from test import test_support as support
12
13
14def _wait():
15    # A crude wait/yield function not relying on synchronization primitives.
16    time.sleep(0.01)
17
18class Bunch(object):
19    """
20    A bunch of threads.
21    """
22    def __init__(self, f, n, wait_before_exit=False):
23        """
24        Construct a bunch of `n` threads running the same function `f`.
25        If `wait_before_exit` is True, the threads won't terminate until
26        do_finish() is called.
27        """
28        self.f = f
29        self.n = n
30        self.started = []
31        self.finished = []
32        self._can_exit = not wait_before_exit
33        def task():
34            tid = get_ident()
35            self.started.append(tid)
36            try:
37                f()
38            finally:
39                self.finished.append(tid)
40                while not self._can_exit:
41                    _wait()
42        for i in range(n):
43            start_new_thread(task, ())
44
45    def wait_for_started(self):
46        while len(self.started) < self.n:
47            _wait()
48
49    def wait_for_finished(self):
50        while len(self.finished) < self.n:
51            _wait()
52
53    def do_finish(self):
54        self._can_exit = True
55
56
57class BaseTestCase(unittest.TestCase):
58    def setUp(self):
59        self._threads = support.threading_setup()
60
61    def tearDown(self):
62        support.threading_cleanup(*self._threads)
63        support.reap_children()
64
65
66class BaseLockTests(BaseTestCase):
67    """
68    Tests for both recursive and non-recursive locks.
69    """
70
71    def test_constructor(self):
72        lock = self.locktype()
73        del lock
74
75    def test_acquire_destroy(self):
76        lock = self.locktype()
77        lock.acquire()
78        del lock
79
80    def test_acquire_release(self):
81        lock = self.locktype()
82        lock.acquire()
83        lock.release()
84        del lock
85
86    def test_try_acquire(self):
87        lock = self.locktype()
88        self.assertTrue(lock.acquire(False))
89        lock.release()
90
91    def test_try_acquire_contended(self):
92        lock = self.locktype()
93        lock.acquire()
94        result = []
95        def f():
96            result.append(lock.acquire(False))
97        Bunch(f, 1).wait_for_finished()
98        self.assertFalse(result[0])
99        lock.release()
100
101    def test_acquire_contended(self):
102        lock = self.locktype()
103        lock.acquire()
104        N = 5
105        def f():
106            lock.acquire()
107            lock.release()
108
109        b = Bunch(f, N)
110        b.wait_for_started()
111        _wait()
112        self.assertEqual(len(b.finished), 0)
113        lock.release()
114        b.wait_for_finished()
115        self.assertEqual(len(b.finished), N)
116
117    def test_with(self):
118        lock = self.locktype()
119        def f():
120            lock.acquire()
121            lock.release()
122        def _with(err=None):
123            with lock:
124                if err is not None:
125                    raise err
126        _with()
127        # Check the lock is unacquired
128        Bunch(f, 1).wait_for_finished()
129        self.assertRaises(TypeError, _with, TypeError)
130        # Check the lock is unacquired
131        Bunch(f, 1).wait_for_finished()
132
133    def test_thread_leak(self):
134        # The lock shouldn't leak a Thread instance when used from a foreign
135        # (non-threading) thread.
136        lock = self.locktype()
137        def f():
138            lock.acquire()
139            lock.release()
140        n = len(threading.enumerate())
141        # We run many threads in the hope that existing threads ids won't
142        # be recycled.
143        Bunch(f, 15).wait_for_finished()
144        self.assertEqual(n, len(threading.enumerate()))
145
146
147class LockTests(BaseLockTests):
148    """
149    Tests for non-recursive, weak locks
150    (which can be acquired and released from different threads).
151    """
152    def test_reacquire(self):
153        # Lock needs to be released before re-acquiring.
154        lock = self.locktype()
155        phase = []
156        def f():
157            lock.acquire()
158            phase.append(None)
159            lock.acquire()
160            phase.append(None)
161        start_new_thread(f, ())
162        while len(phase) == 0:
163            _wait()
164        _wait()
165        self.assertEqual(len(phase), 1)
166        lock.release()
167        while len(phase) == 1:
168            _wait()
169        self.assertEqual(len(phase), 2)
170
171    def test_different_thread(self):
172        # Lock can be released from a different thread.
173        lock = self.locktype()
174        lock.acquire()
175        def f():
176            lock.release()
177        b = Bunch(f, 1)
178        b.wait_for_finished()
179        lock.acquire()
180        lock.release()
181
182
183class RLockTests(BaseLockTests):
184    """
185    Tests for recursive locks.
186    """
187    def test_reacquire(self):
188        lock = self.locktype()
189        lock.acquire()
190        lock.acquire()
191        lock.release()
192        lock.acquire()
193        lock.release()
194        lock.release()
195
196    def test_release_unacquired(self):
197        # Cannot release an unacquired lock
198        lock = self.locktype()
199        self.assertRaises(RuntimeError, lock.release)
200        lock.acquire()
201        lock.acquire()
202        lock.release()
203        lock.acquire()
204        lock.release()
205        lock.release()
206        self.assertRaises(RuntimeError, lock.release)
207
208    def test_different_thread(self):
209        # Cannot release from a different thread
210        lock = self.locktype()
211        def f():
212            lock.acquire()
213        b = Bunch(f, 1, True)
214        try:
215            self.assertRaises(RuntimeError, lock.release)
216        finally:
217            b.do_finish()
218
219    def test__is_owned(self):
220        lock = self.locktype()
221        self.assertFalse(lock._is_owned())
222        lock.acquire()
223        self.assertTrue(lock._is_owned())
224        lock.acquire()
225        self.assertTrue(lock._is_owned())
226        result = []
227        def f():
228            result.append(lock._is_owned())
229        Bunch(f, 1).wait_for_finished()
230        self.assertFalse(result[0])
231        lock.release()
232        self.assertTrue(lock._is_owned())
233        lock.release()
234        self.assertFalse(lock._is_owned())
235
236
237class EventTests(BaseTestCase):
238    """
239    Tests for Event objects.
240    """
241
242    def test_is_set(self):
243        evt = self.eventtype()
244        self.assertFalse(evt.is_set())
245        evt.set()
246        self.assertTrue(evt.is_set())
247        evt.set()
248        self.assertTrue(evt.is_set())
249        evt.clear()
250        self.assertFalse(evt.is_set())
251        evt.clear()
252        self.assertFalse(evt.is_set())
253
254    def _check_notify(self, evt):
255        # All threads get notified
256        N = 5
257        results1 = []
258        results2 = []
259        def f():
260            results1.append(evt.wait())
261            results2.append(evt.wait())
262        b = Bunch(f, N)
263        b.wait_for_started()
264        _wait()
265        self.assertEqual(len(results1), 0)
266        evt.set()
267        b.wait_for_finished()
268        self.assertEqual(results1, [True] * N)
269        self.assertEqual(results2, [True] * N)
270
271    def test_notify(self):
272        evt = self.eventtype()
273        self._check_notify(evt)
274        # Another time, after an explicit clear()
275        evt.set()
276        evt.clear()
277        self._check_notify(evt)
278
279    def test_timeout(self):
280        evt = self.eventtype()
281        results1 = []
282        results2 = []
283        N = 5
284        def f():
285            results1.append(evt.wait(0.0))
286            t1 = time.time()
287            r = evt.wait(0.2)
288            t2 = time.time()
289            results2.append((r, t2 - t1))
290        Bunch(f, N).wait_for_finished()
291        self.assertEqual(results1, [False] * N)
292        for r, dt in results2:
293            self.assertFalse(r)
294            self.assertTrue(dt >= 0.2, dt)
295        # The event is set
296        results1 = []
297        results2 = []
298        evt.set()
299        Bunch(f, N).wait_for_finished()
300        self.assertEqual(results1, [True] * N)
301        for r, dt in results2:
302            self.assertTrue(r)
303
304
305class ConditionTests(BaseTestCase):
306    """
307    Tests for condition variables.
308    """
309
310    def test_acquire(self):
311        cond = self.condtype()
312        # Be default we have an RLock: the condition can be acquired multiple
313        # times.
314        cond.acquire()
315        cond.acquire()
316        cond.release()
317        cond.release()
318        lock = threading.Lock()
319        cond = self.condtype(lock)
320        cond.acquire()
321        self.assertFalse(lock.acquire(False))
322        cond.release()
323        self.assertTrue(lock.acquire(False))
324        self.assertFalse(cond.acquire(False))
325        lock.release()
326        with cond:
327            self.assertFalse(lock.acquire(False))
328
329    def test_unacquired_wait(self):
330        cond = self.condtype()
331        self.assertRaises(RuntimeError, cond.wait)
332
333    def test_unacquired_notify(self):
334        cond = self.condtype()
335        self.assertRaises(RuntimeError, cond.notify)
336
337    def _check_notify(self, cond):
338        N = 5
339        results1 = []
340        results2 = []
341        phase_num = 0
342        def f():
343            cond.acquire()
344            cond.wait()
345            cond.release()
346            results1.append(phase_num)
347            cond.acquire()
348            cond.wait()
349            cond.release()
350            results2.append(phase_num)
351        b = Bunch(f, N)
352        b.wait_for_started()
353        _wait()
354        self.assertEqual(results1, [])
355        # Notify 3 threads at first
356        cond.acquire()
357        cond.notify(3)
358        _wait()
359        phase_num = 1
360        cond.release()
361        while len(results1) < 3:
362            _wait()
363        self.assertEqual(results1, [1] * 3)
364        self.assertEqual(results2, [])
365        # Notify 5 threads: they might be in their first or second wait
366        cond.acquire()
367        cond.notify(5)
368        _wait()
369        phase_num = 2
370        cond.release()
371        while len(results1) + len(results2) < 8:
372            _wait()
373        self.assertEqual(results1, [1] * 3 + [2] * 2)
374        self.assertEqual(results2, [2] * 3)
375        # Notify all threads: they are all in their second wait
376        cond.acquire()
377        cond.notify_all()
378        _wait()
379        phase_num = 3
380        cond.release()
381        while len(results2) < 5:
382            _wait()
383        self.assertEqual(results1, [1] * 3 + [2] * 2)
384        self.assertEqual(results2, [2] * 3 + [3] * 2)
385        b.wait_for_finished()
386
387    def test_notify(self):
388        cond = self.condtype()
389        self._check_notify(cond)
390        # A second time, to check internal state is still ok.
391        self._check_notify(cond)
392
393    def test_timeout(self):
394        cond = self.condtype()
395        results = []
396        N = 5
397        def f():
398            cond.acquire()
399            t1 = time.time()
400            cond.wait(0.2)
401            t2 = time.time()
402            cond.release()
403            results.append(t2 - t1)
404        Bunch(f, N).wait_for_finished()
405        self.assertEqual(len(results), 5)
406        for dt in results:
407            self.assertTrue(dt >= 0.2, dt)
408
409
410class BaseSemaphoreTests(BaseTestCase):
411    """
412    Common tests for {bounded, unbounded} semaphore objects.
413    """
414
415    def test_constructor(self):
416        self.assertRaises(ValueError, self.semtype, value = -1)
417        self.assertRaises(ValueError, self.semtype, value = -sys.maxint)
418
419    def test_acquire(self):
420        sem = self.semtype(1)
421        sem.acquire()
422        sem.release()
423        sem = self.semtype(2)
424        sem.acquire()
425        sem.acquire()
426        sem.release()
427        sem.release()
428
429    def test_acquire_destroy(self):
430        sem = self.semtype()
431        sem.acquire()
432        del sem
433
434    def test_acquire_contended(self):
435        sem = self.semtype(7)
436        sem.acquire()
437        N = 10
438        results1 = []
439        results2 = []
440        phase_num = 0
441        def f():
442            sem.acquire()
443            results1.append(phase_num)
444            sem.acquire()
445            results2.append(phase_num)
446        b = Bunch(f, 10)
447        b.wait_for_started()
448        while len(results1) + len(results2) < 6:
449            _wait()
450        self.assertEqual(results1 + results2, [0] * 6)
451        phase_num = 1
452        for i in range(7):
453            sem.release()
454        while len(results1) + len(results2) < 13:
455            _wait()
456        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7)
457        phase_num = 2
458        for i in range(6):
459            sem.release()
460        while len(results1) + len(results2) < 19:
461            _wait()
462        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6)
463        # The semaphore is still locked
464        self.assertFalse(sem.acquire(False))
465        # Final release, to let the last thread finish
466        sem.release()
467        b.wait_for_finished()
468
469    def test_try_acquire(self):
470        sem = self.semtype(2)
471        self.assertTrue(sem.acquire(False))
472        self.assertTrue(sem.acquire(False))
473        self.assertFalse(sem.acquire(False))
474        sem.release()
475        self.assertTrue(sem.acquire(False))
476
477    def test_try_acquire_contended(self):
478        sem = self.semtype(4)
479        sem.acquire()
480        results = []
481        def f():
482            results.append(sem.acquire(False))
483            results.append(sem.acquire(False))
484        Bunch(f, 5).wait_for_finished()
485        # There can be a thread switch between acquiring the semaphore and
486        # appending the result, therefore results will not necessarily be
487        # ordered.
488        self.assertEqual(sorted(results), [False] * 7 + [True] *  3 )
489
490    def test_default_value(self):
491        # The default initial value is 1.
492        sem = self.semtype()
493        sem.acquire()
494        def f():
495            sem.acquire()
496            sem.release()
497        b = Bunch(f, 1)
498        b.wait_for_started()
499        _wait()
500        self.assertFalse(b.finished)
501        sem.release()
502        b.wait_for_finished()
503
504    def test_with(self):
505        sem = self.semtype(2)
506        def _with(err=None):
507            with sem:
508                self.assertTrue(sem.acquire(False))
509                sem.release()
510                with sem:
511                    self.assertFalse(sem.acquire(False))
512                    if err:
513                        raise err
514        _with()
515        self.assertTrue(sem.acquire(False))
516        sem.release()
517        self.assertRaises(TypeError, _with, TypeError)
518        self.assertTrue(sem.acquire(False))
519        sem.release()
520
521class SemaphoreTests(BaseSemaphoreTests):
522    """
523    Tests for unbounded semaphores.
524    """
525
526    def test_release_unacquired(self):
527        # Unbounded releases are allowed and increment the semaphore's value
528        sem = self.semtype(1)
529        sem.release()
530        sem.acquire()
531        sem.acquire()
532        sem.release()
533
534
535class BoundedSemaphoreTests(BaseSemaphoreTests):
536    """
537    Tests for bounded semaphores.
538    """
539
540    def test_release_unacquired(self):
541        # Cannot go past the initial value
542        sem = self.semtype()
543        self.assertRaises(ValueError, sem.release)
544        sem.acquire()
545        sem.release()
546        self.assertRaises(ValueError, sem.release)
547