1# -*- coding: utf-8 -*-
2"""Result backend base classes.
3
4- :class:`BaseBackend` defines the interface.
5
6- :class:`KeyValueStoreBackend` is a common base class
7    using K/V semantics like _get and _put.
8"""
9from __future__ import absolute_import, unicode_literals
10
11from datetime import datetime, timedelta
12import sys
13import time
14import warnings
15from collections import namedtuple
16from functools import partial
17from weakref import WeakValueDictionary
18
19from billiard.einfo import ExceptionInfo
20from kombu.serialization import dumps, loads, prepare_accept_content
21from kombu.serialization import registry as serializer_registry
22from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
23from kombu.utils.url import maybe_sanitize_url
24
25import celery.exceptions
26from celery import current_app, group, maybe_signature, states
27from celery._state import get_current_task
28from celery.exceptions import (ChordError, ImproperlyConfigured,
29                               NotRegistered, TaskRevokedError, TimeoutError,
30                               BackendGetMetaError, BackendStoreError)
31from celery.five import PY3, items
32from celery.result import (GroupResult, ResultBase, ResultSet,
33                           allow_join_result, result_from_tuple)
34from celery.utils.collections import BufferMap
35from celery.utils.functional import LRUCache, arity_greater
36from celery.utils.log import get_logger
37from celery.utils.serialization import (create_exception_cls,
38                                        ensure_serializable,
39                                        get_pickleable_exception,
40                                        get_pickled_exception,
41                                        raise_with_context)
42from celery.utils.time import get_exponential_backoff_interval
43
44__all__ = ('BaseBackend', 'KeyValueStoreBackend', 'DisabledBackend')
45
46EXCEPTION_ABLE_CODECS = frozenset({'pickle'})
47
48logger = get_logger(__name__)
49
50MESSAGE_BUFFER_MAX = 8192
51
52pending_results_t = namedtuple('pending_results_t', (
53    'concrete', 'weak',
54))
55
56E_NO_BACKEND = """
57No result backend is configured.
58Please see the documentation for more information.
59"""
60
61E_CHORD_NO_BACKEND = """
62Starting chords requires a result backend to be configured.
63
64Note that a group chained with a task is also upgraded to be a chord,
65as this pattern requires synchronization.
66
67Result backends that supports chords: Redis, Database, Memcached, and more.
68"""
69
70
71def unpickle_backend(cls, args, kwargs):
72    """Return an unpickled backend."""
73    return cls(*args, app=current_app._get_current_object(), **kwargs)
74
75
76class _nulldict(dict):
77    def ignore(self, *a, **kw):
78        pass
79
80    __setitem__ = update = setdefault = ignore
81
82
83class Backend(object):
84    READY_STATES = states.READY_STATES
85    UNREADY_STATES = states.UNREADY_STATES
86    EXCEPTION_STATES = states.EXCEPTION_STATES
87
88    TimeoutError = TimeoutError
89
90    #: Time to sleep between polling each individual item
91    #: in `ResultSet.iterate`. as opposed to the `interval`
92    #: argument which is for each pass.
93    subpolling_interval = None
94
95    #: If true the backend must implement :meth:`get_many`.
96    supports_native_join = False
97
98    #: If true the backend must automatically expire results.
99    #: The daily backend_cleanup periodic task won't be triggered
100    #: in this case.
101    supports_autoexpire = False
102
103    #: Set to true if the backend is persistent by default.
104    persistent = True
105
106    retry_policy = {
107        'max_retries': 20,
108        'interval_start': 0,
109        'interval_step': 1,
110        'interval_max': 1,
111    }
112
113    def __init__(self, app,
114                 serializer=None, max_cached_results=None, accept=None,
115                 expires=None, expires_type=None, url=None, **kwargs):
116        self.app = app
117        conf = self.app.conf
118        self.serializer = serializer or conf.result_serializer
119        (self.content_type,
120         self.content_encoding,
121         self.encoder) = serializer_registry._encoders[self.serializer]
122        cmax = max_cached_results or conf.result_cache_max
123        self._cache = _nulldict() if cmax == -1 else LRUCache(limit=cmax)
124
125        self.expires = self.prepare_expires(expires, expires_type)
126
127        # precedence: accept, conf.result_accept_content, conf.accept_content
128        self.accept = conf.result_accept_content if accept is None else accept
129        self.accept = conf.accept_content if self.accept is None else self.accept  # noqa: E501
130        self.accept = prepare_accept_content(self.accept)
131
132        self.always_retry = conf.get('result_backend_always_retry', False)
133        self.max_sleep_between_retries_ms = conf.get('result_backend_max_sleep_between_retries_ms', 10000)
134        self.base_sleep_between_retries_ms = conf.get('result_backend_base_sleep_between_retries_ms', 10)
135        self.max_retries = conf.get('result_backend_max_retries', float("inf"))
136
137        self._pending_results = pending_results_t({}, WeakValueDictionary())
138        self._pending_messages = BufferMap(MESSAGE_BUFFER_MAX)
139        self.url = url
140
141    def as_uri(self, include_password=False):
142        """Return the backend as an URI, sanitizing the password or not."""
143        # when using maybe_sanitize_url(), "/" is added
144        # we're stripping it for consistency
145        if include_password:
146            return self.url
147        url = maybe_sanitize_url(self.url or '')
148        return url[:-1] if url.endswith(':///') else url
149
150    def mark_as_started(self, task_id, **meta):
151        """Mark a task as started."""
152        return self.store_result(task_id, meta, states.STARTED)
153
154    def mark_as_done(self, task_id, result,
155                     request=None, store_result=True, state=states.SUCCESS):
156        """Mark task as successfully executed."""
157        if store_result:
158            self.store_result(task_id, result, state, request=request)
159        if request and request.chord:
160            self.on_chord_part_return(request, state, result)
161
162    def mark_as_failure(self, task_id, exc,
163                        traceback=None, request=None,
164                        store_result=True, call_errbacks=True,
165                        state=states.FAILURE):
166        """Mark task as executed with failure."""
167        if store_result:
168            self.store_result(task_id, exc, state,
169                              traceback=traceback, request=request)
170        if request:
171            if request.chord:
172                self.on_chord_part_return(request, state, exc)
173            if call_errbacks and request.errbacks:
174                self._call_task_errbacks(request, exc, traceback)
175
176    def _call_task_errbacks(self, request, exc, traceback):
177        old_signature = []
178        for errback in request.errbacks:
179            errback = self.app.signature(errback)
180            if not errback._app:
181                # Ensure all signatures have an application
182                errback._app = self.app
183            try:
184                if (
185                        # Celery tasks type created with the @task decorator have
186                        # the __header__ property, but Celery task created from
187                        # Task class do not have this property.
188                        # That's why we have to check if this property exists
189                        # before checking is it partial function.
190                        hasattr(errback.type, '__header__') and
191
192                        # workaround to support tasks with bind=True executed as
193                        # link errors. Otherwise retries can't be used
194                        not isinstance(errback.type.__header__, partial) and
195                        arity_greater(errback.type.__header__, 1)
196                ):
197                    errback(request, exc, traceback)
198                else:
199                    old_signature.append(errback)
200            except NotRegistered:
201                # Task may not be present in this worker.
202                # We simply send it forward for another worker to consume.
203                # If the task is not registered there, the worker will raise
204                # NotRegistered.
205                old_signature.append(errback)
206
207        if old_signature:
208            # Previously errback was called as a task so we still
209            # need to do so if the errback only takes a single task_id arg.
210            task_id = request.id
211            root_id = request.root_id or task_id
212            g = group(old_signature, app=self.app)
213            if self.app.conf.task_always_eager or request.delivery_info.get('is_eager', False):
214                g.apply(
215                    (task_id,), parent_id=task_id, root_id=root_id
216                )
217            else:
218                g.apply_async(
219                    (task_id,), parent_id=task_id, root_id=root_id
220                )
221
222    def mark_as_revoked(self, task_id, reason='',
223                        request=None, store_result=True, state=states.REVOKED):
224        exc = TaskRevokedError(reason)
225        if store_result:
226            self.store_result(task_id, exc, state,
227                              traceback=None, request=request)
228        if request and request.chord:
229            self.on_chord_part_return(request, state, exc)
230
231    def mark_as_retry(self, task_id, exc, traceback=None,
232                      request=None, store_result=True, state=states.RETRY):
233        """Mark task as being retries.
234
235        Note:
236            Stores the current exception (if any).
237        """
238        return self.store_result(task_id, exc, state,
239                                 traceback=traceback, request=request)
240
241    def chord_error_from_stack(self, callback, exc=None):
242        # need below import for test for some crazy reason
243        from celery import group  # pylint: disable
244        app = self.app
245        try:
246            backend = app._tasks[callback.task].backend
247        except KeyError:
248            backend = self
249        try:
250            group(
251                [app.signature(errback)
252                 for errback in callback.options.get('link_error') or []],
253                app=app,
254            ).apply_async((callback.id,))
255        except Exception as eb_exc:  # pylint: disable=broad-except
256            return backend.fail_from_current_stack(callback.id, exc=eb_exc)
257        else:
258            return backend.fail_from_current_stack(callback.id, exc=exc)
259
260    def fail_from_current_stack(self, task_id, exc=None):
261        type_, real_exc, tb = sys.exc_info()
262        try:
263            exc = real_exc if exc is None else exc
264            exception_info = ExceptionInfo((type_, exc, tb))
265            self.mark_as_failure(task_id, exc, exception_info.traceback)
266            return exception_info
267        finally:
268            if sys.version_info >= (3, 5, 0):
269                while tb is not None:
270                    try:
271                        tb.tb_frame.clear()
272                        tb.tb_frame.f_locals
273                    except RuntimeError:
274                        # Ignore the exception raised if the frame is still executing.
275                        pass
276                    tb = tb.tb_next
277
278            elif (2, 7, 0) <= sys.version_info < (3, 0, 0):
279                sys.exc_clear()
280
281            del tb
282
283    def prepare_exception(self, exc, serializer=None):
284        """Prepare exception for serialization."""
285        serializer = self.serializer if serializer is None else serializer
286        if serializer in EXCEPTION_ABLE_CODECS:
287            return get_pickleable_exception(exc)
288        exctype = type(exc)
289        return {'exc_type': getattr(exctype, '__qualname__', exctype.__name__),
290                'exc_message': ensure_serializable(exc.args, self.encode),
291                'exc_module': exctype.__module__}
292
293    def exception_to_python(self, exc):
294        """Convert serialized exception to Python exception."""
295        if exc:
296            if not isinstance(exc, BaseException):
297                exc_module = exc.get('exc_module')
298                if exc_module is None:
299                    cls = create_exception_cls(
300                        from_utf8(exc['exc_type']), __name__)
301                else:
302                    exc_module = from_utf8(exc_module)
303                    exc_type = from_utf8(exc['exc_type'])
304                    try:
305                        # Load module and find exception class in that
306                        cls = sys.modules[exc_module]
307                        # The type can contain qualified name with parent classes
308                        for name in exc_type.split('.'):
309                            cls = getattr(cls, name)
310                    except (KeyError, AttributeError):
311                        cls = create_exception_cls(exc_type,
312                                                   celery.exceptions.__name__)
313                exc_msg = exc['exc_message']
314                try:
315                    if isinstance(exc_msg, (tuple, list)):
316                        exc = cls(*exc_msg)
317                    else:
318                        exc = cls(exc_msg)
319                except Exception as err:  # noqa
320                    exc = Exception('{}({})'.format(cls, exc_msg))
321            if self.serializer in EXCEPTION_ABLE_CODECS:
322                exc = get_pickled_exception(exc)
323        return exc
324
325    def prepare_value(self, result):
326        """Prepare value for storage."""
327        if self.serializer != 'pickle' and isinstance(result, ResultBase):
328            return result.as_tuple()
329        return result
330
331    def encode(self, data):
332        _, _, payload = self._encode(data)
333        return payload
334
335    def _encode(self, data):
336        return dumps(data, serializer=self.serializer)
337
338    def meta_from_decoded(self, meta):
339        if meta['status'] in self.EXCEPTION_STATES:
340            meta['result'] = self.exception_to_python(meta['result'])
341        return meta
342
343    def decode_result(self, payload):
344        return self.meta_from_decoded(self.decode(payload))
345
346    def decode(self, payload):
347        if payload is None:
348            return payload
349        payload = PY3 and payload or str(payload)
350        return loads(payload,
351                     content_type=self.content_type,
352                     content_encoding=self.content_encoding,
353                     accept=self.accept)
354
355    def prepare_expires(self, value, type=None):
356        if value is None:
357            value = self.app.conf.result_expires
358        if isinstance(value, timedelta):
359            value = value.total_seconds()
360        if value is not None and type:
361            return type(value)
362        return value
363
364    def prepare_persistent(self, enabled=None):
365        if enabled is not None:
366            return enabled
367        persistent = self.app.conf.result_persistent
368        return self.persistent if persistent is None else persistent
369
370    def encode_result(self, result, state):
371        if state in self.EXCEPTION_STATES and isinstance(result, Exception):
372            return self.prepare_exception(result)
373        return self.prepare_value(result)
374
375    def is_cached(self, task_id):
376        return task_id in self._cache
377
378    def _get_result_meta(self, result,
379                         state, traceback, request, format_date=True,
380                         encode=False):
381        if state in self.READY_STATES:
382            date_done = datetime.utcnow()
383            if format_date:
384                date_done = date_done.isoformat()
385        else:
386            date_done = None
387
388        meta = {
389            'status': state,
390            'result': result,
391            'traceback': traceback,
392            'children': self.current_task_children(request),
393            'date_done': date_done,
394        }
395
396        if request and getattr(request, 'group', None):
397            meta['group_id'] = request.group
398        if request and getattr(request, 'parent_id', None):
399            meta['parent_id'] = request.parent_id
400
401        if self.app.conf.find_value_for_key('extended', 'result'):
402            if request:
403                request_meta = {
404                    'name': getattr(request, 'task', None),
405                    'args': getattr(request, 'args', None),
406                    'kwargs': getattr(request, 'kwargs', None),
407                    'worker': getattr(request, 'hostname', None),
408                    'retries': getattr(request, 'retries', None),
409                    'queue': request.delivery_info.get('routing_key')
410                    if hasattr(request, 'delivery_info') and
411                    request.delivery_info else None
412                }
413
414                if encode:
415                    # args and kwargs need to be encoded properly before saving
416                    encode_needed_fields = {"args", "kwargs"}
417                    for field in encode_needed_fields:
418                        value = request_meta[field]
419                        encoded_value = self.encode(value)
420                        request_meta[field] = ensure_bytes(encoded_value)
421
422                meta.update(request_meta)
423
424        return meta
425
426    def _sleep(self, amount):
427        time.sleep(amount)
428
429    def store_result(self, task_id, result, state,
430                     traceback=None, request=None, **kwargs):
431        """Update task state and result.
432
433        if always_retry_backend_operation is activated, in the event of a recoverable exception,
434        then retry operation with an exponential backoff until a limit has been reached.
435        """
436        result = self.encode_result(result, state)
437
438        retries = 0
439
440        while True:
441            try:
442                self._store_result(task_id, result, state, traceback,
443                                   request=request, **kwargs)
444                return result
445            except Exception as exc:
446                if self.always_retry and self.exception_safe_to_retry(exc):
447                    if retries < self.max_retries:
448                        retries += 1
449
450                        # get_exponential_backoff_interval computes integers
451                        # and time.sleep accept floats for sub second sleep
452                        sleep_amount = get_exponential_backoff_interval(
453                            self.base_sleep_between_retries_ms, retries,
454                            self.max_sleep_between_retries_ms, True) / 1000
455                        self._sleep(sleep_amount)
456                    else:
457                        raise_with_context(
458                            BackendStoreError("failed to store result on the backend", task_id=task_id, state=state),
459                        )
460                else:
461                    raise
462
463    def forget(self, task_id):
464        self._cache.pop(task_id, None)
465        self._forget(task_id)
466
467    def _forget(self, task_id):
468        raise NotImplementedError('backend does not implement forget.')
469
470    def get_state(self, task_id):
471        """Get the state of a task."""
472        return self.get_task_meta(task_id)['status']
473
474    get_status = get_state  # XXX compat
475
476    def get_traceback(self, task_id):
477        """Get the traceback for a failed task."""
478        return self.get_task_meta(task_id).get('traceback')
479
480    def get_result(self, task_id):
481        """Get the result of a task."""
482        return self.get_task_meta(task_id).get('result')
483
484    def get_children(self, task_id):
485        """Get the list of subtasks sent by a task."""
486        try:
487            return self.get_task_meta(task_id)['children']
488        except KeyError:
489            pass
490
491    def _ensure_not_eager(self):
492        if self.app.conf.task_always_eager:
493            warnings.warn(
494                "Shouldn't retrieve result with task_always_eager enabled.",
495                RuntimeWarning
496            )
497
498    def exception_safe_to_retry(self, exc):
499        """Check if an exception is safe to retry.
500
501        Backends have to overload this method with correct predicates dealing with their exceptions.
502
503        By default no exception is safe to retry, it's up to backend implementation
504        to define which exceptions are safe.
505        """
506        return False
507
508    def get_task_meta(self, task_id, cache=True):
509        """Get task meta from backend.
510
511        if always_retry_backend_operation is activated, in the event of a recoverable exception,
512        then retry operation with an exponential backoff until a limit has been reached.
513        """
514        self._ensure_not_eager()
515        if cache:
516            try:
517                return self._cache[task_id]
518            except KeyError:
519                pass
520        retries = 0
521        while True:
522            try:
523                meta = self._get_task_meta_for(task_id)
524                break
525            except Exception as exc:
526                if self.always_retry and self.exception_safe_to_retry(exc):
527                    if retries < self.max_retries:
528                        retries += 1
529
530                        # get_exponential_backoff_interval computes integers
531                        # and time.sleep accept floats for sub second sleep
532                        sleep_amount = get_exponential_backoff_interval(
533                            self.base_sleep_between_retries_ms, retries,
534                            self.max_sleep_between_retries_ms, True) / 1000
535                        self._sleep(sleep_amount)
536                    else:
537                        raise_with_context(
538                            BackendGetMetaError("failed to get meta", task_id=task_id),
539                        )
540                else:
541                    raise
542
543        if cache and meta.get('status') == states.SUCCESS:
544            self._cache[task_id] = meta
545        return meta
546
547    def reload_task_result(self, task_id):
548        """Reload task result, even if it has been previously fetched."""
549        self._cache[task_id] = self.get_task_meta(task_id, cache=False)
550
551    def reload_group_result(self, group_id):
552        """Reload group result, even if it has been previously fetched."""
553        self._cache[group_id] = self.get_group_meta(group_id, cache=False)
554
555    def get_group_meta(self, group_id, cache=True):
556        self._ensure_not_eager()
557        if cache:
558            try:
559                return self._cache[group_id]
560            except KeyError:
561                pass
562
563        meta = self._restore_group(group_id)
564        if cache and meta is not None:
565            self._cache[group_id] = meta
566        return meta
567
568    def restore_group(self, group_id, cache=True):
569        """Get the result for a group."""
570        meta = self.get_group_meta(group_id, cache=cache)
571        if meta:
572            return meta['result']
573
574    def save_group(self, group_id, result):
575        """Store the result of an executed group."""
576        return self._save_group(group_id, result)
577
578    def delete_group(self, group_id):
579        self._cache.pop(group_id, None)
580        return self._delete_group(group_id)
581
582    def cleanup(self):
583        """Backend cleanup.
584
585        Note:
586            This is run by :class:`celery.task.DeleteExpiredTaskMetaTask`.
587        """
588
589    def process_cleanup(self):
590        """Cleanup actions to do at the end of a task worker process."""
591
592    def on_task_call(self, producer, task_id):
593        return {}
594
595    def add_to_chord(self, chord_id, result):
596        raise NotImplementedError('Backend does not support add_to_chord')
597
598    def on_chord_part_return(self, request, state, result, **kwargs):
599        pass
600
601    def fallback_chord_unlock(self, header_result, body, countdown=1,
602                              **kwargs):
603        kwargs['result'] = [r.as_tuple() for r in header_result]
604        queue = body.options.get('queue', getattr(body.type, 'queue', None))
605        priority = body.options.get('priority', getattr(body.type, 'priority', 0))
606        self.app.tasks['celery.chord_unlock'].apply_async(
607            (header_result.id, body,), kwargs,
608            countdown=countdown,
609            queue=queue,
610            priority=priority,
611        )
612
613    def ensure_chords_allowed(self):
614        pass
615
616    def apply_chord(self, header_result, body, **kwargs):
617        self.ensure_chords_allowed()
618        self.fallback_chord_unlock(header_result, body, **kwargs)
619
620    def current_task_children(self, request=None):
621        request = request or getattr(get_current_task(), 'request', None)
622        if request:
623            return [r.as_tuple() for r in getattr(request, 'children', [])]
624
625    def __reduce__(self, args=(), kwargs=None):
626        kwargs = {} if not kwargs else kwargs
627        return (unpickle_backend, (self.__class__, args, kwargs))
628
629
630class SyncBackendMixin(object):
631    def iter_native(self, result, timeout=None, interval=0.5, no_ack=True,
632                    on_message=None, on_interval=None):
633        self._ensure_not_eager()
634        results = result.results
635        if not results:
636            return
637
638        task_ids = set()
639        for result in results:
640            if isinstance(result, ResultSet):
641                yield result.id, result.results
642            else:
643                task_ids.add(result.id)
644
645        for task_id, meta in self.get_many(
646            task_ids,
647            timeout=timeout, interval=interval, no_ack=no_ack,
648            on_message=on_message, on_interval=on_interval,
649        ):
650            yield task_id, meta
651
652    def wait_for_pending(self, result, timeout=None, interval=0.5,
653                         no_ack=True, on_message=None, on_interval=None,
654                         callback=None, propagate=True):
655        self._ensure_not_eager()
656        if on_message is not None:
657            raise ImproperlyConfigured(
658                'Backend does not support on_message callback')
659
660        meta = self.wait_for(
661            result.id, timeout=timeout,
662            interval=interval,
663            on_interval=on_interval,
664            no_ack=no_ack,
665        )
666        if meta:
667            result._maybe_set_cache(meta)
668            return result.maybe_throw(propagate=propagate, callback=callback)
669
670    def wait_for(self, task_id,
671                 timeout=None, interval=0.5, no_ack=True, on_interval=None):
672        """Wait for task and return its result.
673
674        If the task raises an exception, this exception
675        will be re-raised by :func:`wait_for`.
676
677        Raises:
678            celery.exceptions.TimeoutError:
679                If `timeout` is not :const:`None`, and the operation
680                takes longer than `timeout` seconds.
681        """
682        self._ensure_not_eager()
683
684        time_elapsed = 0.0
685
686        while 1:
687            meta = self.get_task_meta(task_id)
688            if meta['status'] in states.READY_STATES:
689                return meta
690            if on_interval:
691                on_interval()
692            # avoid hammering the CPU checking status.
693            time.sleep(interval)
694            time_elapsed += interval
695            if timeout and time_elapsed >= timeout:
696                raise TimeoutError('The operation timed out.')
697
698    def add_pending_result(self, result, weak=False):
699        return result
700
701    def remove_pending_result(self, result):
702        return result
703
704    @property
705    def is_async(self):
706        return False
707
708
709class BaseBackend(Backend, SyncBackendMixin):
710    """Base (synchronous) result backend."""
711
712
713BaseDictBackend = BaseBackend  # noqa: E305 XXX compat
714
715
716class BaseKeyValueStoreBackend(Backend):
717    key_t = ensure_bytes
718    task_keyprefix = 'celery-task-meta-'
719    group_keyprefix = 'celery-taskset-meta-'
720    chord_keyprefix = 'chord-unlock-'
721    implements_incr = False
722
723    def __init__(self, *args, **kwargs):
724        if hasattr(self.key_t, '__func__'):  # pragma: no cover
725            self.key_t = self.key_t.__func__  # remove binding
726        self._encode_prefixes()
727        super(BaseKeyValueStoreBackend, self).__init__(*args, **kwargs)
728        if self.implements_incr:
729            self.apply_chord = self._apply_chord_incr
730
731    def _encode_prefixes(self):
732        self.task_keyprefix = self.key_t(self.task_keyprefix)
733        self.group_keyprefix = self.key_t(self.group_keyprefix)
734        self.chord_keyprefix = self.key_t(self.chord_keyprefix)
735
736    def get(self, key):
737        raise NotImplementedError('Must implement the get method.')
738
739    def mget(self, keys):
740        raise NotImplementedError('Does not support get_many')
741
742    def _set_with_state(self, key, value, state):
743        return self.set(key, value)
744
745    def set(self, key, value):
746        raise NotImplementedError('Must implement the set method.')
747
748    def delete(self, key):
749        raise NotImplementedError('Must implement the delete method')
750
751    def incr(self, key):
752        raise NotImplementedError('Does not implement incr')
753
754    def expire(self, key, value):
755        pass
756
757    def get_key_for_task(self, task_id, key=''):
758        """Get the cache key for a task by id."""
759        key_t = self.key_t
760        return key_t('').join([
761            self.task_keyprefix, key_t(task_id), key_t(key),
762        ])
763
764    def get_key_for_group(self, group_id, key=''):
765        """Get the cache key for a group by id."""
766        key_t = self.key_t
767        return key_t('').join([
768            self.group_keyprefix, key_t(group_id), key_t(key),
769        ])
770
771    def get_key_for_chord(self, group_id, key=''):
772        """Get the cache key for the chord waiting on group with given id."""
773        key_t = self.key_t
774        return key_t('').join([
775            self.chord_keyprefix, key_t(group_id), key_t(key),
776        ])
777
778    def _strip_prefix(self, key):
779        """Take bytes: emit string."""
780        key = self.key_t(key)
781        for prefix in self.task_keyprefix, self.group_keyprefix:
782            if key.startswith(prefix):
783                return bytes_to_str(key[len(prefix):])
784        return bytes_to_str(key)
785
786    def _filter_ready(self, values, READY_STATES=states.READY_STATES):
787        for k, value in values:
788            if value is not None:
789                value = self.decode_result(value)
790                if value['status'] in READY_STATES:
791                    yield k, value
792
793    def _mget_to_results(self, values, keys, READY_STATES=states.READY_STATES):
794        if hasattr(values, 'items'):
795            # client returns dict so mapping preserved.
796            return {
797                self._strip_prefix(k): v
798                for k, v in self._filter_ready(items(values), READY_STATES)
799            }
800        else:
801            # client returns list so need to recreate mapping.
802            return {
803                bytes_to_str(keys[i]): v
804                for i, v in self._filter_ready(enumerate(values), READY_STATES)
805            }
806
807    def get_many(self, task_ids, timeout=None, interval=0.5, no_ack=True,
808                 on_message=None, on_interval=None, max_iterations=None,
809                 READY_STATES=states.READY_STATES):
810        interval = 0.5 if interval is None else interval
811        ids = task_ids if isinstance(task_ids, set) else set(task_ids)
812        cached_ids = set()
813        cache = self._cache
814        for task_id in ids:
815            try:
816                cached = cache[task_id]
817            except KeyError:
818                pass
819            else:
820                if cached['status'] in READY_STATES:
821                    yield bytes_to_str(task_id), cached
822                    cached_ids.add(task_id)
823
824        ids.difference_update(cached_ids)
825        iterations = 0
826        while ids:
827            keys = list(ids)
828            r = self._mget_to_results(self.mget([self.get_key_for_task(k)
829                                                 for k in keys]), keys, READY_STATES)
830            cache.update(r)
831            ids.difference_update({bytes_to_str(v) for v in r})
832            for key, value in items(r):
833                if on_message is not None:
834                    on_message(value)
835                yield bytes_to_str(key), value
836            if timeout and iterations * interval >= timeout:
837                raise TimeoutError('Operation timed out ({0})'.format(timeout))
838            if on_interval:
839                on_interval()
840            time.sleep(interval)  # don't busy loop.
841            iterations += 1
842            if max_iterations and iterations >= max_iterations:
843                break
844
845    def _forget(self, task_id):
846        self.delete(self.get_key_for_task(task_id))
847
848    def _store_result(self, task_id, result, state,
849                      traceback=None, request=None, **kwargs):
850        meta = self._get_result_meta(result=result, state=state,
851                                     traceback=traceback, request=request)
852        meta['task_id'] = bytes_to_str(task_id)
853
854        # Retrieve metadata from the backend, if the status
855        # is a success then we ignore any following update to the state.
856        # This solves a task deduplication issue because of network
857        # partitioning or lost workers. This issue involved a race condition
858        # making a lost task overwrite the last successful result in the
859        # result backend.
860        current_meta = self._get_task_meta_for(task_id)
861
862        if current_meta['status'] == states.SUCCESS:
863            return result
864
865        self._set_with_state(self.get_key_for_task(task_id), self.encode(meta), state)
866        return result
867
868    def _save_group(self, group_id, result):
869        self._set_with_state(self.get_key_for_group(group_id),
870                             self.encode({'result': result.as_tuple()}), states.SUCCESS)
871        return result
872
873    def _delete_group(self, group_id):
874        self.delete(self.get_key_for_group(group_id))
875
876    def _get_task_meta_for(self, task_id):
877        """Get task meta-data for a task by id."""
878        meta = self.get(self.get_key_for_task(task_id))
879        if not meta:
880            return {'status': states.PENDING, 'result': None}
881        return self.decode_result(meta)
882
883    def _restore_group(self, group_id):
884        """Get task meta-data for a task by id."""
885        meta = self.get(self.get_key_for_group(group_id))
886        # previously this was always pickled, but later this
887        # was extended to support other serializers, so the
888        # structure is kind of weird.
889        if meta:
890            meta = self.decode(meta)
891            result = meta['result']
892            meta['result'] = result_from_tuple(result, self.app)
893            return meta
894
895    def _apply_chord_incr(self, header_result, body, **kwargs):
896        self.ensure_chords_allowed()
897        header_result.save(backend=self)
898
899    def on_chord_part_return(self, request, state, result, **kwargs):
900        if not self.implements_incr:
901            return
902        app = self.app
903        gid = request.group
904        if not gid:
905            return
906        key = self.get_key_for_chord(gid)
907        try:
908            deps = GroupResult.restore(gid, backend=self)
909        except Exception as exc:  # pylint: disable=broad-except
910            callback = maybe_signature(request.chord, app=app)
911            logger.exception('Chord %r raised: %r', gid, exc)
912            return self.chord_error_from_stack(
913                callback,
914                ChordError('Cannot restore group: {0!r}'.format(exc)),
915            )
916        if deps is None:
917            try:
918                raise ValueError(gid)
919            except ValueError as exc:
920                callback = maybe_signature(request.chord, app=app)
921                logger.exception('Chord callback %r raised: %r', gid, exc)
922                return self.chord_error_from_stack(
923                    callback,
924                    ChordError('GroupResult {0} no longer exists'.format(gid)),
925                )
926        val = self.incr(key)
927        size = len(deps)
928        if val > size:  # pragma: no cover
929            logger.warning('Chord counter incremented too many times for %r',
930                           gid)
931        elif val == size:
932            callback = maybe_signature(request.chord, app=app)
933            j = deps.join_native if deps.supports_native_join else deps.join
934            try:
935                with allow_join_result():
936                    ret = j(timeout=3.0, propagate=True)
937            except Exception as exc:  # pylint: disable=broad-except
938                try:
939                    culprit = next(deps._failed_join_report())
940                    reason = 'Dependency {0.id} raised {1!r}'.format(
941                        culprit, exc,
942                    )
943                except StopIteration:
944                    reason = repr(exc)
945
946                logger.exception('Chord %r raised: %r', gid, reason)
947                self.chord_error_from_stack(callback, ChordError(reason))
948            else:
949                try:
950                    callback.delay(ret)
951                except Exception as exc:  # pylint: disable=broad-except
952                    logger.exception('Chord %r raised: %r', gid, exc)
953                    self.chord_error_from_stack(
954                        callback,
955                        ChordError('Callback error: {0!r}'.format(exc)),
956                    )
957            finally:
958                deps.delete()
959                self.client.delete(key)
960        else:
961            self.expire(key, self.expires)
962
963
964class KeyValueStoreBackend(BaseKeyValueStoreBackend, SyncBackendMixin):
965    """Result backend base class for key/value stores."""
966
967
968class DisabledBackend(BaseBackend):
969    """Dummy result backend."""
970
971    _cache = {}  # need this attribute to reset cache in tests.
972
973    def store_result(self, *args, **kwargs):
974        pass
975
976    def ensure_chords_allowed(self):
977        raise NotImplementedError(E_CHORD_NO_BACKEND.strip())
978
979    def _is_disabled(self, *args, **kwargs):
980        raise NotImplementedError(E_NO_BACKEND.strip())
981
982    def as_uri(self, *args, **kwargs):
983        return 'disabled://'
984
985    get_state = get_status = get_result = get_traceback = _is_disabled
986    get_task_meta_for = wait_for = get_many = _is_disabled
987