1from __future__ import absolute_import, unicode_literals
2
3import pytest
4from case import Mock, patch
5
6from celery import chord, group
7from celery._state import _task_stack
8from celery.backends.rpc import RPCBackend
9
10
11class test_RPCResultConsumer:
12    def get_backend(self):
13        return RPCBackend(app=self.app)
14
15    def get_consumer(self):
16        return self.get_backend().result_consumer
17
18    def test_drain_events_before_start(self):
19        consumer = self.get_consumer()
20        # drain_events shouldn't crash when called before start
21        consumer.drain_events(0.001)
22
23
24class test_RPCBackend:
25
26    def setup(self):
27        self.b = RPCBackend(app=self.app)
28
29    def test_oid(self):
30        oid = self.b.oid
31        oid2 = self.b.oid
32        assert oid == oid2
33        assert oid == self.app.oid
34
35    def test_interface(self):
36        self.b.on_reply_declare('task_id')
37
38    def test_ensure_chords_allowed(self):
39        with pytest.raises(NotImplementedError):
40            self.b.ensure_chords_allowed()
41
42    def test_apply_chord(self):
43        with pytest.raises(NotImplementedError):
44            self.b.apply_chord(self.app.GroupResult(), None)
45
46    @pytest.mark.celery(result_backend='rpc')
47    def test_chord_raises_error(self):
48        with pytest.raises(NotImplementedError):
49            chord(self.add.s(i, i) for i in range(10))(self.add.s([2]))
50
51    @pytest.mark.celery(result_backend='rpc')
52    def test_chain_with_chord_raises_error(self):
53        with pytest.raises(NotImplementedError):
54            (self.add.s(2, 2) |
55             group(self.add.s(2, 2),
56                   self.add.s(5, 6)) | self.add.s()).delay()
57
58    def test_destination_for(self):
59        req = Mock(name='request')
60        req.reply_to = 'reply_to'
61        req.correlation_id = 'corid'
62        assert self.b.destination_for('task_id', req) == ('reply_to', 'corid')
63        task = Mock()
64        _task_stack.push(task)
65        try:
66            task.request.reply_to = 'reply_to'
67            task.request.correlation_id = 'corid'
68            assert self.b.destination_for('task_id', None) == (
69                'reply_to', 'corid',
70            )
71        finally:
72            _task_stack.pop()
73
74        with pytest.raises(RuntimeError):
75            self.b.destination_for('task_id', None)
76
77    def test_binding(self):
78        queue = self.b.binding
79        assert queue.name == self.b.oid
80        assert queue.exchange == self.b.exchange
81        assert queue.routing_key == self.b.oid
82        assert not queue.durable
83        assert queue.auto_delete
84
85    def test_create_binding(self):
86        assert self.b._create_binding('id') == self.b.binding
87
88    def test_on_task_call(self):
89        with patch('celery.backends.rpc.maybe_declare') as md:
90            with self.app.amqp.producer_pool.acquire() as prod:
91                self.b.on_task_call(prod, 'task_id'),
92                md.assert_called_with(
93                    self.b.binding(prod.channel),
94                    retry=True,
95                )
96
97    def test_create_exchange(self):
98        ex = self.b._create_exchange('name')
99        assert isinstance(ex, self.b.Exchange)
100        assert ex.name == ''
101