1from __future__ import absolute_import, unicode_literals 2 3import pytest 4from case import Mock, patch 5 6from celery import chord, group 7from celery._state import _task_stack 8from celery.backends.rpc import RPCBackend 9 10 11class test_RPCResultConsumer: 12 def get_backend(self): 13 return RPCBackend(app=self.app) 14 15 def get_consumer(self): 16 return self.get_backend().result_consumer 17 18 def test_drain_events_before_start(self): 19 consumer = self.get_consumer() 20 # drain_events shouldn't crash when called before start 21 consumer.drain_events(0.001) 22 23 24class test_RPCBackend: 25 26 def setup(self): 27 self.b = RPCBackend(app=self.app) 28 29 def test_oid(self): 30 oid = self.b.oid 31 oid2 = self.b.oid 32 assert oid == oid2 33 assert oid == self.app.oid 34 35 def test_interface(self): 36 self.b.on_reply_declare('task_id') 37 38 def test_ensure_chords_allowed(self): 39 with pytest.raises(NotImplementedError): 40 self.b.ensure_chords_allowed() 41 42 def test_apply_chord(self): 43 with pytest.raises(NotImplementedError): 44 self.b.apply_chord(self.app.GroupResult(), None) 45 46 @pytest.mark.celery(result_backend='rpc') 47 def test_chord_raises_error(self): 48 with pytest.raises(NotImplementedError): 49 chord(self.add.s(i, i) for i in range(10))(self.add.s([2])) 50 51 @pytest.mark.celery(result_backend='rpc') 52 def test_chain_with_chord_raises_error(self): 53 with pytest.raises(NotImplementedError): 54 (self.add.s(2, 2) | 55 group(self.add.s(2, 2), 56 self.add.s(5, 6)) | self.add.s()).delay() 57 58 def test_destination_for(self): 59 req = Mock(name='request') 60 req.reply_to = 'reply_to' 61 req.correlation_id = 'corid' 62 assert self.b.destination_for('task_id', req) == ('reply_to', 'corid') 63 task = Mock() 64 _task_stack.push(task) 65 try: 66 task.request.reply_to = 'reply_to' 67 task.request.correlation_id = 'corid' 68 assert self.b.destination_for('task_id', None) == ( 69 'reply_to', 'corid', 70 ) 71 finally: 72 _task_stack.pop() 73 74 with pytest.raises(RuntimeError): 75 self.b.destination_for('task_id', None) 76 77 def test_binding(self): 78 queue = self.b.binding 79 assert queue.name == self.b.oid 80 assert queue.exchange == self.b.exchange 81 assert queue.routing_key == self.b.oid 82 assert not queue.durable 83 assert queue.auto_delete 84 85 def test_create_binding(self): 86 assert self.b._create_binding('id') == self.b.binding 87 88 def test_on_task_call(self): 89 with patch('celery.backends.rpc.maybe_declare') as md: 90 with self.app.amqp.producer_pool.acquire() as prod: 91 self.b.on_task_call(prod, 'task_id'), 92 md.assert_called_with( 93 self.b.binding(prod.channel), 94 retry=True, 95 ) 96 97 def test_create_exchange(self): 98 ex = self.b._create_exchange('name') 99 assert isinstance(ex, self.b.Exchange) 100 assert ex.name == '' 101