1from __future__ import absolute_import, unicode_literals
2
3import pickle
4from contextlib import contextmanager
5from datetime import timedelta
6from pickle import dumps, loads
7
8import pytest
9from billiard.einfo import ExceptionInfo
10from case import Mock, mock
11
12from celery import states, uuid
13from celery.app.task import Context
14from celery.backends.amqp import AMQPBackend
15from celery.five import Empty, Queue, range
16from celery.result import AsyncResult
17
18
19class SomeClass(object):
20
21    def __init__(self, data):
22        self.data = data
23
24
25class test_AMQPBackend:
26
27    def setup(self):
28        self.app.conf.result_cache_max = 100
29
30    def create_backend(self, **opts):
31        opts = dict({'serializer': 'pickle', 'persistent': True}, **opts)
32        return AMQPBackend(self.app, **opts)
33
34    def test_destination_for(self):
35        b = self.create_backend()
36        request = Mock()
37        assert b.destination_for('id', request) == (
38            b.rkey('id'), request.correlation_id,
39        )
40
41    def test_store_result__no_routing_key(self):
42        b = self.create_backend()
43        b.destination_for = Mock()
44        b.destination_for.return_value = None, None
45        b.store_result('id', None, states.SUCCESS)
46
47    def test_mark_as_done(self):
48        tb1 = self.create_backend(max_cached_results=1)
49        tb2 = self.create_backend(max_cached_results=1)
50
51        tid = uuid()
52
53        tb1.mark_as_done(tid, 42)
54        assert tb2.get_state(tid) == states.SUCCESS
55        assert tb2.get_result(tid) == 42
56        assert tb2._cache.get(tid)
57        assert tb2.get_result(tid), 42
58
59    @pytest.mark.usefixtures('depends_on_current_app')
60    def test_pickleable(self):
61        assert loads(dumps(self.create_backend()))
62
63    def test_revive(self):
64        tb = self.create_backend()
65        tb.revive(None)
66
67    def test_is_pickled(self):
68        tb1 = self.create_backend()
69        tb2 = self.create_backend()
70
71        tid2 = uuid()
72        result = {'foo': 'baz', 'bar': SomeClass(12345)}
73        tb1.mark_as_done(tid2, result)
74        # is serialized properly.
75        rindb = tb2.get_result(tid2)
76        assert rindb.get('foo') == 'baz'
77        assert rindb.get('bar').data == 12345
78
79    def test_mark_as_failure(self):
80        tb1 = self.create_backend()
81        tb2 = self.create_backend()
82
83        tid3 = uuid()
84        try:
85            raise KeyError('foo')
86        except KeyError as exception:
87            einfo = ExceptionInfo()
88            tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback)
89            assert tb2.get_state(tid3) == states.FAILURE
90            assert isinstance(tb2.get_result(tid3), KeyError)
91            assert tb2.get_traceback(tid3) == einfo.traceback
92
93    def test_repair_uuid(self):
94        from celery.backends.amqp import repair_uuid
95        for i in range(10):
96            tid = uuid()
97            assert repair_uuid(tid.replace('-', '')) == tid
98
99    def test_expires_is_int(self):
100        b = self.create_backend(expires=48)
101        q = b._create_binding('x1y2z3')
102        assert q.expires == 48
103
104    def test_expires_is_float(self):
105        b = self.create_backend(expires=48.3)
106        q = b._create_binding('x1y2z3')
107        assert q.expires == 48.3
108
109    def test_expires_is_timedelta(self):
110        b = self.create_backend(expires=timedelta(minutes=1))
111        q = b._create_binding('x1y2z3')
112        assert q.expires == 60
113
114    @mock.sleepdeprived()
115    def test_store_result_retries(self):
116        iterations = [0]
117        stop_raising_at = [5]
118
119        def publish(*args, **kwargs):
120            if iterations[0] > stop_raising_at[0]:
121                return
122            iterations[0] += 1
123            raise KeyError('foo')
124
125        backend = AMQPBackend(self.app)
126        from celery.app.amqp import Producer
127        prod, Producer.publish = Producer.publish, publish
128        try:
129            with pytest.raises(KeyError):
130                backend.retry_policy['max_retries'] = None
131                backend.store_result('foo', 'bar', 'STARTED')
132
133            with pytest.raises(KeyError):
134                backend.retry_policy['max_retries'] = 10
135                backend.store_result('foo', 'bar', 'STARTED')
136        finally:
137            Producer.publish = prod
138
139    def test_poll_no_messages(self):
140        b = self.create_backend()
141        assert b.get_task_meta(uuid())['status'] == states.PENDING
142
143    @contextmanager
144    def _result_context(self):
145        results = Queue()
146
147        class Message(object):
148            acked = 0
149            requeued = 0
150
151            def __init__(self, **merge):
152                self.payload = dict({'status': states.STARTED,
153                                     'result': None}, **merge)
154                self.properties = {'correlation_id': merge.get('task_id')}
155                self.body = pickle.dumps(self.payload)
156                self.content_type = 'application/x-python-serialize'
157                self.content_encoding = 'binary'
158
159            def ack(self, *args, **kwargs):
160                self.acked += 1
161
162            def requeue(self, *args, **kwargs):
163                self.requeued += 1
164
165        class MockBinding(object):
166
167            def __init__(self, *args, **kwargs):
168                self.channel = Mock()
169
170            def __call__(self, *args, **kwargs):
171                return self
172
173            def declare(self):
174                pass
175
176            def get(self, no_ack=False, accept=None):
177                try:
178                    m = results.get(block=False)
179                    if m:
180                        m.accept = accept
181                    return m
182                except Empty:
183                    pass
184
185            def is_bound(self):
186                return True
187
188        class MockBackend(AMQPBackend):
189            Queue = MockBinding
190
191        backend = MockBackend(self.app, max_cached_results=100)
192        backend._republish = Mock()
193
194        yield results, backend, Message
195
196    def test_backlog_limit_exceeded(self):
197        with self._result_context() as (results, backend, Message):
198            for i in range(1001):
199                results.put(Message(task_id='id', status=states.RECEIVED))
200            with pytest.raises(backend.BacklogLimitExceeded):
201                backend.get_task_meta('id')
202
203    def test_poll_result(self):
204        with self._result_context() as (results, backend, Message):
205            tid = uuid()
206            # FFWD's to the latest state.
207            state_messages = [
208                Message(task_id=tid, status=states.RECEIVED, seq=1),
209                Message(task_id=tid, status=states.STARTED, seq=2),
210                Message(task_id=tid, status=states.FAILURE, seq=3),
211            ]
212            for state_message in state_messages:
213                results.put(state_message)
214            r1 = backend.get_task_meta(tid)
215            # FFWDs to the last state.
216            assert r1['status'] == states.FAILURE
217            assert r1['seq'] == 3
218
219            # Caches last known state.
220            tid = uuid()
221            results.put(Message(task_id=tid))
222            backend.get_task_meta(tid)
223            assert tid, backend._cache in 'Caches last known state'
224
225            assert state_messages[-1].requeued
226
227            # Returns cache if no new states.
228            results.queue.clear()
229            assert not results.qsize()
230            backend._cache[tid] = 'hello'
231            # returns cache if no new states.
232            assert backend.get_task_meta(tid) == 'hello'
233
234    def test_drain_events_decodes_exceptions_in_meta(self):
235        tid = uuid()
236        b = self.create_backend(serializer='json')
237        b.store_result(tid, RuntimeError('aap'), states.FAILURE)
238        result = AsyncResult(tid, backend=b)
239
240        with pytest.raises(Exception) as excinfo:
241            result.get()
242
243        assert excinfo.value.__class__.__name__ == 'RuntimeError'
244        assert str(excinfo.value) == 'aap'
245
246    def test_no_expires(self):
247        b = self.create_backend(expires=None)
248        app = self.app
249        app.conf.result_expires = None
250        b = self.create_backend(expires=None)
251        q = b._create_binding('foo')
252        assert q.expires is None
253
254    def test_process_cleanup(self):
255        self.create_backend().process_cleanup()
256
257    def test_reload_task_result(self):
258        with pytest.raises(NotImplementedError):
259            self.create_backend().reload_task_result('x')
260
261    def test_reload_group_result(self):
262        with pytest.raises(NotImplementedError):
263            self.create_backend().reload_group_result('x')
264
265    def test_save_group(self):
266        with pytest.raises(NotImplementedError):
267            self.create_backend().save_group('x', 'x')
268
269    def test_restore_group(self):
270        with pytest.raises(NotImplementedError):
271            self.create_backend().restore_group('x')
272
273    def test_delete_group(self):
274        with pytest.raises(NotImplementedError):
275            self.create_backend().delete_group('x')
276
277
278class test_AMQPBackend_result_extended:
279    def setup(self):
280        self.app.conf.result_extended = True
281
282    def test_store_result(self):
283        b = AMQPBackend(self.app)
284        tid = uuid()
285
286        request = Context(args=(1, 2, 3), kwargs={'foo': 'bar'},
287                          task_name='mytask', retries=2,
288                          hostname='celery@worker_1',
289                          delivery_info={'routing_key': 'celery'})
290
291        b.store_result(tid, {'fizz': 'buzz'}, states.SUCCESS, request=request)
292
293        meta = b.get_task_meta(tid)
294        assert meta == {
295            'args': [1, 2, 3],
296            'children': [],
297            'kwargs': {'foo': 'bar'},
298            'name': 'mytask',
299            'queue': 'celery',
300            'result': {'fizz': 'buzz'},
301            'retries': 2,
302            'status': 'SUCCESS',
303            'task_id': tid,
304            'traceback': None,
305            'worker': 'celery@worker_1',
306        }
307