1# Some simple queue module tests, plus some failure conditions
2# to ensure the Queue locks remain stable.
3import itertools
4import random
5import threading
6import time
7import unittest
8import weakref
9from test import support
10
11py_queue = support.import_fresh_module('queue', blocked=['_queue'])
12c_queue = support.import_fresh_module('queue', fresh=['_queue'])
13need_c_queue = unittest.skipUnless(c_queue, "No _queue module found")
14
15QUEUE_SIZE = 5
16
17def qfull(q):
18    return q.maxsize > 0 and q.qsize() == q.maxsize
19
20# A thread to run a function that unclogs a blocked Queue.
21class _TriggerThread(threading.Thread):
22    def __init__(self, fn, args):
23        self.fn = fn
24        self.args = args
25        self.startedEvent = threading.Event()
26        threading.Thread.__init__(self)
27
28    def run(self):
29        # The sleep isn't necessary, but is intended to give the blocking
30        # function in the main thread a chance at actually blocking before
31        # we unclog it.  But if the sleep is longer than the timeout-based
32        # tests wait in their blocking functions, those tests will fail.
33        # So we give them much longer timeout values compared to the
34        # sleep here (I aimed at 10 seconds for blocking functions --
35        # they should never actually wait that long - they should make
36        # progress as soon as we call self.fn()).
37        time.sleep(0.1)
38        self.startedEvent.set()
39        self.fn(*self.args)
40
41
42# Execute a function that blocks, and in a separate thread, a function that
43# triggers the release.  Returns the result of the blocking function.  Caution:
44# block_func must guarantee to block until trigger_func is called, and
45# trigger_func must guarantee to change queue state so that block_func can make
46# enough progress to return.  In particular, a block_func that just raises an
47# exception regardless of whether trigger_func is called will lead to
48# timing-dependent sporadic failures, and one of those went rarely seen but
49# undiagnosed for years.  Now block_func must be unexceptional.  If block_func
50# is supposed to raise an exception, call do_exceptional_blocking_test()
51# instead.
52
53class BlockingTestMixin:
54
55    def do_blocking_test(self, block_func, block_args, trigger_func, trigger_args):
56        thread = _TriggerThread(trigger_func, trigger_args)
57        thread.start()
58        try:
59            self.result = block_func(*block_args)
60            # If block_func returned before our thread made the call, we failed!
61            if not thread.startedEvent.is_set():
62                self.fail("blocking function %r appeared not to block" %
63                          block_func)
64            return self.result
65        finally:
66            support.join_thread(thread, 10) # make sure the thread terminates
67
68    # Call this instead if block_func is supposed to raise an exception.
69    def do_exceptional_blocking_test(self,block_func, block_args, trigger_func,
70                                   trigger_args, expected_exception_class):
71        thread = _TriggerThread(trigger_func, trigger_args)
72        thread.start()
73        try:
74            try:
75                block_func(*block_args)
76            except expected_exception_class:
77                raise
78            else:
79                self.fail("expected exception of kind %r" %
80                                 expected_exception_class)
81        finally:
82            support.join_thread(thread, 10) # make sure the thread terminates
83            if not thread.startedEvent.is_set():
84                self.fail("trigger thread ended but event never set")
85
86
87class BaseQueueTestMixin(BlockingTestMixin):
88    def setUp(self):
89        self.cum = 0
90        self.cumlock = threading.Lock()
91
92    def basic_queue_test(self, q):
93        if q.qsize():
94            raise RuntimeError("Call this function with an empty queue")
95        self.assertTrue(q.empty())
96        self.assertFalse(q.full())
97        # I guess we better check things actually queue correctly a little :)
98        q.put(111)
99        q.put(333)
100        q.put(222)
101        target_order = dict(Queue = [111, 333, 222],
102                            LifoQueue = [222, 333, 111],
103                            PriorityQueue = [111, 222, 333])
104        actual_order = [q.get(), q.get(), q.get()]
105        self.assertEqual(actual_order, target_order[q.__class__.__name__],
106                         "Didn't seem to queue the correct data!")
107        for i in range(QUEUE_SIZE-1):
108            q.put(i)
109            self.assertTrue(q.qsize(), "Queue should not be empty")
110        self.assertTrue(not qfull(q), "Queue should not be full")
111        last = 2 * QUEUE_SIZE
112        full = 3 * 2 * QUEUE_SIZE
113        q.put(last)
114        self.assertTrue(qfull(q), "Queue should be full")
115        self.assertFalse(q.empty())
116        self.assertTrue(q.full())
117        try:
118            q.put(full, block=0)
119            self.fail("Didn't appear to block with a full queue")
120        except self.queue.Full:
121            pass
122        try:
123            q.put(full, timeout=0.01)
124            self.fail("Didn't appear to time-out with a full queue")
125        except self.queue.Full:
126            pass
127        # Test a blocking put
128        self.do_blocking_test(q.put, (full,), q.get, ())
129        self.do_blocking_test(q.put, (full, True, 10), q.get, ())
130        # Empty it
131        for i in range(QUEUE_SIZE):
132            q.get()
133        self.assertTrue(not q.qsize(), "Queue should be empty")
134        try:
135            q.get(block=0)
136            self.fail("Didn't appear to block with an empty queue")
137        except self.queue.Empty:
138            pass
139        try:
140            q.get(timeout=0.01)
141            self.fail("Didn't appear to time-out with an empty queue")
142        except self.queue.Empty:
143            pass
144        # Test a blocking get
145        self.do_blocking_test(q.get, (), q.put, ('empty',))
146        self.do_blocking_test(q.get, (True, 10), q.put, ('empty',))
147
148
149    def worker(self, q):
150        while True:
151            x = q.get()
152            if x < 0:
153                q.task_done()
154                return
155            with self.cumlock:
156                self.cum += x
157            q.task_done()
158
159    def queue_join_test(self, q):
160        self.cum = 0
161        threads = []
162        for i in (0,1):
163            thread = threading.Thread(target=self.worker, args=(q,))
164            thread.start()
165            threads.append(thread)
166        for i in range(100):
167            q.put(i)
168        q.join()
169        self.assertEqual(self.cum, sum(range(100)),
170                         "q.join() did not block until all tasks were done")
171        for i in (0,1):
172            q.put(-1)         # instruct the threads to close
173        q.join()                # verify that you can join twice
174        for thread in threads:
175            thread.join()
176
177    def test_queue_task_done(self):
178        # Test to make sure a queue task completed successfully.
179        q = self.type2test()
180        try:
181            q.task_done()
182        except ValueError:
183            pass
184        else:
185            self.fail("Did not detect task count going negative")
186
187    def test_queue_join(self):
188        # Test that a queue join()s successfully, and before anything else
189        # (done twice for insurance).
190        q = self.type2test()
191        self.queue_join_test(q)
192        self.queue_join_test(q)
193        try:
194            q.task_done()
195        except ValueError:
196            pass
197        else:
198            self.fail("Did not detect task count going negative")
199
200    def test_basic(self):
201        # Do it a couple of times on the same queue.
202        # Done twice to make sure works with same instance reused.
203        q = self.type2test(QUEUE_SIZE)
204        self.basic_queue_test(q)
205        self.basic_queue_test(q)
206
207    def test_negative_timeout_raises_exception(self):
208        q = self.type2test(QUEUE_SIZE)
209        with self.assertRaises(ValueError):
210            q.put(1, timeout=-1)
211        with self.assertRaises(ValueError):
212            q.get(1, timeout=-1)
213
214    def test_nowait(self):
215        q = self.type2test(QUEUE_SIZE)
216        for i in range(QUEUE_SIZE):
217            q.put_nowait(1)
218        with self.assertRaises(self.queue.Full):
219            q.put_nowait(1)
220
221        for i in range(QUEUE_SIZE):
222            q.get_nowait()
223        with self.assertRaises(self.queue.Empty):
224            q.get_nowait()
225
226    def test_shrinking_queue(self):
227        # issue 10110
228        q = self.type2test(3)
229        q.put(1)
230        q.put(2)
231        q.put(3)
232        with self.assertRaises(self.queue.Full):
233            q.put_nowait(4)
234        self.assertEqual(q.qsize(), 3)
235        q.maxsize = 2                       # shrink the queue
236        with self.assertRaises(self.queue.Full):
237            q.put_nowait(4)
238
239class QueueTest(BaseQueueTestMixin):
240
241    def setUp(self):
242        self.type2test = self.queue.Queue
243        super().setUp()
244
245class PyQueueTest(QueueTest, unittest.TestCase):
246    queue = py_queue
247
248
249@need_c_queue
250class CQueueTest(QueueTest, unittest.TestCase):
251    queue = c_queue
252
253
254class LifoQueueTest(BaseQueueTestMixin):
255
256    def setUp(self):
257        self.type2test = self.queue.LifoQueue
258        super().setUp()
259
260
261class PyLifoQueueTest(LifoQueueTest, unittest.TestCase):
262    queue = py_queue
263
264
265@need_c_queue
266class CLifoQueueTest(LifoQueueTest, unittest.TestCase):
267    queue = c_queue
268
269
270class PriorityQueueTest(BaseQueueTestMixin):
271
272    def setUp(self):
273        self.type2test = self.queue.PriorityQueue
274        super().setUp()
275
276
277class PyPriorityQueueTest(PriorityQueueTest, unittest.TestCase):
278    queue = py_queue
279
280
281@need_c_queue
282class CPriorityQueueTest(PriorityQueueTest, unittest.TestCase):
283    queue = c_queue
284
285
286# A Queue subclass that can provoke failure at a moment's notice :)
287class FailingQueueException(Exception): pass
288
289class FailingQueueTest(BlockingTestMixin):
290
291    def setUp(self):
292
293        Queue = self.queue.Queue
294
295        class FailingQueue(Queue):
296            def __init__(self, *args):
297                self.fail_next_put = False
298                self.fail_next_get = False
299                Queue.__init__(self, *args)
300            def _put(self, item):
301                if self.fail_next_put:
302                    self.fail_next_put = False
303                    raise FailingQueueException("You Lose")
304                return Queue._put(self, item)
305            def _get(self):
306                if self.fail_next_get:
307                    self.fail_next_get = False
308                    raise FailingQueueException("You Lose")
309                return Queue._get(self)
310
311        self.FailingQueue = FailingQueue
312
313        super().setUp()
314
315    def failing_queue_test(self, q):
316        if q.qsize():
317            raise RuntimeError("Call this function with an empty queue")
318        for i in range(QUEUE_SIZE-1):
319            q.put(i)
320        # Test a failing non-blocking put.
321        q.fail_next_put = True
322        try:
323            q.put("oops", block=0)
324            self.fail("The queue didn't fail when it should have")
325        except FailingQueueException:
326            pass
327        q.fail_next_put = True
328        try:
329            q.put("oops", timeout=0.1)
330            self.fail("The queue didn't fail when it should have")
331        except FailingQueueException:
332            pass
333        q.put("last")
334        self.assertTrue(qfull(q), "Queue should be full")
335        # Test a failing blocking put
336        q.fail_next_put = True
337        try:
338            self.do_blocking_test(q.put, ("full",), q.get, ())
339            self.fail("The queue didn't fail when it should have")
340        except FailingQueueException:
341            pass
342        # Check the Queue isn't damaged.
343        # put failed, but get succeeded - re-add
344        q.put("last")
345        # Test a failing timeout put
346        q.fail_next_put = True
347        try:
348            self.do_exceptional_blocking_test(q.put, ("full", True, 10), q.get, (),
349                                              FailingQueueException)
350            self.fail("The queue didn't fail when it should have")
351        except FailingQueueException:
352            pass
353        # Check the Queue isn't damaged.
354        # put failed, but get succeeded - re-add
355        q.put("last")
356        self.assertTrue(qfull(q), "Queue should be full")
357        q.get()
358        self.assertTrue(not qfull(q), "Queue should not be full")
359        q.put("last")
360        self.assertTrue(qfull(q), "Queue should be full")
361        # Test a blocking put
362        self.do_blocking_test(q.put, ("full",), q.get, ())
363        # Empty it
364        for i in range(QUEUE_SIZE):
365            q.get()
366        self.assertTrue(not q.qsize(), "Queue should be empty")
367        q.put("first")
368        q.fail_next_get = True
369        try:
370            q.get()
371            self.fail("The queue didn't fail when it should have")
372        except FailingQueueException:
373            pass
374        self.assertTrue(q.qsize(), "Queue should not be empty")
375        q.fail_next_get = True
376        try:
377            q.get(timeout=0.1)
378            self.fail("The queue didn't fail when it should have")
379        except FailingQueueException:
380            pass
381        self.assertTrue(q.qsize(), "Queue should not be empty")
382        q.get()
383        self.assertTrue(not q.qsize(), "Queue should be empty")
384        q.fail_next_get = True
385        try:
386            self.do_exceptional_blocking_test(q.get, (), q.put, ('empty',),
387                                              FailingQueueException)
388            self.fail("The queue didn't fail when it should have")
389        except FailingQueueException:
390            pass
391        # put succeeded, but get failed.
392        self.assertTrue(q.qsize(), "Queue should not be empty")
393        q.get()
394        self.assertTrue(not q.qsize(), "Queue should be empty")
395
396    def test_failing_queue(self):
397
398        # Test to make sure a queue is functioning correctly.
399        # Done twice to the same instance.
400        q = self.FailingQueue(QUEUE_SIZE)
401        self.failing_queue_test(q)
402        self.failing_queue_test(q)
403
404
405
406class PyFailingQueueTest(FailingQueueTest, unittest.TestCase):
407    queue = py_queue
408
409
410@need_c_queue
411class CFailingQueueTest(FailingQueueTest, unittest.TestCase):
412    queue = c_queue
413
414
415class BaseSimpleQueueTest:
416
417    def setUp(self):
418        self.q = self.type2test()
419
420    def feed(self, q, seq, rnd):
421        while True:
422            try:
423                val = seq.pop()
424            except IndexError:
425                return
426            q.put(val)
427            if rnd.random() > 0.5:
428                time.sleep(rnd.random() * 1e-3)
429
430    def consume(self, q, results, sentinel):
431        while True:
432            val = q.get()
433            if val == sentinel:
434                return
435            results.append(val)
436
437    def consume_nonblock(self, q, results, sentinel):
438        while True:
439            while True:
440                try:
441                    val = q.get(block=False)
442                except self.queue.Empty:
443                    time.sleep(1e-5)
444                else:
445                    break
446            if val == sentinel:
447                return
448            results.append(val)
449
450    def consume_timeout(self, q, results, sentinel):
451        while True:
452            while True:
453                try:
454                    val = q.get(timeout=1e-5)
455                except self.queue.Empty:
456                    pass
457                else:
458                    break
459            if val == sentinel:
460                return
461            results.append(val)
462
463    def run_threads(self, n_feeders, n_consumers, q, inputs,
464                    feed_func, consume_func):
465        results = []
466        sentinel = None
467        seq = inputs + [sentinel] * n_consumers
468        seq.reverse()
469        rnd = random.Random(42)
470
471        exceptions = []
472        def log_exceptions(f):
473            def wrapper(*args, **kwargs):
474                try:
475                    f(*args, **kwargs)
476                except BaseException as e:
477                    exceptions.append(e)
478            return wrapper
479
480        feeders = [threading.Thread(target=log_exceptions(feed_func),
481                                    args=(q, seq, rnd))
482                   for i in range(n_feeders)]
483        consumers = [threading.Thread(target=log_exceptions(consume_func),
484                                      args=(q, results, sentinel))
485                     for i in range(n_consumers)]
486
487        with support.start_threads(feeders + consumers):
488            pass
489
490        self.assertFalse(exceptions)
491        self.assertTrue(q.empty())
492        self.assertEqual(q.qsize(), 0)
493
494        return results
495
496    def test_basic(self):
497        # Basic tests for get(), put() etc.
498        q = self.q
499        self.assertTrue(q.empty())
500        self.assertEqual(q.qsize(), 0)
501        q.put(1)
502        self.assertFalse(q.empty())
503        self.assertEqual(q.qsize(), 1)
504        q.put(2)
505        q.put_nowait(3)
506        q.put(4)
507        self.assertFalse(q.empty())
508        self.assertEqual(q.qsize(), 4)
509
510        self.assertEqual(q.get(), 1)
511        self.assertEqual(q.qsize(), 3)
512
513        self.assertEqual(q.get_nowait(), 2)
514        self.assertEqual(q.qsize(), 2)
515
516        self.assertEqual(q.get(block=False), 3)
517        self.assertFalse(q.empty())
518        self.assertEqual(q.qsize(), 1)
519
520        self.assertEqual(q.get(timeout=0.1), 4)
521        self.assertTrue(q.empty())
522        self.assertEqual(q.qsize(), 0)
523
524        with self.assertRaises(self.queue.Empty):
525            q.get(block=False)
526        with self.assertRaises(self.queue.Empty):
527            q.get(timeout=1e-3)
528        with self.assertRaises(self.queue.Empty):
529            q.get_nowait()
530        self.assertTrue(q.empty())
531        self.assertEqual(q.qsize(), 0)
532
533    def test_negative_timeout_raises_exception(self):
534        q = self.q
535        q.put(1)
536        with self.assertRaises(ValueError):
537            q.get(timeout=-1)
538
539    def test_order(self):
540        # Test a pair of concurrent put() and get()
541        q = self.q
542        inputs = list(range(100))
543        results = self.run_threads(1, 1, q, inputs, self.feed, self.consume)
544
545        # One producer, one consumer => results appended in well-defined order
546        self.assertEqual(results, inputs)
547
548    def test_many_threads(self):
549        # Test multiple concurrent put() and get()
550        N = 50
551        q = self.q
552        inputs = list(range(10000))
553        results = self.run_threads(N, N, q, inputs, self.feed, self.consume)
554
555        # Multiple consumers without synchronization append the
556        # results in random order
557        self.assertEqual(sorted(results), inputs)
558
559    def test_many_threads_nonblock(self):
560        # Test multiple concurrent put() and get(block=False)
561        N = 50
562        q = self.q
563        inputs = list(range(10000))
564        results = self.run_threads(N, N, q, inputs,
565                                   self.feed, self.consume_nonblock)
566
567        self.assertEqual(sorted(results), inputs)
568
569    def test_many_threads_timeout(self):
570        # Test multiple concurrent put() and get(timeout=...)
571        N = 50
572        q = self.q
573        inputs = list(range(1000))
574        results = self.run_threads(N, N, q, inputs,
575                                   self.feed, self.consume_timeout)
576
577        self.assertEqual(sorted(results), inputs)
578
579    def test_references(self):
580        # The queue should lose references to each item as soon as
581        # it leaves the queue.
582        class C:
583            pass
584
585        N = 20
586        q = self.q
587        for i in range(N):
588            q.put(C())
589        for i in range(N):
590            wr = weakref.ref(q.get())
591            self.assertIsNone(wr())
592
593
594class PySimpleQueueTest(BaseSimpleQueueTest, unittest.TestCase):
595
596    queue = py_queue
597    def setUp(self):
598        self.type2test = self.queue._PySimpleQueue
599        super().setUp()
600
601
602@need_c_queue
603class CSimpleQueueTest(BaseSimpleQueueTest, unittest.TestCase):
604
605    queue = c_queue
606
607    def setUp(self):
608        self.type2test = self.queue.SimpleQueue
609        super().setUp()
610
611    def test_is_default(self):
612        self.assertIs(self.type2test, self.queue.SimpleQueue)
613        self.assertIs(self.type2test, self.queue.SimpleQueue)
614
615    def test_reentrancy(self):
616        # bpo-14976: put() may be called reentrantly in an asynchronous
617        # callback.
618        q = self.q
619        gen = itertools.count()
620        N = 10000
621        results = []
622
623        # This test exploits the fact that __del__ in a reference cycle
624        # can be called any time the GC may run.
625
626        class Circular(object):
627            def __init__(self):
628                self.circular = self
629
630            def __del__(self):
631                q.put(next(gen))
632
633        while True:
634            o = Circular()
635            q.put(next(gen))
636            del o
637            results.append(q.get())
638            if results[-1] >= N:
639                break
640
641        self.assertEqual(results, list(range(N + 1)))
642
643
644if __name__ == "__main__":
645    unittest.main()
646