1# Disable warning Missing docstring
2# pylint: disable=C0111
3
4# Disable warning Invalid variable name
5# pylint: disable=C0103
6
7# Suppress pylint warning about access to protected member
8# pylint: disable=W0212
9
10# Suppress no-member: Twisted's reactor methods are not easily discoverable
11# pylint: disable=E1101
12
13"""twisted adapter test"""
14import unittest
15
16import mock
17from nose.twistedtools import reactor, deferred
18from twisted.internet import defer, error as twisted_error
19from twisted.python.failure import Failure
20
21from pika.adapters.twisted_connection import (
22    ClosableDeferredQueue, ReceivedMessage, TwistedChannel,
23    _TwistedConnectionAdapter, TwistedProtocolConnection, _TimerHandle)
24from pika import spec
25from pika.exceptions import (
26    AMQPConnectionError, ConsumerCancelled, DuplicateGetOkCallback, NackError,
27    UnroutableError, ChannelClosedByBroker)
28from pika.frame import Method
29
30
31class TestCase(unittest.TestCase):
32    """Imported from twisted.trial.unittest.TestCase
33
34    We only want the assertFailure implementation, using the class directly
35    hides some assertion errors.
36    """
37
38    def assertFailure(self, d, *expectedFailures):
39        """
40        Fail if C{deferred} does not errback with one of C{expectedFailures}.
41        Returns the original Deferred with callbacks added. You will need
42        to return this Deferred from your test case.
43        """
44        def _cb(ignore):
45            raise self.failureException(
46                "did not catch an error, instead got %r" % (ignore,))
47
48        def _eb(failure):
49            if failure.check(*expectedFailures):
50                return failure.value
51            else:
52                output = ('\nExpected: %r\nGot:\n%s'
53                          % (expectedFailures, str(failure)))
54                raise self.failureException(output)
55        return d.addCallbacks(_cb, _eb)
56
57
58class ClosableDeferredQueueTestCase(TestCase):
59
60    @deferred(timeout=5.0)
61    def test_put_closed(self):
62        # Verify that the .put() method errbacks when the queue is closed.
63        q = ClosableDeferredQueue()
64        q.closed = RuntimeError("testing")
65        d = self.assertFailure(q.put(None), RuntimeError)
66        d.addCallback(lambda e: self.assertEqual(e.args[0], "testing"))
67        return d
68
69    @deferred(timeout=5.0)
70    def test_get_closed(self):
71        # Verify that the .get() method errbacks when the queue is closed.
72        q = ClosableDeferredQueue()
73        q.closed = RuntimeError("testing")
74        d = self.assertFailure(q.get(), RuntimeError)
75        d.addCallback(lambda e: self.assertEqual(e.args[0], "testing"))
76        return d
77
78    def test_close(self):
79        # Verify that the queue can be closed.
80        q = ClosableDeferredQueue()
81        q.close("testing")
82        self.assertEqual(q.closed, "testing")
83        self.assertEqual(q.waiting, [])
84        self.assertEqual(q.pending, [])
85
86    def test_close_waiting(self):
87        # Verify that the deferred waiting for new data are errbacked when the
88        # queue is closed.
89        q = ClosableDeferredQueue()
90        d = q.get()
91        q.close(RuntimeError("testing"))
92        self.assertTrue(q.closed)
93        self.assertEqual(q.waiting, [])
94        self.assertEqual(q.pending, [])
95        return self.assertFailure(d, RuntimeError)
96
97    def test_close_twice(self):
98        # If a queue it called twice, it must not crash.
99        q = ClosableDeferredQueue()
100        q.close("testing")
101        self.assertEqual(q.closed, "testing")
102        q.close("testing")
103        self.assertEqual(q.closed, "testing")
104
105
106class TwistedChannelTestCase(TestCase):
107
108    def setUp(self):
109        self.pika_channel = mock.Mock()
110        self.channel = TwistedChannel(self.pika_channel)
111        # This is only needed on Python2 for functools.wraps to work.
112        wrapped = (
113            "basic_cancel", "basic_get", "basic_qos", "basic_recover",
114            "exchange_bind", "exchange_unbind", "exchange_declare",
115            "exchange_delete", "confirm_delivery", "flow",
116            "queue_bind", "queue_declare", "queue_delete", "queue_purge",
117            "queue_unbind", "tx_commit", "tx_rollback", "tx_select",
118        )
119        for meth_name in wrapped:
120            getattr(self.pika_channel, meth_name).__name__ = meth_name
121
122    def test_repr(self):
123        self.pika_channel.__repr__ = lambda _s: "<TestChannel>"
124        self.assertEqual(
125            repr(self.channel),
126            "<TwistedChannel channel=<TestChannel>>",
127        )
128
129    @deferred(timeout=5.0)
130    def test_on_close(self):
131        # Verify that the channel can be closed and that pending calls and
132        # consumers are errbacked.
133        self.pika_channel.add_on_close_callback.assert_called_with(
134            self.channel._on_channel_closed)
135        calls = self.channel._calls = [defer.Deferred()]
136        consumers = self.channel._consumers = {
137            "test-delivery-tag": mock.Mock()
138        }
139        error = RuntimeError("testing")
140        self.channel._on_channel_closed(None, error)
141        consumers["test-delivery-tag"].close.assert_called_once_with(error)
142        self.assertEqual(len(self.channel._calls), 0)
143        self.assertEqual(len(self.channel._consumers), 0)
144        return self.assertFailure(calls[0], RuntimeError)
145
146    @deferred(timeout=5.0)
147    def test_basic_consume(self):
148        # Verify that the basic_consume method works properly.
149        d = self.channel.basic_consume(queue="testqueue")
150        self.pika_channel.basic_consume.assert_called_once()
151        kwargs = self.pika_channel.basic_consume.call_args_list[0][1]
152        self.assertEqual(kwargs["queue"], "testqueue")
153        on_message = kwargs["on_message_callback"]
154
155        def check_cb(result):
156            queue, _consumer_tag = result
157            # Make sure the queue works
158            queue_get_d = queue.get()
159            queue_get_d.addCallback(
160                self.assertEqual,
161                (self.channel, "testmethod", "testprops", "testbody")
162            )
163            # Simulate reception of a message
164            on_message("testchan", "testmethod", "testprops", "testbody")
165            return queue_get_d
166        d.addCallback(check_cb)
167        # Simulate a ConsumeOk from the server
168        frame = Method(1, spec.Basic.ConsumeOk(consumer_tag="testconsumertag"))
169        kwargs["callback"](frame)
170        return d
171
172    @deferred(timeout=5.0)
173    def test_basic_consume_while_closed(self):
174        # Verify that a Failure is returned when the channel's basic_consume
175        # is called and the channel is closed.
176        error = RuntimeError("testing")
177        self.channel._on_channel_closed(None, error)
178        d = self.channel.basic_consume(queue="testqueue")
179        return self.assertFailure(d, RuntimeError)
180
181    @deferred(timeout=5.0)
182    def test_basic_consume_failure(self):
183        # Verify that a Failure is returned when the channel's basic_consume
184        # method fails.
185        self.pika_channel.basic_consume.side_effect = RuntimeError()
186        d = self.channel.basic_consume(queue="testqueue")
187        return self.assertFailure(d, RuntimeError)
188
189    def test_basic_consume_errback_on_close(self):
190        # Verify Deferreds that haven't had their callback invoked errback when
191        # the channel closes.
192        d = self.channel.basic_consume(queue="testqueue")
193        self.channel._on_channel_closed(
194            self, ChannelClosedByBroker(404, "NOT FOUND"))
195        return self.assertFailure(d, ChannelClosedByBroker)
196
197    @deferred(timeout=5.0)
198    def test_queue_delete(self):
199        # Verify that the consumers are cleared when a queue is deleted.
200        queue_obj = mock.Mock()
201        self.channel._consumers = {
202            "test-delivery-tag": queue_obj,
203        }
204        self.channel._queue_name_to_consumer_tags["testqueue"] = set([
205            "test-delivery-tag"
206        ])
207        self.channel._calls = set()
208        self.pika_channel.queue_delete.__name__ = "queue_delete"
209        d = self.channel.queue_delete(queue="testqueue")
210        self.pika_channel.queue_delete.assert_called_once()
211        call_kw = self.pika_channel.queue_delete.call_args_list[0][1]
212        self.assertEqual(call_kw["queue"], "testqueue")
213
214        def check(_):
215            self.assertEqual(len(self.channel._consumers), 0)
216            queue_obj.close.assert_called_once()
217            close_call_args = queue_obj.close.call_args_list[0][0]
218            self.assertEqual(len(close_call_args), 1)
219            self.assertTrue(isinstance(close_call_args[0], ConsumerCancelled))
220        d.addCallback(check)
221        # Simulate a server response
222        self.assertEqual(len(self.channel._calls), 1)
223        list(self.channel._calls)[0].callback(None)
224        return d
225
226    @deferred(timeout=5.0)
227    def test_wrapped_method(self):
228        # Verify that the wrapped method is called and the result is properly
229        # transmitted via the Deferred.
230        self.pika_channel.queue_declare.__name__ = "queue_declare"
231        d = self.channel.queue_declare(queue="testqueue")
232        self.pika_channel.queue_declare.assert_called_once()
233        call_kw = self.pika_channel.queue_declare.call_args_list[0][1]
234        self.assertIn("queue", call_kw)
235        self.assertEqual(call_kw["queue"], "testqueue")
236        self.assertIn("callback", call_kw)
237        self.assertTrue(callable(call_kw["callback"]))
238        call_kw["callback"]("testresult")
239        d.addCallback(self.assertEqual, "testresult")
240        return d
241
242    @deferred(timeout=5.0)
243    def test_wrapped_method_while_closed(self):
244        # Verify that a Failure is returned when one of the channel's wrapped
245        # methods is called and the channel is closed.
246        error = RuntimeError("testing")
247        self.channel._on_channel_closed(None, error)
248        self.pika_channel.queue_declare.__name__ = "queue_declare"
249        d = self.channel.queue_declare(queue="testqueue")
250        return self.assertFailure(d, RuntimeError)
251
252    @deferred(timeout=5.0)
253    def test_wrapped_method_multiple_args(self):
254        # Verify that multiple arguments to the callback are properly converted
255        # to a tuple for the Deferred's result.
256        self.pika_channel.queue_declare.__name__ = "queue_declare"
257        d = self.channel.queue_declare(queue="testqueue")
258        call_kw = self.pika_channel.queue_declare.call_args_list[0][1]
259        call_kw["callback"]("testresult-1", "testresult-2")
260        d.addCallback(self.assertEqual, ("testresult-1", "testresult-2"))
261        return d
262
263    @deferred(timeout=5.0)
264    def test_wrapped_method_failure(self):
265        # Verify that exceptions are properly handled in wrapped methods.
266        error = RuntimeError("testing")
267        self.pika_channel.queue_declare.__name__ = "queue_declare"
268        self.pika_channel.queue_declare.side_effect = error
269        d = self.channel.queue_declare(queue="testqueue")
270        return self.assertFailure(d, RuntimeError)
271
272    def test_method_not_wrapped(self):
273        # Test that only methods that can be wrapped are wrapped.
274        result = self.channel.basic_ack()
275        self.assertFalse(isinstance(result, defer.Deferred))
276        self.pika_channel.basic_ack.assert_called_once()
277
278    def test_passthrough(self):
279        # Check the simple attribute passthroughs
280        attributes = (
281            "channel_number", "connection", "is_closed", "is_closing",
282            "is_open", "flow_active", "consumer_tags",
283        )
284        for name in attributes:
285            value = "testvalue-{}".format(name)
286            setattr(self.pika_channel, name, value)
287            self.assertEqual(getattr(self.channel, name), value)
288
289    def test_callback_deferred(self):
290        # Check that the deferred will be called back.
291        d = defer.Deferred()
292        replies = [spec.Basic.CancelOk]
293        self.channel.callback_deferred(d, replies)
294        self.pika_channel.add_callback.assert_called_with(
295            d.callback, replies)
296
297    def test_add_on_return_callback(self):
298        # Check that the deferred contains the right value.
299        cb = mock.Mock()
300        self.channel.add_on_return_callback(cb)
301        self.pika_channel.add_on_return_callback.assert_called_once()
302        self.pika_channel.add_on_return_callback.call_args[0][0](
303            "testchannel", "testmethod", "testprops", "testbody")
304        cb.assert_called_once()
305        self.assertEqual(len(cb.call_args[0]), 1)
306        self.assertEqual(
307            cb.call_args[0][0],
308            (self.channel, "testmethod", "testprops", "testbody")
309        )
310
311    @deferred(timeout=5.0)
312    def test_basic_cancel(self):
313        # Verify that basic_cancels calls clean up the consumer queue.
314        queue_obj = mock.Mock()
315        queue_obj_2 = mock.Mock()
316        self.channel._consumers["test-consumer"] = queue_obj
317        self.channel._consumers["test-consumer-2"] = queue_obj_2
318        self.channel._queue_name_to_consumer_tags.update({
319            "testqueue": set(["test-consumer"]),
320            "testqueue-2": set(["test-consumer-2"]),
321        })
322        d = self.channel.basic_cancel("test-consumer")
323
324        def check(result):
325            self.assertTrue(isinstance(result, Method))
326            queue_obj.close.assert_called_once()
327            self.assertTrue(isinstance(
328                queue_obj.close.call_args[0][0], ConsumerCancelled))
329            self.assertEqual(len(self.channel._consumers), 1)
330            queue_obj_2.close.assert_not_called()
331            self.assertEqual(
332                self.channel._queue_name_to_consumer_tags["testqueue"],
333                set())
334        d.addCallback(check)
335        self.pika_channel.basic_cancel.assert_called_once()
336        self.pika_channel.basic_cancel.call_args[1]["callback"](
337            Method(1, spec.Basic.CancelOk(consumer_tag="test-consumer"))
338        )
339        return d
340
341    @deferred(timeout=5.0)
342    def test_basic_cancel_no_consumer(self):
343        # Verify that basic_cancel does not crash if there is no consumer.
344        d = self.channel.basic_cancel("test-consumer")
345
346        def check(result):
347            self.assertTrue(isinstance(result, Method))
348        d.addCallback(check)
349        self.pika_channel.basic_cancel.assert_called_once()
350        self.pika_channel.basic_cancel.call_args[1]["callback"](
351            Method(1, spec.Basic.CancelOk(consumer_tag="test-consumer"))
352        )
353        return d
354
355    def test_consumer_cancelled_by_broker(self):
356        # Verify that server-originating cancels are handled.
357        self.pika_channel.add_on_cancel_callback.assert_called_with(
358            self.channel._on_consumer_cancelled_by_broker)
359        queue_obj = mock.Mock()
360        self.channel._consumers["test-consumer"] = queue_obj
361        self.channel._queue_name_to_consumer_tags["testqueue"] = set([
362            "test-consumer"])
363        self.channel._on_consumer_cancelled_by_broker(
364            Method(1, spec.Basic.Cancel(consumer_tag="test-consumer"))
365        )
366        queue_obj.close.assert_called_once()
367        self.assertTrue(isinstance(
368            queue_obj.close.call_args[0][0], ConsumerCancelled))
369        self.assertEqual(self.channel._consumers, {})
370        self.assertEqual(
371            self.channel._queue_name_to_consumer_tags["testqueue"],
372            set())
373
374    @deferred(timeout=5.0)
375    def test_basic_get(self):
376        # Verify that the basic_get method works properly.
377        d = self.channel.basic_get(queue="testqueue")
378        self.pika_channel.basic_get.assert_called_once()
379        kwargs = self.pika_channel.basic_get.call_args_list[0][1]
380        self.assertEqual(kwargs["queue"], "testqueue")
381
382        def check_cb(result):
383            self.assertEqual(
384                result,
385                (self.channel, "testmethod", "testprops", "testbody")
386            )
387        d.addCallback(check_cb)
388        # Simulate reception of a message
389        kwargs["callback"](
390            "testchannel", "testmethod", "testprops", "testbody")
391        return d
392
393    def test_basic_get_twice(self):
394        # Verify that the basic_get method raises the proper exception when
395        # called twice.
396        self.channel.basic_get(queue="testqueue")
397        self.assertRaises(
398            DuplicateGetOkCallback, self.channel.basic_get, "testqueue")
399
400    @deferred(timeout=5.0)
401    def test_basic_get_empty(self):
402        # Verify that the basic_get method works when the queue is empty.
403        self.pika_channel.add_callback.assert_called_with(
404            self.channel._on_getempty, [spec.Basic.GetEmpty], False)
405        d = self.channel.basic_get(queue="testqueue")
406        self.channel._on_getempty("testmethod")
407        d.addCallback(self.assertIsNone)
408        return d
409
410    def test_basic_nack(self):
411        # Verify that basic_nack is transmitted properly.
412        self.channel.basic_nack("testdeliverytag")
413        self.pika_channel.basic_nack.assert_called_once_with(
414            delivery_tag="testdeliverytag",
415            multiple=False, requeue=True)
416
417    @deferred(timeout=5.0)
418    def test_basic_publish(self):
419        # Verify that basic_publish wraps properly.
420        args = [object()]
421        kwargs = {"routing_key": object(), "body": object()}
422        d = self.channel.basic_publish(*args, **kwargs)
423        kwargs.update(dict(
424            # Args are converted to kwargs
425            exchange=args[0],
426            # Defaults
427            mandatory=False, properties=None,
428        ))
429        self.pika_channel.basic_publish.assert_called_once_with(
430            **kwargs)
431        return d
432
433    @deferred(timeout=5.0)
434    def test_basic_publish_closed(self):
435        # Verify that a Failure is returned when the channel's basic_publish
436        # is called and the channel is closed.
437        self.channel._on_channel_closed(None, RuntimeError("testing"))
438        d = self.channel.basic_publish(None, None, None)
439        self.pika_channel.basic_publish.assert_not_called()
440        d = self.assertFailure(d, RuntimeError)
441        d.addCallback(lambda e: self.assertEqual(e.args[0], "testing"))
442        return d
443
444    def _test_wrapped_func(self, func, kwargs, do_callback=False):
445        func.assert_called_once()
446        call_kw = dict(
447            (key, value) for key, value in
448            func.call_args[1].items()
449            if key != "callback"
450        )
451        self.assertEqual(kwargs, call_kw)
452        if do_callback:
453            func.call_args[1]["callback"](do_callback)
454
455    @deferred(timeout=5.0)
456    def test_basic_qos(self):
457        # Verify that basic_qos wraps properly.
458        kwargs = {"prefetch_size": 2}
459        d = self.channel.basic_qos(**kwargs)
460        # Defaults
461        kwargs.update(dict(prefetch_count=0, global_qos=False))
462        self._test_wrapped_func(self.pika_channel.basic_qos, kwargs, True)
463        return d
464
465    def test_basic_reject(self):
466        # Verify that basic_reject is transmitted properly.
467        self.channel.basic_reject("testdeliverytag")
468        self.pika_channel.basic_reject.assert_called_once_with(
469            delivery_tag="testdeliverytag", requeue=True)
470
471    @deferred(timeout=5.0)
472    def test_basic_recover(self):
473        # Verify that basic_recover wraps properly.
474        d = self.channel.basic_recover()
475        self._test_wrapped_func(
476            self.pika_channel.basic_recover, {"requeue": False}, True)
477        return d
478
479    def test_close(self):
480        # Verify that close wraps properly.
481        self.channel.close()
482        self.pika_channel.close.assert_called_once_with(
483            reply_code=0, reply_text="Normal shutdown")
484
485    @deferred(timeout=5.0)
486    def test_confirm_delivery(self):
487        # Verify that confirm_delivery works
488        d = self.channel.confirm_delivery()
489        self.pika_channel.confirm_delivery.assert_called_once()
490        self.assertEqual(
491            self.pika_channel.confirm_delivery.call_args[1][
492                "ack_nack_callback"],
493            self.channel._on_delivery_confirmation)
494
495        def send_message(_result):
496            d = self.channel.basic_publish("testexch", "testrk", "testbody")
497            frame = Method(1, spec.Basic.Ack(delivery_tag=1))
498            self.channel._on_delivery_confirmation(frame)
499            return d
500
501        def check_response(frame_method):
502            self.assertTrue(isinstance(frame_method, spec.Basic.Ack))
503        d.addCallback(send_message)
504        d.addCallback(check_response)
505        # Simulate Confirm.SelectOk
506        self.pika_channel.confirm_delivery.call_args[1]["callback"](None)
507        return d
508
509    @deferred(timeout=5.0)
510    def test_confirm_delivery_nacked(self):
511        # Verify that messages can be nacked when delivery
512        # confirmation is on.
513        d = self.channel.confirm_delivery()
514
515        def send_message(_result):
516            d = self.channel.basic_publish("testexch", "testrk", "testbody")
517            frame = Method(1, spec.Basic.Nack(delivery_tag=1))
518            self.channel._on_delivery_confirmation(frame)
519            return d
520
521        def check_response(error):
522            self.assertIsInstance(error.value, NackError)
523            self.assertEqual(len(error.value.messages), 0)
524        d.addCallback(send_message)
525        d.addCallbacks(self.fail, check_response)
526        # Simulate Confirm.SelectOk
527        self.pika_channel.confirm_delivery.call_args[1]["callback"](None)
528        return d
529
530    @deferred(timeout=5.0)
531    def test_confirm_delivery_returned(self):
532        # Verify handling of unroutable messages.
533        d = self.channel.confirm_delivery()
534        self.pika_channel.add_on_return_callback.assert_called_once()
535        return_cb = self.pika_channel.add_on_return_callback.call_args[0][0]
536
537        def send_message(_result):
538            d = self.channel.basic_publish("testexch", "testrk", "testbody")
539            # Send the Basic.Return frame
540            method = spec.Basic.Return(
541                exchange="testexch", routing_key="testrk")
542            return_cb(self.channel, method,
543                    spec.BasicProperties(), "testbody")
544            # Send the Basic.Ack frame
545            frame = Method(1, spec.Basic.Ack(delivery_tag=1))
546            self.channel._on_delivery_confirmation(frame)
547            return d
548
549        def check_response(error):
550            self.assertIsInstance(error.value, UnroutableError)
551            self.assertEqual(len(error.value.messages), 1)
552            msg = error.value.messages[0]
553            self.assertEqual(msg.body, "testbody")
554        d.addCallbacks(send_message, self.fail)
555        d.addCallbacks(self.fail, check_response)
556        # Simulate Confirm.SelectOk
557        self.pika_channel.confirm_delivery.call_args[1]["callback"](None)
558        return d
559
560    @deferred(timeout=5.0)
561    def test_confirm_delivery_returned_nacked(self):
562        # Verify that messages can be nacked when delivery
563        # confirmation is on.
564        d = self.channel.confirm_delivery()
565        self.pika_channel.add_on_return_callback.assert_called_once()
566        return_cb = self.pika_channel.add_on_return_callback.call_args[0][0]
567
568        def send_message(_result):
569            d = self.channel.basic_publish("testexch", "testrk", "testbody")
570            # Send the Basic.Return frame
571            method = spec.Basic.Return(
572                exchange="testexch", routing_key="testrk")
573            return_cb(self.channel, method,
574                    spec.BasicProperties(), "testbody")
575            # Send the Basic.Nack frame
576            frame = Method(1, spec.Basic.Nack(delivery_tag=1))
577            self.channel._on_delivery_confirmation(frame)
578            return d
579
580        def check_response(error):
581            self.assertTrue(isinstance(error.value, NackError))
582            self.assertEqual(len(error.value.messages), 1)
583            msg = error.value.messages[0]
584            self.assertEqual(msg.body, "testbody")
585        d.addCallback(send_message)
586        d.addCallbacks(self.fail, check_response)
587        self.pika_channel.confirm_delivery.call_args[1]["callback"](None)
588        return d
589
590    @deferred(timeout=5.0)
591    def test_confirm_delivery_multiple(self):
592        # Verify that multiple messages can be acked at once when
593        # delivery confirmation is on.
594        d = self.channel.confirm_delivery()
595
596        def send_message(_result):
597            d1 = self.channel.basic_publish("testexch", "testrk", "testbody1")
598            d2 = self.channel.basic_publish("testexch", "testrk", "testbody2")
599            frame = Method(1, spec.Basic.Ack(delivery_tag=2, multiple=True))
600            self.channel._on_delivery_confirmation(frame)
601            return defer.DeferredList([d1, d2])
602
603        def check_response(results):
604            self.assertTrue(len(results), 2)
605            for is_ok, result in results:
606                self.assertTrue(is_ok)
607                self.assertTrue(isinstance(result, spec.Basic.Ack))
608        d.addCallback(send_message)
609        d.addCallback(check_response)
610        self.pika_channel.confirm_delivery.call_args[1]["callback"](None)
611        return d
612
613
614class TwistedProtocolConnectionTestCase(TestCase):
615
616    def setUp(self):
617        self.conn = TwistedProtocolConnection()
618        self.conn._impl = mock.Mock()
619
620    @deferred(timeout=5.0)
621    def test_connection(self):
622        # Verify that the connection opening is properly wrapped.
623        transport = mock.Mock()
624        self.conn.connectionMade = mock.Mock()
625        self.conn.makeConnection(transport)
626        self.conn._impl.connection_made.assert_called_once_with(
627            transport)
628        self.conn.connectionMade.assert_called_once()
629        d = self.conn.ready
630        self.conn._on_connection_ready(None)
631        return d
632
633    @deferred(timeout=5.0)
634    def test_channel(self):
635        # Verify that the request for a channel works properly.
636        channel = mock.Mock()
637        self.conn._impl.channel.side_effect = lambda n, cb: cb(channel)
638        d = self.conn.channel()
639        self.conn._impl.channel.assert_called_once()
640
641        def check(result):
642            self.assertTrue(isinstance(result, TwistedChannel))
643        d.addCallback(check)
644        return d
645
646    def test_dataReceived(self):
647        # Verify that the data is transmitted to the callback method.
648        self.conn.dataReceived("testdata")
649        self.conn._impl.data_received.assert_called_once_with("testdata")
650
651    @deferred(timeout=5.0)
652    def test_connectionLost(self):
653        # Verify that the "ready" Deferred errbacks on connectionLost, and that
654        # the underlying implementation callback is called.
655        ready_d = self.conn.ready
656        error = RuntimeError("testreason")
657        self.conn.connectionLost(error)
658        self.conn._impl.connection_lost.assert_called_with(error)
659        self.assertIsNone(self.conn.ready)
660        return self.assertFailure(ready_d, RuntimeError)
661
662    def test_connectionLost_twice(self):
663        # Verify that calling connectionLost twice will not cause an
664        # AlreadyCalled error on the Deferred.
665        ready_d = self.conn.ready
666        error = RuntimeError("testreason")
667        self.conn.connectionLost(error)
668        self.assertTrue(ready_d.called)
669        ready_d.addErrback(lambda f: None)  # silence the error
670        self.assertIsNone(self.conn.ready)
671        # A second call must not raise AlreadyCalled
672        self.conn.connectionLost(error)
673
674    @deferred(timeout=5.0)
675    def test_on_connection_ready(self):
676        # Verify that the "ready" Deferred is resolved on _on_connection_ready.
677        d = self.conn.ready
678        self.conn._on_connection_ready("testresult")
679        self.assertTrue(d.called)
680        d.addCallback(self.assertIsNone)
681        return d
682
683    def test_on_connection_ready_twice(self):
684        # Verify that calling _on_connection_ready twice will not cause an
685        # AlreadyCalled error on the Deferred.
686        d = self.conn.ready
687        self.conn._on_connection_ready("testresult")
688        self.assertTrue(d.called)
689        # A second call must not raise AlreadyCalled
690        self.conn._on_connection_ready("testresult")
691
692    @deferred(timeout=5.0)
693    def test_on_connection_ready_method(self):
694        # Verify that the connectionReady method is called when the "ready"
695        # Deferred is resolved.
696        d = self.conn.ready
697        self.conn.connectionReady = mock.Mock()
698        self.conn._on_connection_ready("testresult")
699        self.conn.connectionReady.assert_called_once()
700        return d
701
702    @deferred(timeout=5.0)
703    def test_on_connection_failed(self):
704        # Verify that the "ready" Deferred errbacks on _on_connection_failed.
705        d = self.conn.ready
706        self.conn._on_connection_failed(None)
707        return self.assertFailure(d, AMQPConnectionError)
708
709    def test_on_connection_failed_twice(self):
710        # Verify that calling _on_connection_failed twice will not cause an
711        # AlreadyCalled error on the Deferred.
712        d = self.conn.ready
713        self.conn._on_connection_failed(None)
714        self.assertTrue(d.called)
715        d.addErrback(lambda f: None)  # silence the error
716        # A second call must not raise AlreadyCalled
717        self.conn._on_connection_failed(None)
718
719    @deferred(timeout=5.0)
720    def test_on_connection_closed(self):
721        # Verify that the "closed" Deferred is resolved on
722        # _on_connection_closed.
723        self.conn._on_connection_ready("dummy")
724        d = self.conn.closed
725        self.conn._on_connection_closed("test conn", "test reason")
726        self.assertTrue(d.called)
727        d.addCallback(self.assertEqual, "test reason")
728        return d
729
730    def test_on_connection_closed_twice(self):
731        # Verify that calling _on_connection_closed twice will not cause an
732        # AlreadyCalled error on the Deferred.
733        self.conn._on_connection_ready("dummy")
734        d = self.conn.closed
735        self.conn._on_connection_closed("test conn", "test reason")
736        self.assertTrue(d.called)
737        # A second call must not raise AlreadyCalled
738        self.conn._on_connection_closed("test conn", "test reason")
739
740    @deferred(timeout=5.0)
741    def test_on_connection_closed_Failure(self):
742        # Verify that the _on_connection_closed method can be called with
743        # a Failure instance without triggering the errback path.
744        self.conn._on_connection_ready("dummy")
745        error = RuntimeError()
746        d = self.conn.closed
747        self.conn._on_connection_closed("test conn", Failure(error))
748        self.assertTrue(d.called)
749
750        def _check_cb(result):
751            self.assertEqual(result, error)
752
753        def _check_eb(_failure):
754            self.fail("The errback path should not have been triggered")
755
756        d.addCallbacks(_check_cb, _check_eb)
757        return d
758
759    def test_close(self):
760        # Verify that the close method is properly wrapped.
761        self.conn._impl.is_closed = False
762        self.conn.closed = "TESTING"
763        value = self.conn.close()
764        self.assertEqual(value, "TESTING")
765        self.conn._impl.close.assert_called_once_with(200, "Normal shutdown")
766
767    def test_close_twice(self):
768        # Verify that the close method is only transmitted when open.
769        self.conn._impl.is_closed = True
770        self.conn.close()
771        self.conn._impl.close.assert_not_called()
772
773
774class TwistedConnectionAdapterTestCase(TestCase):
775
776    def setUp(self):
777        self.conn = _TwistedConnectionAdapter(
778            None, None, None, None, None
779        )
780
781    def tearDown(self):
782        if self.conn._transport is None:
783            self.conn._transport = mock.Mock()
784        self.conn.close()
785
786    def test_adapter_disconnect_stream(self):
787        # Verify that the underlying transport is aborted.
788        transport = mock.Mock()
789        self.conn.connection_made(transport)
790        self.conn._adapter_disconnect_stream()
791        transport.loseConnection.assert_called_once()
792
793    def test_adapter_emit_data(self):
794        # Verify that the data is transmitted to the underlying transport.
795        transport = mock.Mock()
796        self.conn.connection_made(transport)
797        self.conn._adapter_emit_data("testdata")
798        transport.write.assert_called_with("testdata")
799
800    def test_timeout(self):
801        # Verify that timeouts are registered and cancelled properly.
802        callback = mock.Mock()
803        timer_id = self.conn._adapter_call_later(5, callback)
804        self.assertEqual(len(reactor.getDelayedCalls()), 1)
805        self.conn._adapter_remove_timeout(timer_id)
806        self.assertEqual(len(reactor.getDelayedCalls()), 0)
807        callback.assert_not_called()
808
809    @deferred(timeout=5.0)
810    def test_call_threadsafe(self):
811        # Verify that the method is actually called using the reactor's
812        # callFromThread method.
813        callback = mock.Mock()
814        self.conn._adapter_add_callback_threadsafe(callback)
815        d = defer.Deferred()
816
817        def check():
818            callback.assert_called_once()
819            d.callback(None)
820        # Give time to run the callFromThread call
821        reactor.callLater(0.1, check)
822        return d
823
824    def test_connection_made(self):
825        # Verify the connection callback
826        transport = mock.Mock()
827        self.conn.connection_made(transport)
828        self.assertEqual(self.conn._transport, transport)
829        self.assertEqual(
830            self.conn.connection_state, self.conn.CONNECTION_PROTOCOL)
831
832    def test_connection_lost(self):
833        # Verify that the correct callback is called and that the
834        # attributes are reinitialized.
835        self.conn._on_stream_terminated = mock.Mock()
836        error = Failure(RuntimeError("testreason"))
837        self.conn.connection_lost(error)
838        self.conn._on_stream_terminated.assert_called_with(error.value)
839        self.assertIsNone(self.conn._transport)
840
841    def test_connection_lost_connectiondone(self):
842        # When the ConnectionDone is transmitted, consider it an expected
843        # disconnection.
844        self.conn._on_stream_terminated = mock.Mock()
845        error = Failure(twisted_error.ConnectionDone())
846        self.conn.connection_lost(error)
847        self.assertEqual(self.conn._error, error.value)
848        self.conn._on_stream_terminated.assert_called_with(None)
849        self.assertIsNone(self.conn._transport)
850
851    def test_data_received(self):
852        # Verify that the received data is forwarded to the Connection.
853        data = b"test data"
854        self.conn._on_data_available = mock.Mock()
855        self.conn.data_received(data)
856        self.conn._on_data_available.assert_called_once_with(data)
857
858
859class TimerHandleTestCase(TestCase):
860
861    def setUp(self):
862        self.handle = mock.Mock()
863        self.timer = _TimerHandle(self.handle)
864
865    def test_cancel(self):
866        # Verify that the cancel call is properly transmitted.
867        self.timer.cancel()
868        self.handle.cancel.assert_called_once()
869        self.assertIsNone(self.timer._handle)
870
871    def test_cancel_twice(self):
872        # Verify that cancel() can be called twice.
873        self.timer.cancel()
874        self.timer.cancel()  # This must not traceback
875
876    def test_cancel_already_called(self):
877        # Verify that the timer gracefully handles AlreadyCalled errors.
878        self.handle.cancel.side_effect = twisted_error.AlreadyCalled()
879        self.timer.cancel()
880        self.handle.cancel.assert_called_once()
881
882    def test_cancel_already_cancelled(self):
883        # Verify that the timer gracefully handles AlreadyCancelled errors.
884        self.handle.cancel.side_effect = twisted_error.AlreadyCancelled()
885        self.timer.cancel()
886        self.handle.cancel.assert_called_once()
887