1from __future__ import absolute_import, unicode_literals
2
3import sys
4import types
5from contextlib import contextmanager
6
7import pytest
8from case import ANY, Mock, call, patch, skip, sentinel
9from kombu.serialization import prepare_accept_content
10from kombu.utils.encoding import ensure_bytes
11
12import celery
13from celery import chord, group, signature, states, uuid
14from celery.app.task import Context, Task
15from celery.backends.base import (BaseBackend, DisabledBackend,
16                                  KeyValueStoreBackend, _nulldict)
17from celery.exceptions import ChordError, TimeoutError, BackendStoreError, BackendGetMetaError
18from celery.five import bytes_if_py2, items, range
19from celery.result import result_from_tuple
20from celery.utils import serialization
21from celery.utils.functional import pass1
22from celery.utils.serialization import UnpickleableExceptionWrapper
23from celery.utils.serialization import find_pickleable_exception as fnpe
24from celery.utils.serialization import get_pickleable_exception as gpe
25from celery.utils.serialization import subclass_exception
26
27
28class wrapobject(object):
29
30    def __init__(self, *args, **kwargs):
31        self.args = args
32
33
34class paramexception(Exception):
35
36    def __init__(self, param):
37        self.param = param
38
39
40class objectexception(object):
41    class Nested(Exception):
42        pass
43
44
45if sys.version_info[0] == 3 or getattr(sys, 'pypy_version_info', None):
46    Oldstyle = None
47else:
48    Oldstyle = types.ClassType(bytes_if_py2('Oldstyle'), (), {})
49Unpickleable = subclass_exception(
50    bytes_if_py2('Unpickleable'), KeyError, 'foo.module',
51)
52Impossible = subclass_exception(
53    bytes_if_py2('Impossible'), object, 'foo.module',
54)
55Lookalike = subclass_exception(
56    bytes_if_py2('Lookalike'), wrapobject, 'foo.module',
57)
58
59
60class test_nulldict:
61
62    def test_nulldict(self):
63        x = _nulldict()
64        x['foo'] = 1
65        x.update(foo=1, bar=2)
66        x.setdefault('foo', 3)
67
68
69class test_serialization:
70
71    def test_create_exception_cls(self):
72        assert serialization.create_exception_cls('FooError', 'm')
73        assert serialization.create_exception_cls('FooError', 'm', KeyError)
74
75
76class test_Backend_interface:
77
78    def setup(self):
79        self.app.conf.accept_content = ['json']
80
81    def test_accept_precedence(self):
82
83        # default is app.conf.accept_content
84        accept_content = self.app.conf.accept_content
85        b1 = BaseBackend(self.app)
86        assert prepare_accept_content(accept_content) == b1.accept
87
88        # accept parameter
89        b2 = BaseBackend(self.app, accept=['yaml'])
90        assert len(b2.accept) == 1
91        assert list(b2.accept)[0] == 'application/x-yaml'
92        assert prepare_accept_content(['yaml']) == b2.accept
93
94        # accept parameter over result_accept_content
95        self.app.conf.result_accept_content = ['json']
96        b3 = BaseBackend(self.app, accept=['yaml'])
97        assert len(b3.accept) == 1
98        assert list(b3.accept)[0] == 'application/x-yaml'
99        assert prepare_accept_content(['yaml']) == b3.accept
100
101        # conf.result_accept_content if specified
102        self.app.conf.result_accept_content = ['yaml']
103        b4 = BaseBackend(self.app)
104        assert len(b4.accept) == 1
105        assert list(b4.accept)[0] == 'application/x-yaml'
106        assert prepare_accept_content(['yaml']) == b4.accept
107
108    def test_get_result_meta(self):
109        b1 = BaseBackend(self.app)
110        meta = b1._get_result_meta(result={'fizz': 'buzz'},
111                                   state=states.SUCCESS, traceback=None,
112                                   request=None)
113        assert meta['status'] == states.SUCCESS
114        assert meta['result'] == {'fizz': 'buzz'}
115        assert meta['traceback'] is None
116
117        self.app.conf.result_extended = True
118        args = ['a', 'b']
119        kwargs = {'foo': 'bar'}
120        task_name = 'mytask'
121
122        b2 = BaseBackend(self.app)
123        request = Context(args=args, kwargs=kwargs,
124                          task=task_name,
125                          delivery_info={'routing_key': 'celery'})
126        meta = b2._get_result_meta(result={'fizz': 'buzz'},
127                                   state=states.SUCCESS, traceback=None,
128                                   request=request, encode=False)
129        assert meta['name'] == task_name
130        assert meta['args'] == args
131        assert meta['kwargs'] == kwargs
132        assert meta['queue'] == 'celery'
133
134    def test_get_result_meta_encoded(self):
135        self.app.conf.result_extended = True
136        b1 = BaseBackend(self.app)
137        args = ['a', 'b']
138        kwargs = {'foo': 'bar'}
139
140        request = Context(args=args, kwargs=kwargs)
141        meta = b1._get_result_meta(result={'fizz': 'buzz'},
142                                   state=states.SUCCESS, traceback=None,
143                                   request=request, encode=True)
144        assert meta['args'] == ensure_bytes(b1.encode(args))
145        assert meta['kwargs'] == ensure_bytes(b1.encode(kwargs))
146
147    def test_get_result_meta_with_none(self):
148        b1 = BaseBackend(self.app)
149        meta = b1._get_result_meta(result=None,
150                                   state=states.SUCCESS, traceback=None,
151                                   request=None)
152        assert meta['status'] == states.SUCCESS
153        assert meta['result'] is None
154        assert meta['traceback'] is None
155
156        self.app.conf.result_extended = True
157        args = ['a', 'b']
158        kwargs = {'foo': 'bar'}
159        task_name = 'mytask'
160
161        b2 = BaseBackend(self.app)
162        request = Context(args=args, kwargs=kwargs,
163                          task=task_name,
164                          delivery_info={'routing_key': 'celery'})
165        meta = b2._get_result_meta(result=None,
166                                   state=states.SUCCESS, traceback=None,
167                                   request=request, encode=False)
168        assert meta['name'] == task_name
169        assert meta['args'] == args
170        assert meta['kwargs'] == kwargs
171        assert meta['queue'] == 'celery'
172
173
174class test_BaseBackend_interface:
175
176    def setup(self):
177        self.b = BaseBackend(self.app)
178
179        @self.app.task(shared=False)
180        def callback(result):
181            pass
182
183        self.callback = callback
184
185    def test__forget(self):
186        with pytest.raises(NotImplementedError):
187            self.b._forget('SOMExx-N0Nex1stant-IDxx-')
188
189    def test_forget(self):
190        with pytest.raises(NotImplementedError):
191            self.b.forget('SOMExx-N0nex1stant-IDxx-')
192
193    def test_on_chord_part_return(self):
194        self.b.on_chord_part_return(None, None, None)
195
196    def test_apply_chord(self, unlock='celery.chord_unlock'):
197        self.app.tasks[unlock] = Mock()
198        header_result = self.app.GroupResult(
199            uuid(),
200            [self.app.AsyncResult(x) for x in range(3)],
201        )
202        self.b.apply_chord(header_result, self.callback.s())
203        assert self.app.tasks[unlock].apply_async.call_count
204
205    def test_chord_unlock_queue(self, unlock='celery.chord_unlock'):
206        self.app.tasks[unlock] = Mock()
207        header_result = self.app.GroupResult(
208            uuid(),
209            [self.app.AsyncResult(x) for x in range(3)],
210        )
211        body = self.callback.s()
212
213        self.b.apply_chord(header_result, body)
214        called_kwargs = self.app.tasks[unlock].apply_async.call_args[1]
215        assert called_kwargs['queue'] is None
216
217        self.b.apply_chord(header_result, body.set(queue='test_queue'))
218        called_kwargs = self.app.tasks[unlock].apply_async.call_args[1]
219        assert called_kwargs['queue'] == 'test_queue'
220
221        @self.app.task(shared=False, queue='test_queue_two')
222        def callback_queue(result):
223            pass
224
225        self.b.apply_chord(header_result, callback_queue.s())
226        called_kwargs = self.app.tasks[unlock].apply_async.call_args[1]
227        assert called_kwargs['queue'] == 'test_queue_two'
228
229
230class test_exception_pickle:
231
232    @skip.if_python3(reason='does not support old style classes')
233    @skip.if_pypy()
234    def test_oldstyle(self):
235        assert fnpe(Oldstyle())
236
237    def test_BaseException(self):
238        assert fnpe(Exception()) is None
239
240    def test_get_pickleable_exception(self):
241        exc = Exception('foo')
242        assert gpe(exc) == exc
243
244    def test_unpickleable(self):
245        assert isinstance(fnpe(Unpickleable()), KeyError)
246        assert fnpe(Impossible()) is None
247
248
249class test_prepare_exception:
250
251    def setup(self):
252        self.b = BaseBackend(self.app)
253
254    def test_unpickleable(self):
255        self.b.serializer = 'pickle'
256        x = self.b.prepare_exception(Unpickleable(1, 2, 'foo'))
257        assert isinstance(x, KeyError)
258        y = self.b.exception_to_python(x)
259        assert isinstance(y, KeyError)
260
261    def test_json_exception_arguments(self):
262        self.b.serializer = 'json'
263        x = self.b.prepare_exception(Exception(object))
264        assert x == {
265            'exc_message': serialization.ensure_serializable(
266                (object,), self.b.encode),
267            'exc_type': Exception.__name__,
268            'exc_module': Exception.__module__}
269        y = self.b.exception_to_python(x)
270        assert isinstance(y, Exception)
271
272    @pytest.mark.skipif(sys.version_info < (3, 3), reason='no qualname support')
273    def test_json_exception_nested(self):
274        self.b.serializer = 'json'
275        x = self.b.prepare_exception(objectexception.Nested('msg'))
276        assert x == {
277            'exc_message': ('msg',),
278            'exc_type': 'objectexception.Nested',
279            'exc_module': objectexception.Nested.__module__}
280        y = self.b.exception_to_python(x)
281        assert isinstance(y, objectexception.Nested)
282
283    def test_impossible(self):
284        self.b.serializer = 'pickle'
285        x = self.b.prepare_exception(Impossible())
286        assert isinstance(x, UnpickleableExceptionWrapper)
287        assert str(x)
288        y = self.b.exception_to_python(x)
289        assert y.__class__.__name__ == 'Impossible'
290        if sys.version_info < (2, 5):
291            assert y.__class__.__module__
292        else:
293            assert y.__class__.__module__ == 'foo.module'
294
295    def test_regular(self):
296        self.b.serializer = 'pickle'
297        x = self.b.prepare_exception(KeyError('baz'))
298        assert isinstance(x, KeyError)
299        y = self.b.exception_to_python(x)
300        assert isinstance(y, KeyError)
301
302    def test_unicode_message(self):
303        message = u'\u03ac'
304        x = self.b.prepare_exception(Exception(message))
305        assert x == {'exc_message': (message,),
306                     'exc_type': Exception.__name__,
307                     'exc_module': Exception.__module__}
308
309
310class KVBackend(KeyValueStoreBackend):
311    mget_returns_dict = False
312
313    def __init__(self, app, *args, **kwargs):
314        self.db = {}
315        super(KVBackend, self).__init__(app, *args, **kwargs)
316
317    def get(self, key):
318        return self.db.get(key)
319
320    def _set_with_state(self, key, value, state):
321        self.db[key] = value
322
323    def mget(self, keys):
324        if self.mget_returns_dict:
325            return {key: self.get(key) for key in keys}
326        else:
327            return [self.get(k) for k in keys]
328
329    def delete(self, key):
330        self.db.pop(key, None)
331
332
333class DictBackend(BaseBackend):
334
335    def __init__(self, *args, **kwargs):
336        BaseBackend.__init__(self, *args, **kwargs)
337        self._data = {'can-delete': {'result': 'foo'}}
338
339    def _restore_group(self, group_id):
340        if group_id == 'exists':
341            return {'result': 'group'}
342
343    def _get_task_meta_for(self, task_id):
344        if task_id == 'task-exists':
345            return {'result': 'task'}
346
347    def _delete_group(self, group_id):
348        self._data.pop(group_id, None)
349
350
351class test_BaseBackend_dict:
352
353    def setup(self):
354        self.b = DictBackend(app=self.app)
355
356        @self.app.task(shared=False, bind=True)
357        def bound_errback(self, result):
358            pass
359
360        @self.app.task(shared=False)
361        def errback(arg1, arg2):
362            errback.last_result = arg1 + arg2
363
364        self.bound_errback = bound_errback
365        self.errback = errback
366
367    def test_delete_group(self):
368        self.b.delete_group('can-delete')
369        assert 'can-delete' not in self.b._data
370
371    def test_prepare_exception_json(self):
372        x = DictBackend(self.app, serializer='json')
373        e = x.prepare_exception(KeyError('foo'))
374        assert 'exc_type' in e
375        e = x.exception_to_python(e)
376        assert e.__class__.__name__ == 'KeyError'
377        assert str(e).strip('u') == "'foo'"
378
379    def test_save_group(self):
380        b = BaseBackend(self.app)
381        b._save_group = Mock()
382        b.save_group('foofoo', 'xxx')
383        b._save_group.assert_called_with('foofoo', 'xxx')
384
385    def test_add_to_chord_interface(self):
386        b = BaseBackend(self.app)
387        with pytest.raises(NotImplementedError):
388            b.add_to_chord('group_id', 'sig')
389
390    def test_forget_interface(self):
391        b = BaseBackend(self.app)
392        with pytest.raises(NotImplementedError):
393            b.forget('foo')
394
395    def test_restore_group(self):
396        assert self.b.restore_group('missing') is None
397        assert self.b.restore_group('missing') is None
398        assert self.b.restore_group('exists') == 'group'
399        assert self.b.restore_group('exists') == 'group'
400        assert self.b.restore_group('exists', cache=False) == 'group'
401
402    def test_reload_group_result(self):
403        self.b._cache = {}
404        self.b.reload_group_result('exists')
405        self.b._cache['exists'] = {'result': 'group'}
406
407    def test_reload_task_result(self):
408        self.b._cache = {}
409        self.b.reload_task_result('task-exists')
410        self.b._cache['task-exists'] = {'result': 'task'}
411
412    def test_fail_from_current_stack(self):
413        import inspect
414        self.b.mark_as_failure = Mock()
415        frame_list = []
416
417        if (2, 7, 0) <= sys.version_info < (3, 0, 0):
418            sys.exc_clear = Mock()
419
420        def raise_dummy():
421            frame_str_temp = str(inspect.currentframe().__repr__)
422            frame_list.append(frame_str_temp)
423            raise KeyError('foo')
424        try:
425            raise_dummy()
426        except KeyError as exc:
427            self.b.fail_from_current_stack('task_id')
428            self.b.mark_as_failure.assert_called()
429            args = self.b.mark_as_failure.call_args[0]
430            assert args[0] == 'task_id'
431            assert args[1] is exc
432            assert args[2]
433
434            if sys.version_info >= (3, 5, 0):
435                tb_ = exc.__traceback__
436                while tb_ is not None:
437                    if str(tb_.tb_frame.__repr__) == frame_list[0]:
438                        assert len(tb_.tb_frame.f_locals) == 0
439                    tb_ = tb_.tb_next
440            elif (2, 7, 0) <= sys.version_info < (3, 0, 0):
441                sys.exc_clear.assert_called()
442
443    def test_prepare_value_serializes_group_result(self):
444        self.b.serializer = 'json'
445        g = self.app.GroupResult('group_id', [self.app.AsyncResult('foo')])
446        v = self.b.prepare_value(g)
447        assert isinstance(v, (list, tuple))
448        assert result_from_tuple(v, app=self.app) == g
449
450        v2 = self.b.prepare_value(g[0])
451        assert isinstance(v2, (list, tuple))
452        assert result_from_tuple(v2, app=self.app) == g[0]
453
454        self.b.serializer = 'pickle'
455        assert isinstance(self.b.prepare_value(g), self.app.GroupResult)
456
457    def test_is_cached(self):
458        b = BaseBackend(app=self.app, max_cached_results=1)
459        b._cache['foo'] = 1
460        assert b.is_cached('foo')
461        assert not b.is_cached('false')
462
463    def test_mark_as_done__chord(self):
464        b = BaseBackend(app=self.app)
465        b._store_result = Mock()
466        request = Mock(name='request')
467        b.on_chord_part_return = Mock()
468        b.mark_as_done('id', 10, request=request)
469        b.on_chord_part_return.assert_called_with(request, states.SUCCESS, 10)
470
471    def test_mark_as_failure__bound_errback_eager(self):
472        b = BaseBackend(app=self.app)
473        b._store_result = Mock()
474        request = Mock(name='request')
475        request.delivery_info = {
476            'is_eager': True
477        }
478        request.errbacks = [
479            self.bound_errback.subtask(args=[1], immutable=True)]
480        exc = KeyError()
481        group = self.patching('celery.backends.base.group')
482        b.mark_as_failure('id', exc, request=request)
483        group.assert_called_with(request.errbacks, app=self.app)
484        group.return_value.apply.assert_called_with(
485            (request.id, ), parent_id=request.id, root_id=request.root_id)
486
487    def test_mark_as_failure__bound_errback(self):
488        b = BaseBackend(app=self.app)
489        b._store_result = Mock()
490        request = Mock(name='request')
491        request.delivery_info = {}
492        request.errbacks = [
493            self.bound_errback.subtask(args=[1], immutable=True)]
494        exc = KeyError()
495        group = self.patching('celery.backends.base.group')
496        b.mark_as_failure('id', exc, request=request)
497        group.assert_called_with(request.errbacks, app=self.app)
498        group.return_value.apply_async.assert_called_with(
499            (request.id, ), parent_id=request.id, root_id=request.root_id)
500
501    def test_mark_as_failure__errback(self):
502        b = BaseBackend(app=self.app)
503        b._store_result = Mock()
504        request = Mock(name='request')
505        request.errbacks = [self.errback.subtask(args=[2, 3], immutable=True)]
506        exc = KeyError()
507        b.mark_as_failure('id', exc, request=request)
508        assert self.errback.last_result == 5
509
510    @patch('celery.backends.base.group')
511    def test_class_based_task_can_be_used_as_error_callback(self, mock_group):
512        b = BaseBackend(app=self.app)
513        b._store_result = Mock()
514
515        class TaskBasedClass(Task):
516            def run(self):
517                pass
518
519        TaskBasedClass = self.app.register_task(TaskBasedClass())
520
521        request = Mock(name='request')
522        request.errbacks = [TaskBasedClass.subtask(args=[], immutable=True)]
523        exc = KeyError()
524        b.mark_as_failure('id', exc, request=request)
525        mock_group.assert_called_once_with(request.errbacks, app=self.app)
526
527    @patch('celery.backends.base.group')
528    def test_unregistered_task_can_be_used_as_error_callback(self, mock_group):
529        b = BaseBackend(app=self.app)
530        b._store_result = Mock()
531
532        request = Mock(name='request')
533        request.errbacks = [signature('doesnotexist',
534                                      immutable=True)]
535        exc = KeyError()
536        b.mark_as_failure('id', exc, request=request)
537        mock_group.assert_called_once_with(request.errbacks, app=self.app)
538
539    def test_mark_as_failure__chord(self):
540        b = BaseBackend(app=self.app)
541        b._store_result = Mock()
542        request = Mock(name='request')
543        request.errbacks = []
544        b.on_chord_part_return = Mock()
545        exc = KeyError()
546        b.mark_as_failure('id', exc, request=request)
547        b.on_chord_part_return.assert_called_with(request, states.FAILURE, exc)
548
549    def test_mark_as_revoked__chord(self):
550        b = BaseBackend(app=self.app)
551        b._store_result = Mock()
552        request = Mock(name='request')
553        request.errbacks = []
554        b.on_chord_part_return = Mock()
555        b.mark_as_revoked('id', 'revoked', request=request)
556        b.on_chord_part_return.assert_called_with(request, states.REVOKED, ANY)
557
558    def test_chord_error_from_stack_raises(self):
559        b = BaseBackend(app=self.app)
560        exc = KeyError()
561        callback = Mock(name='callback')
562        callback.options = {'link_error': []}
563        task = self.app.tasks[callback.task] = Mock()
564        b.fail_from_current_stack = Mock()
565        group = self.patching('celery.group')
566        group.side_effect = exc
567        b.chord_error_from_stack(callback, exc=ValueError())
568        task.backend.fail_from_current_stack.assert_called_with(
569            callback.id, exc=exc)
570
571    def test_exception_to_python_when_None(self):
572        b = BaseBackend(app=self.app)
573        assert b.exception_to_python(None) is None
574
575    def test_exception_to_python_when_attribute_exception(self):
576        b = BaseBackend(app=self.app)
577        test_exception = {'exc_type': 'AttributeDoesNotExist',
578                          'exc_module': 'celery',
579                          'exc_message': ['Raise Custom Message']}
580
581        result_exc = b.exception_to_python(test_exception)
582        assert str(result_exc) == 'Raise Custom Message'
583
584    def test_exception_to_python_when_type_error(self):
585        b = BaseBackend(app=self.app)
586        celery.TestParamException = paramexception
587        test_exception = {'exc_type': 'TestParamException',
588                          'exc_module': 'celery',
589                          'exc_message': []}
590
591        result_exc = b.exception_to_python(test_exception)
592        del celery.TestParamException
593        assert str(result_exc) == "<class 't.unit.backends.test_base.paramexception'>([])"
594
595    def test_wait_for__on_interval(self):
596        self.patching('time.sleep')
597        b = BaseBackend(app=self.app)
598        b._get_task_meta_for = Mock()
599        b._get_task_meta_for.return_value = {'status': states.PENDING}
600        callback = Mock(name='callback')
601        with pytest.raises(TimeoutError):
602            b.wait_for(task_id='1', on_interval=callback, timeout=1)
603        callback.assert_called_with()
604
605        b._get_task_meta_for.return_value = {'status': states.SUCCESS}
606        b.wait_for(task_id='1', timeout=None)
607
608    def test_get_children(self):
609        b = BaseBackend(app=self.app)
610        b._get_task_meta_for = Mock()
611        b._get_task_meta_for.return_value = {}
612        assert b.get_children('id') is None
613        b._get_task_meta_for.return_value = {'children': 3}
614        assert b.get_children('id') == 3
615
616
617class test_KeyValueStoreBackend:
618
619    def setup(self):
620        self.b = KVBackend(app=self.app)
621
622    def test_on_chord_part_return(self):
623        assert not self.b.implements_incr
624        self.b.on_chord_part_return(None, None, None)
625
626    def test_get_store_delete_result(self):
627        tid = uuid()
628        self.b.mark_as_done(tid, 'Hello world')
629        assert self.b.get_result(tid) == 'Hello world'
630        assert self.b.get_state(tid) == states.SUCCESS
631        self.b.forget(tid)
632        assert self.b.get_state(tid) == states.PENDING
633
634    @pytest.mark.parametrize('serializer',
635                             ['json', 'pickle', 'yaml', 'msgpack'])
636    def test_store_result_parent_id(self, serializer):
637        self.app.conf.accept_content = ('json', serializer)
638        self.b = KVBackend(app=self.app, serializer=serializer)
639        tid = uuid()
640        pid = uuid()
641        state = 'SUCCESS'
642        result = 10
643        request = Context(parent_id=pid)
644        self.b.store_result(
645            tid, state=state, result=result, request=request,
646        )
647        stored_meta = self.b.decode(self.b.get(self.b.get_key_for_task(tid)))
648        assert stored_meta['parent_id'] == request.parent_id
649
650    def test_store_result_group_id(self):
651        tid = uuid()
652        state = 'SUCCESS'
653        result = 10
654        request = Context(group='gid', children=[])
655        self.b.store_result(
656            tid, state=state, result=result, request=request,
657        )
658        stored_meta = self.b.decode(self.b.get(self.b.get_key_for_task(tid)))
659        assert stored_meta['group_id'] == request.group
660
661    def test_store_result_race_second_write_should_ignore_if_previous_success(self):
662        tid = uuid()
663        state = 'SUCCESS'
664        result = 10
665        request = Context(group='gid', children=[])
666        self.b.store_result(
667            tid, state=state, result=result, request=request,
668        )
669        self.b.store_result(
670            tid, state=states.FAILURE, result=result, request=request,
671        )
672        stored_meta = self.b.decode(self.b.get(self.b.get_key_for_task(tid)))
673        assert stored_meta['status'] == states.SUCCESS
674
675    def test_strip_prefix(self):
676        x = self.b.get_key_for_task('x1b34')
677        assert self.b._strip_prefix(x) == 'x1b34'
678        assert self.b._strip_prefix('x1b34') == 'x1b34'
679
680    def test_get_many(self):
681        for is_dict in True, False:
682            self.b.mget_returns_dict = is_dict
683            ids = {uuid(): i for i in range(10)}
684            for id, i in items(ids):
685                self.b.mark_as_done(id, i)
686            it = self.b.get_many(list(ids), interval=0.01)
687            for i, (got_id, got_state) in enumerate(it):
688                assert got_state['result'] == ids[got_id]
689            assert i == 9
690            assert list(self.b.get_many(list(ids), interval=0.01))
691
692            self.b._cache.clear()
693            callback = Mock(name='callback')
694            it = self.b.get_many(
695                list(ids),
696                on_message=callback,
697                interval=0.05
698            )
699            for i, (got_id, got_state) in enumerate(it):
700                assert got_state['result'] == ids[got_id]
701            assert i == 9
702            assert list(
703                self.b.get_many(list(ids), interval=0.01)
704            )
705            callback.assert_has_calls([
706                call(ANY) for id in ids
707            ])
708
709    def test_get_many_times_out(self):
710        tasks = [uuid() for _ in range(4)]
711        self.b._cache[tasks[1]] = {'status': 'PENDING'}
712        with pytest.raises(self.b.TimeoutError):
713            list(self.b.get_many(tasks, timeout=0.01, interval=0.01))
714
715    def test_get_many_passes_ready_states(self):
716        tasks_length = 10
717        ready_states = frozenset({states.SUCCESS})
718
719        self.b._cache.clear()
720        ids = {uuid(): i for i in range(tasks_length)}
721        for id, i in items(ids):
722            if i % 2 == 0:
723                self.b.mark_as_done(id, i)
724            else:
725                self.b.mark_as_failure(id, Exception())
726
727        it = self.b.get_many(list(ids), interval=0.01, max_iterations=1, READY_STATES=ready_states)
728        it_list = list(it)
729
730        assert all([got_state['status'] in ready_states for (got_id, got_state) in it_list])
731        assert len(it_list) == tasks_length / 2
732
733    def test_chord_part_return_no_gid(self):
734        self.b.implements_incr = True
735        task = Mock()
736        state = 'SUCCESS'
737        result = 10
738        task.request.group = None
739        self.b.get_key_for_chord = Mock()
740        self.b.get_key_for_chord.side_effect = AssertionError(
741            'should not get here',
742        )
743        assert self.b.on_chord_part_return(
744            task.request, state, result) is None
745
746    @patch('celery.backends.base.GroupResult')
747    @patch('celery.backends.base.maybe_signature')
748    def test_chord_part_return_restore_raises(self, maybe_signature,
749                                              GroupResult):
750        self.b.implements_incr = True
751        GroupResult.restore.side_effect = KeyError()
752        self.b.chord_error_from_stack = Mock()
753        callback = Mock(name='callback')
754        request = Mock(name='request')
755        request.group = 'gid'
756        maybe_signature.return_value = callback
757        self.b.on_chord_part_return(request, states.SUCCESS, 10)
758        self.b.chord_error_from_stack.assert_called_with(
759            callback, ANY,
760        )
761
762    @patch('celery.backends.base.GroupResult')
763    @patch('celery.backends.base.maybe_signature')
764    def test_chord_part_return_restore_empty(self, maybe_signature,
765                                             GroupResult):
766        self.b.implements_incr = True
767        GroupResult.restore.return_value = None
768        self.b.chord_error_from_stack = Mock()
769        callback = Mock(name='callback')
770        request = Mock(name='request')
771        request.group = 'gid'
772        maybe_signature.return_value = callback
773        self.b.on_chord_part_return(request, states.SUCCESS, 10)
774        self.b.chord_error_from_stack.assert_called_with(
775            callback, ANY,
776        )
777
778    def test_filter_ready(self):
779        self.b.decode_result = Mock()
780        self.b.decode_result.side_effect = pass1
781        assert len(list(self.b._filter_ready([
782            (1, {'status': states.RETRY}),
783            (2, {'status': states.FAILURE}),
784            (3, {'status': states.SUCCESS}),
785        ]))) == 2
786
787    @contextmanager
788    def _chord_part_context(self, b):
789
790        @self.app.task(shared=False)
791        def callback(result):
792            pass
793
794        b.implements_incr = True
795        b.client = Mock()
796        with patch('celery.backends.base.GroupResult') as GR:
797            deps = GR.restore.return_value = Mock(name='DEPS')
798            deps.__len__ = Mock()
799            deps.__len__.return_value = 10
800            b.incr = Mock()
801            b.incr.return_value = 10
802            b.expire = Mock()
803            task = Mock()
804            task.request.group = 'grid'
805            cb = task.request.chord = callback.s()
806            task.request.chord.freeze()
807            callback.backend = b
808            callback.backend.fail_from_current_stack = Mock()
809            yield task, deps, cb
810
811    def test_chord_part_return_propagate_set(self):
812        with self._chord_part_context(self.b) as (task, deps, _):
813            self.b.on_chord_part_return(task.request, 'SUCCESS', 10)
814            self.b.expire.assert_not_called()
815            deps.delete.assert_called_with()
816            deps.join_native.assert_called_with(propagate=True, timeout=3.0)
817
818    def test_chord_part_return_propagate_default(self):
819        with self._chord_part_context(self.b) as (task, deps, _):
820            self.b.on_chord_part_return(task.request, 'SUCCESS', 10)
821            self.b.expire.assert_not_called()
822            deps.delete.assert_called_with()
823            deps.join_native.assert_called_with(propagate=True, timeout=3.0)
824
825    def test_chord_part_return_join_raises_internal(self):
826        with self._chord_part_context(self.b) as (task, deps, callback):
827            deps._failed_join_report = lambda: iter([])
828            deps.join_native.side_effect = KeyError('foo')
829            self.b.on_chord_part_return(task.request, 'SUCCESS', 10)
830            self.b.fail_from_current_stack.assert_called()
831            args = self.b.fail_from_current_stack.call_args
832            exc = args[1]['exc']
833            assert isinstance(exc, ChordError)
834            assert 'foo' in str(exc)
835
836    def test_chord_part_return_join_raises_task(self):
837        b = KVBackend(serializer='pickle', app=self.app)
838        with self._chord_part_context(b) as (task, deps, callback):
839            deps._failed_join_report = lambda: iter([
840                self.app.AsyncResult('culprit'),
841            ])
842            deps.join_native.side_effect = KeyError('foo')
843            b.on_chord_part_return(task.request, 'SUCCESS', 10)
844            b.fail_from_current_stack.assert_called()
845            args = b.fail_from_current_stack.call_args
846            exc = args[1]['exc']
847            assert isinstance(exc, ChordError)
848            assert 'Dependency culprit raised' in str(exc)
849
850    def test_restore_group_from_json(self):
851        b = KVBackend(serializer='json', app=self.app)
852        g = self.app.GroupResult(
853            'group_id',
854            [self.app.AsyncResult('a'), self.app.AsyncResult('b')],
855        )
856        b._save_group(g.id, g)
857        g2 = b._restore_group(g.id)['result']
858        assert g2 == g
859
860    def test_restore_group_from_pickle(self):
861        b = KVBackend(serializer='pickle', app=self.app)
862        g = self.app.GroupResult(
863            'group_id',
864            [self.app.AsyncResult('a'), self.app.AsyncResult('b')],
865        )
866        b._save_group(g.id, g)
867        g2 = b._restore_group(g.id)['result']
868        assert g2 == g
869
870    def test_chord_apply_fallback(self):
871        self.b.implements_incr = False
872        self.b.fallback_chord_unlock = Mock()
873        header_result = self.app.GroupResult(
874            'group_id',
875            [self.app.AsyncResult(x) for x in range(3)],
876        )
877        self.b.apply_chord(
878            header_result, 'body', foo=1,
879        )
880        self.b.fallback_chord_unlock.assert_called_with(
881            header_result, 'body', foo=1,
882        )
883
884    def test_get_missing_meta(self):
885        assert self.b.get_result('xxx-missing') is None
886        assert self.b.get_state('xxx-missing') == states.PENDING
887
888    def test_save_restore_delete_group(self):
889        tid = uuid()
890        tsr = self.app.GroupResult(
891            tid, [self.app.AsyncResult(uuid()) for _ in range(10)],
892        )
893        self.b.save_group(tid, tsr)
894        self.b.restore_group(tid)
895        assert self.b.restore_group(tid) == tsr
896        self.b.delete_group(tid)
897        assert self.b.restore_group(tid) is None
898
899    def test_restore_missing_group(self):
900        assert self.b.restore_group('xxx-nonexistant') is None
901
902
903class test_KeyValueStoreBackend_interface:
904
905    def test_get(self):
906        with pytest.raises(NotImplementedError):
907            KeyValueStoreBackend(self.app).get('a')
908
909    def test_set(self):
910        with pytest.raises(NotImplementedError):
911            KeyValueStoreBackend(self.app)._set_with_state('a', 1, states.SUCCESS)
912
913    def test_incr(self):
914        with pytest.raises(NotImplementedError):
915            KeyValueStoreBackend(self.app).incr('a')
916
917    def test_cleanup(self):
918        assert not KeyValueStoreBackend(self.app).cleanup()
919
920    def test_delete(self):
921        with pytest.raises(NotImplementedError):
922            KeyValueStoreBackend(self.app).delete('a')
923
924    def test_mget(self):
925        with pytest.raises(NotImplementedError):
926            KeyValueStoreBackend(self.app).mget(['a'])
927
928    def test_forget(self):
929        with pytest.raises(NotImplementedError):
930            KeyValueStoreBackend(self.app).forget('a')
931
932
933class test_DisabledBackend:
934
935    def test_store_result(self):
936        DisabledBackend(self.app).store_result()
937
938    def test_is_disabled(self):
939        with pytest.raises(NotImplementedError):
940            DisabledBackend(self.app).get_state('foo')
941
942    def test_as_uri(self):
943        assert DisabledBackend(self.app).as_uri() == 'disabled://'
944
945    @pytest.mark.celery(result_backend='disabled')
946    def test_chord_raises_error(self):
947        with pytest.raises(NotImplementedError):
948            chord(self.add.s(i, i) for i in range(10))(self.add.s([2]))
949
950    @pytest.mark.celery(result_backend='disabled')
951    def test_chain_with_chord_raises_error(self):
952        with pytest.raises(NotImplementedError):
953            (self.add.s(2, 2) |
954             group(self.add.s(2, 2),
955                   self.add.s(5, 6)) | self.add.s()).delay()
956
957
958class test_as_uri:
959
960    def setup(self):
961        self.b = BaseBackend(
962            app=self.app,
963            url='sch://uuuu:pwpw@hostname.dom'
964        )
965
966    def test_as_uri_include_password(self):
967        assert self.b.as_uri(True) == self.b.url
968
969    def test_as_uri_exclude_password(self):
970        assert self.b.as_uri() == 'sch://uuuu:**@hostname.dom/'
971
972
973class test_backend_retries:
974
975    def test_should_retry_exception(self):
976        assert not BaseBackend(app=self.app).exception_safe_to_retry(Exception("test"))
977
978    def test_get_failed_never_retries(self):
979        self.app.conf.result_backend_always_retry, prev = False, self.app.conf.result_backend_always_retry
980
981        expected_exc = Exception("failed")
982        try:
983            b = BaseBackend(app=self.app)
984            b.exception_safe_to_retry = lambda exc: True
985            b._sleep = Mock()
986            b._get_task_meta_for = Mock()
987            b._get_task_meta_for.side_effect = [
988                expected_exc,
989                {'status': states.SUCCESS, 'result': 42}
990            ]
991            try:
992                b.get_task_meta(sentinel.task_id)
993                assert False
994            except Exception as exc:
995                assert b._sleep.call_count == 0
996                assert exc == expected_exc
997        finally:
998            self.app.conf.result_backend_always_retry = prev
999
1000    def test_get_with_retries(self):
1001        self.app.conf.result_backend_always_retry, prev = True, self.app.conf.result_backend_always_retry
1002
1003        try:
1004            b = BaseBackend(app=self.app)
1005            b.exception_safe_to_retry = lambda exc: True
1006            b._sleep = Mock()
1007            b._get_task_meta_for = Mock()
1008            b._get_task_meta_for.side_effect = [
1009                Exception("failed"),
1010                {'status': states.SUCCESS, 'result': 42}
1011            ]
1012            res = b.get_task_meta(sentinel.task_id)
1013            assert res == {'status': states.SUCCESS, 'result': 42}
1014            assert b._sleep.call_count == 1
1015        finally:
1016            self.app.conf.result_backend_always_retry = prev
1017
1018    def test_get_reaching_max_retries(self):
1019        self.app.conf.result_backend_always_retry, prev = True, self.app.conf.result_backend_always_retry
1020        self.app.conf.result_backend_max_retries, prev_max_retries = 0, self.app.conf.result_backend_max_retries
1021
1022        try:
1023            b = BaseBackend(app=self.app)
1024            b.exception_safe_to_retry = lambda exc: True
1025            b._sleep = Mock()
1026            b._get_task_meta_for = Mock()
1027            b._get_task_meta_for.side_effect = [
1028                Exception("failed"),
1029                {'status': states.SUCCESS, 'result': 42}
1030            ]
1031            try:
1032                b.get_task_meta(sentinel.task_id)
1033                assert False
1034            except BackendGetMetaError:
1035                assert b._sleep.call_count == 0
1036        finally:
1037            self.app.conf.result_backend_always_retry = prev
1038            self.app.conf.result_backend_max_retries = prev_max_retries
1039
1040    def test_get_unsafe_exception(self):
1041        self.app.conf.result_backend_always_retry, prev = True, self.app.conf.result_backend_always_retry
1042
1043        expected_exc = Exception("failed")
1044        try:
1045            b = BaseBackend(app=self.app)
1046            b._sleep = Mock()
1047            b._get_task_meta_for = Mock()
1048            b._get_task_meta_for.side_effect = [
1049                expected_exc,
1050                {'status': states.SUCCESS, 'result': 42}
1051            ]
1052            try:
1053                b.get_task_meta(sentinel.task_id)
1054                assert False
1055            except Exception as exc:
1056                assert b._sleep.call_count == 0
1057                assert exc == expected_exc
1058        finally:
1059            self.app.conf.result_backend_always_retry = prev
1060
1061    def test_store_result_never_retries(self):
1062        self.app.conf.result_backend_always_retry, prev = False, self.app.conf.result_backend_always_retry
1063
1064        expected_exc = Exception("failed")
1065        try:
1066            b = BaseBackend(app=self.app)
1067            b.exception_safe_to_retry = lambda exc: True
1068            b._sleep = Mock()
1069            b._get_task_meta_for = Mock()
1070            b._get_task_meta_for.return_value = {
1071                'status': states.RETRY, 'result': {"exc_type": "Exception", "exc_message": ["failed"], "exc_module": "builtins"}
1072            }
1073            b._store_result = Mock()
1074            b._store_result.side_effect = [
1075                expected_exc,
1076                42
1077            ]
1078            try:
1079                b.store_result(sentinel.task_id, 42, states.SUCCESS)
1080            except Exception as exc:
1081                assert b._sleep.call_count == 0
1082                assert exc == expected_exc
1083        finally:
1084            self.app.conf.result_backend_always_retry = prev
1085
1086    def test_store_result_with_retries(self):
1087        self.app.conf.result_backend_always_retry, prev = True, self.app.conf.result_backend_always_retry
1088
1089        try:
1090            b = BaseBackend(app=self.app)
1091            b.exception_safe_to_retry = lambda exc: True
1092            b._sleep = Mock()
1093            b._get_task_meta_for = Mock()
1094            b._get_task_meta_for.return_value = {
1095                'status': states.RETRY, 'result': {"exc_type": "Exception", "exc_message": ["failed"], "exc_module": "builtins"}
1096            }
1097            b._store_result = Mock()
1098            b._store_result.side_effect = [
1099                Exception("failed"),
1100                42
1101            ]
1102            res = b.store_result(sentinel.task_id, 42, states.SUCCESS)
1103            assert res == 42
1104            assert b._sleep.call_count == 1
1105        finally:
1106            self.app.conf.result_backend_always_retry = prev
1107
1108    def test_store_result_reaching_max_retries(self):
1109        self.app.conf.result_backend_always_retry, prev = True, self.app.conf.result_backend_always_retry
1110        self.app.conf.result_backend_max_retries, prev_max_retries = 0, self.app.conf.result_backend_max_retries
1111
1112        try:
1113            b = BaseBackend(app=self.app)
1114            b.exception_safe_to_retry = lambda exc: True
1115            b._sleep = Mock()
1116            b._get_task_meta_for = Mock()
1117            b._get_task_meta_for.return_value = {
1118                'status': states.RETRY, 'result': {"exc_type": "Exception", "exc_message": ["failed"], "exc_module": "builtins"}
1119            }
1120            b._store_result = Mock()
1121            b._store_result.side_effect = [
1122                Exception("failed"),
1123                42
1124            ]
1125            try:
1126                b.store_result(sentinel.task_id, 42, states.SUCCESS)
1127                assert False
1128            except BackendStoreError:
1129                assert b._sleep.call_count == 0
1130        finally:
1131            self.app.conf.result_backend_always_retry = prev
1132            self.app.conf.result_backend_max_retries = prev_max_retries
1133