1# -*- coding: utf-8 -*- 2"""Result backend base classes. 3 4- :class:`BaseBackend` defines the interface. 5 6- :class:`KeyValueStoreBackend` is a common base class 7 using K/V semantics like _get and _put. 8""" 9from __future__ import absolute_import, unicode_literals 10 11from datetime import datetime, timedelta 12import sys 13import time 14import warnings 15from collections import namedtuple 16from functools import partial 17from weakref import WeakValueDictionary 18 19from billiard.einfo import ExceptionInfo 20from kombu.serialization import dumps, loads, prepare_accept_content 21from kombu.serialization import registry as serializer_registry 22from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8 23from kombu.utils.url import maybe_sanitize_url 24 25import celery.exceptions 26from celery import current_app, group, maybe_signature, states 27from celery._state import get_current_task 28from celery.exceptions import (ChordError, ImproperlyConfigured, 29 NotRegistered, TaskRevokedError, TimeoutError, 30 BackendGetMetaError, BackendStoreError) 31from celery.five import PY3, items 32from celery.result import (GroupResult, ResultBase, ResultSet, 33 allow_join_result, result_from_tuple) 34from celery.utils.collections import BufferMap 35from celery.utils.functional import LRUCache, arity_greater 36from celery.utils.log import get_logger 37from celery.utils.serialization import (create_exception_cls, 38 ensure_serializable, 39 get_pickleable_exception, 40 get_pickled_exception, 41 raise_with_context) 42from celery.utils.time import get_exponential_backoff_interval 43 44__all__ = ('BaseBackend', 'KeyValueStoreBackend', 'DisabledBackend') 45 46EXCEPTION_ABLE_CODECS = frozenset({'pickle'}) 47 48logger = get_logger(__name__) 49 50MESSAGE_BUFFER_MAX = 8192 51 52pending_results_t = namedtuple('pending_results_t', ( 53 'concrete', 'weak', 54)) 55 56E_NO_BACKEND = """ 57No result backend is configured. 58Please see the documentation for more information. 59""" 60 61E_CHORD_NO_BACKEND = """ 62Starting chords requires a result backend to be configured. 63 64Note that a group chained with a task is also upgraded to be a chord, 65as this pattern requires synchronization. 66 67Result backends that supports chords: Redis, Database, Memcached, and more. 68""" 69 70 71def unpickle_backend(cls, args, kwargs): 72 """Return an unpickled backend.""" 73 return cls(*args, app=current_app._get_current_object(), **kwargs) 74 75 76class _nulldict(dict): 77 def ignore(self, *a, **kw): 78 pass 79 80 __setitem__ = update = setdefault = ignore 81 82 83class Backend(object): 84 READY_STATES = states.READY_STATES 85 UNREADY_STATES = states.UNREADY_STATES 86 EXCEPTION_STATES = states.EXCEPTION_STATES 87 88 TimeoutError = TimeoutError 89 90 #: Time to sleep between polling each individual item 91 #: in `ResultSet.iterate`. as opposed to the `interval` 92 #: argument which is for each pass. 93 subpolling_interval = None 94 95 #: If true the backend must implement :meth:`get_many`. 96 supports_native_join = False 97 98 #: If true the backend must automatically expire results. 99 #: The daily backend_cleanup periodic task won't be triggered 100 #: in this case. 101 supports_autoexpire = False 102 103 #: Set to true if the backend is persistent by default. 104 persistent = True 105 106 retry_policy = { 107 'max_retries': 20, 108 'interval_start': 0, 109 'interval_step': 1, 110 'interval_max': 1, 111 } 112 113 def __init__(self, app, 114 serializer=None, max_cached_results=None, accept=None, 115 expires=None, expires_type=None, url=None, **kwargs): 116 self.app = app 117 conf = self.app.conf 118 self.serializer = serializer or conf.result_serializer 119 (self.content_type, 120 self.content_encoding, 121 self.encoder) = serializer_registry._encoders[self.serializer] 122 cmax = max_cached_results or conf.result_cache_max 123 self._cache = _nulldict() if cmax == -1 else LRUCache(limit=cmax) 124 125 self.expires = self.prepare_expires(expires, expires_type) 126 127 # precedence: accept, conf.result_accept_content, conf.accept_content 128 self.accept = conf.result_accept_content if accept is None else accept 129 self.accept = conf.accept_content if self.accept is None else self.accept # noqa: E501 130 self.accept = prepare_accept_content(self.accept) 131 132 self.always_retry = conf.get('result_backend_always_retry', False) 133 self.max_sleep_between_retries_ms = conf.get('result_backend_max_sleep_between_retries_ms', 10000) 134 self.base_sleep_between_retries_ms = conf.get('result_backend_base_sleep_between_retries_ms', 10) 135 self.max_retries = conf.get('result_backend_max_retries', float("inf")) 136 137 self._pending_results = pending_results_t({}, WeakValueDictionary()) 138 self._pending_messages = BufferMap(MESSAGE_BUFFER_MAX) 139 self.url = url 140 141 def as_uri(self, include_password=False): 142 """Return the backend as an URI, sanitizing the password or not.""" 143 # when using maybe_sanitize_url(), "/" is added 144 # we're stripping it for consistency 145 if include_password: 146 return self.url 147 url = maybe_sanitize_url(self.url or '') 148 return url[:-1] if url.endswith(':///') else url 149 150 def mark_as_started(self, task_id, **meta): 151 """Mark a task as started.""" 152 return self.store_result(task_id, meta, states.STARTED) 153 154 def mark_as_done(self, task_id, result, 155 request=None, store_result=True, state=states.SUCCESS): 156 """Mark task as successfully executed.""" 157 if store_result: 158 self.store_result(task_id, result, state, request=request) 159 if request and request.chord: 160 self.on_chord_part_return(request, state, result) 161 162 def mark_as_failure(self, task_id, exc, 163 traceback=None, request=None, 164 store_result=True, call_errbacks=True, 165 state=states.FAILURE): 166 """Mark task as executed with failure.""" 167 if store_result: 168 self.store_result(task_id, exc, state, 169 traceback=traceback, request=request) 170 if request: 171 if request.chord: 172 self.on_chord_part_return(request, state, exc) 173 if call_errbacks and request.errbacks: 174 self._call_task_errbacks(request, exc, traceback) 175 176 def _call_task_errbacks(self, request, exc, traceback): 177 old_signature = [] 178 for errback in request.errbacks: 179 errback = self.app.signature(errback) 180 if not errback._app: 181 # Ensure all signatures have an application 182 errback._app = self.app 183 try: 184 if ( 185 # Celery tasks type created with the @task decorator have 186 # the __header__ property, but Celery task created from 187 # Task class do not have this property. 188 # That's why we have to check if this property exists 189 # before checking is it partial function. 190 hasattr(errback.type, '__header__') and 191 192 # workaround to support tasks with bind=True executed as 193 # link errors. Otherwise retries can't be used 194 not isinstance(errback.type.__header__, partial) and 195 arity_greater(errback.type.__header__, 1) 196 ): 197 errback(request, exc, traceback) 198 else: 199 old_signature.append(errback) 200 except NotRegistered: 201 # Task may not be present in this worker. 202 # We simply send it forward for another worker to consume. 203 # If the task is not registered there, the worker will raise 204 # NotRegistered. 205 old_signature.append(errback) 206 207 if old_signature: 208 # Previously errback was called as a task so we still 209 # need to do so if the errback only takes a single task_id arg. 210 task_id = request.id 211 root_id = request.root_id or task_id 212 g = group(old_signature, app=self.app) 213 if self.app.conf.task_always_eager or request.delivery_info.get('is_eager', False): 214 g.apply( 215 (task_id,), parent_id=task_id, root_id=root_id 216 ) 217 else: 218 g.apply_async( 219 (task_id,), parent_id=task_id, root_id=root_id 220 ) 221 222 def mark_as_revoked(self, task_id, reason='', 223 request=None, store_result=True, state=states.REVOKED): 224 exc = TaskRevokedError(reason) 225 if store_result: 226 self.store_result(task_id, exc, state, 227 traceback=None, request=request) 228 if request and request.chord: 229 self.on_chord_part_return(request, state, exc) 230 231 def mark_as_retry(self, task_id, exc, traceback=None, 232 request=None, store_result=True, state=states.RETRY): 233 """Mark task as being retries. 234 235 Note: 236 Stores the current exception (if any). 237 """ 238 return self.store_result(task_id, exc, state, 239 traceback=traceback, request=request) 240 241 def chord_error_from_stack(self, callback, exc=None): 242 # need below import for test for some crazy reason 243 from celery import group # pylint: disable 244 app = self.app 245 try: 246 backend = app._tasks[callback.task].backend 247 except KeyError: 248 backend = self 249 try: 250 group( 251 [app.signature(errback) 252 for errback in callback.options.get('link_error') or []], 253 app=app, 254 ).apply_async((callback.id,)) 255 except Exception as eb_exc: # pylint: disable=broad-except 256 return backend.fail_from_current_stack(callback.id, exc=eb_exc) 257 else: 258 return backend.fail_from_current_stack(callback.id, exc=exc) 259 260 def fail_from_current_stack(self, task_id, exc=None): 261 type_, real_exc, tb = sys.exc_info() 262 try: 263 exc = real_exc if exc is None else exc 264 exception_info = ExceptionInfo((type_, exc, tb)) 265 self.mark_as_failure(task_id, exc, exception_info.traceback) 266 return exception_info 267 finally: 268 if sys.version_info >= (3, 5, 0): 269 while tb is not None: 270 try: 271 tb.tb_frame.clear() 272 tb.tb_frame.f_locals 273 except RuntimeError: 274 # Ignore the exception raised if the frame is still executing. 275 pass 276 tb = tb.tb_next 277 278 elif (2, 7, 0) <= sys.version_info < (3, 0, 0): 279 sys.exc_clear() 280 281 del tb 282 283 def prepare_exception(self, exc, serializer=None): 284 """Prepare exception for serialization.""" 285 serializer = self.serializer if serializer is None else serializer 286 if serializer in EXCEPTION_ABLE_CODECS: 287 return get_pickleable_exception(exc) 288 exctype = type(exc) 289 return {'exc_type': getattr(exctype, '__qualname__', exctype.__name__), 290 'exc_message': ensure_serializable(exc.args, self.encode), 291 'exc_module': exctype.__module__} 292 293 def exception_to_python(self, exc): 294 """Convert serialized exception to Python exception.""" 295 if exc: 296 if not isinstance(exc, BaseException): 297 exc_module = exc.get('exc_module') 298 if exc_module is None: 299 cls = create_exception_cls( 300 from_utf8(exc['exc_type']), __name__) 301 else: 302 exc_module = from_utf8(exc_module) 303 exc_type = from_utf8(exc['exc_type']) 304 try: 305 # Load module and find exception class in that 306 cls = sys.modules[exc_module] 307 # The type can contain qualified name with parent classes 308 for name in exc_type.split('.'): 309 cls = getattr(cls, name) 310 except (KeyError, AttributeError): 311 cls = create_exception_cls(exc_type, 312 celery.exceptions.__name__) 313 exc_msg = exc['exc_message'] 314 try: 315 if isinstance(exc_msg, (tuple, list)): 316 exc = cls(*exc_msg) 317 else: 318 exc = cls(exc_msg) 319 except Exception as err: # noqa 320 exc = Exception('{}({})'.format(cls, exc_msg)) 321 if self.serializer in EXCEPTION_ABLE_CODECS: 322 exc = get_pickled_exception(exc) 323 return exc 324 325 def prepare_value(self, result): 326 """Prepare value for storage.""" 327 if self.serializer != 'pickle' and isinstance(result, ResultBase): 328 return result.as_tuple() 329 return result 330 331 def encode(self, data): 332 _, _, payload = self._encode(data) 333 return payload 334 335 def _encode(self, data): 336 return dumps(data, serializer=self.serializer) 337 338 def meta_from_decoded(self, meta): 339 if meta['status'] in self.EXCEPTION_STATES: 340 meta['result'] = self.exception_to_python(meta['result']) 341 return meta 342 343 def decode_result(self, payload): 344 return self.meta_from_decoded(self.decode(payload)) 345 346 def decode(self, payload): 347 if payload is None: 348 return payload 349 payload = PY3 and payload or str(payload) 350 return loads(payload, 351 content_type=self.content_type, 352 content_encoding=self.content_encoding, 353 accept=self.accept) 354 355 def prepare_expires(self, value, type=None): 356 if value is None: 357 value = self.app.conf.result_expires 358 if isinstance(value, timedelta): 359 value = value.total_seconds() 360 if value is not None and type: 361 return type(value) 362 return value 363 364 def prepare_persistent(self, enabled=None): 365 if enabled is not None: 366 return enabled 367 persistent = self.app.conf.result_persistent 368 return self.persistent if persistent is None else persistent 369 370 def encode_result(self, result, state): 371 if state in self.EXCEPTION_STATES and isinstance(result, Exception): 372 return self.prepare_exception(result) 373 return self.prepare_value(result) 374 375 def is_cached(self, task_id): 376 return task_id in self._cache 377 378 def _get_result_meta(self, result, 379 state, traceback, request, format_date=True, 380 encode=False): 381 if state in self.READY_STATES: 382 date_done = datetime.utcnow() 383 if format_date: 384 date_done = date_done.isoformat() 385 else: 386 date_done = None 387 388 meta = { 389 'status': state, 390 'result': result, 391 'traceback': traceback, 392 'children': self.current_task_children(request), 393 'date_done': date_done, 394 } 395 396 if request and getattr(request, 'group', None): 397 meta['group_id'] = request.group 398 if request and getattr(request, 'parent_id', None): 399 meta['parent_id'] = request.parent_id 400 401 if self.app.conf.find_value_for_key('extended', 'result'): 402 if request: 403 request_meta = { 404 'name': getattr(request, 'task', None), 405 'args': getattr(request, 'args', None), 406 'kwargs': getattr(request, 'kwargs', None), 407 'worker': getattr(request, 'hostname', None), 408 'retries': getattr(request, 'retries', None), 409 'queue': request.delivery_info.get('routing_key') 410 if hasattr(request, 'delivery_info') and 411 request.delivery_info else None 412 } 413 414 if encode: 415 # args and kwargs need to be encoded properly before saving 416 encode_needed_fields = {"args", "kwargs"} 417 for field in encode_needed_fields: 418 value = request_meta[field] 419 encoded_value = self.encode(value) 420 request_meta[field] = ensure_bytes(encoded_value) 421 422 meta.update(request_meta) 423 424 return meta 425 426 def _sleep(self, amount): 427 time.sleep(amount) 428 429 def store_result(self, task_id, result, state, 430 traceback=None, request=None, **kwargs): 431 """Update task state and result. 432 433 if always_retry_backend_operation is activated, in the event of a recoverable exception, 434 then retry operation with an exponential backoff until a limit has been reached. 435 """ 436 result = self.encode_result(result, state) 437 438 retries = 0 439 440 while True: 441 try: 442 self._store_result(task_id, result, state, traceback, 443 request=request, **kwargs) 444 return result 445 except Exception as exc: 446 if self.always_retry and self.exception_safe_to_retry(exc): 447 if retries < self.max_retries: 448 retries += 1 449 450 # get_exponential_backoff_interval computes integers 451 # and time.sleep accept floats for sub second sleep 452 sleep_amount = get_exponential_backoff_interval( 453 self.base_sleep_between_retries_ms, retries, 454 self.max_sleep_between_retries_ms, True) / 1000 455 self._sleep(sleep_amount) 456 else: 457 raise_with_context( 458 BackendStoreError("failed to store result on the backend", task_id=task_id, state=state), 459 ) 460 else: 461 raise 462 463 def forget(self, task_id): 464 self._cache.pop(task_id, None) 465 self._forget(task_id) 466 467 def _forget(self, task_id): 468 raise NotImplementedError('backend does not implement forget.') 469 470 def get_state(self, task_id): 471 """Get the state of a task.""" 472 return self.get_task_meta(task_id)['status'] 473 474 get_status = get_state # XXX compat 475 476 def get_traceback(self, task_id): 477 """Get the traceback for a failed task.""" 478 return self.get_task_meta(task_id).get('traceback') 479 480 def get_result(self, task_id): 481 """Get the result of a task.""" 482 return self.get_task_meta(task_id).get('result') 483 484 def get_children(self, task_id): 485 """Get the list of subtasks sent by a task.""" 486 try: 487 return self.get_task_meta(task_id)['children'] 488 except KeyError: 489 pass 490 491 def _ensure_not_eager(self): 492 if self.app.conf.task_always_eager: 493 warnings.warn( 494 "Shouldn't retrieve result with task_always_eager enabled.", 495 RuntimeWarning 496 ) 497 498 def exception_safe_to_retry(self, exc): 499 """Check if an exception is safe to retry. 500 501 Backends have to overload this method with correct predicates dealing with their exceptions. 502 503 By default no exception is safe to retry, it's up to backend implementation 504 to define which exceptions are safe. 505 """ 506 return False 507 508 def get_task_meta(self, task_id, cache=True): 509 """Get task meta from backend. 510 511 if always_retry_backend_operation is activated, in the event of a recoverable exception, 512 then retry operation with an exponential backoff until a limit has been reached. 513 """ 514 self._ensure_not_eager() 515 if cache: 516 try: 517 return self._cache[task_id] 518 except KeyError: 519 pass 520 retries = 0 521 while True: 522 try: 523 meta = self._get_task_meta_for(task_id) 524 break 525 except Exception as exc: 526 if self.always_retry and self.exception_safe_to_retry(exc): 527 if retries < self.max_retries: 528 retries += 1 529 530 # get_exponential_backoff_interval computes integers 531 # and time.sleep accept floats for sub second sleep 532 sleep_amount = get_exponential_backoff_interval( 533 self.base_sleep_between_retries_ms, retries, 534 self.max_sleep_between_retries_ms, True) / 1000 535 self._sleep(sleep_amount) 536 else: 537 raise_with_context( 538 BackendGetMetaError("failed to get meta", task_id=task_id), 539 ) 540 else: 541 raise 542 543 if cache and meta.get('status') == states.SUCCESS: 544 self._cache[task_id] = meta 545 return meta 546 547 def reload_task_result(self, task_id): 548 """Reload task result, even if it has been previously fetched.""" 549 self._cache[task_id] = self.get_task_meta(task_id, cache=False) 550 551 def reload_group_result(self, group_id): 552 """Reload group result, even if it has been previously fetched.""" 553 self._cache[group_id] = self.get_group_meta(group_id, cache=False) 554 555 def get_group_meta(self, group_id, cache=True): 556 self._ensure_not_eager() 557 if cache: 558 try: 559 return self._cache[group_id] 560 except KeyError: 561 pass 562 563 meta = self._restore_group(group_id) 564 if cache and meta is not None: 565 self._cache[group_id] = meta 566 return meta 567 568 def restore_group(self, group_id, cache=True): 569 """Get the result for a group.""" 570 meta = self.get_group_meta(group_id, cache=cache) 571 if meta: 572 return meta['result'] 573 574 def save_group(self, group_id, result): 575 """Store the result of an executed group.""" 576 return self._save_group(group_id, result) 577 578 def delete_group(self, group_id): 579 self._cache.pop(group_id, None) 580 return self._delete_group(group_id) 581 582 def cleanup(self): 583 """Backend cleanup. 584 585 Note: 586 This is run by :class:`celery.task.DeleteExpiredTaskMetaTask`. 587 """ 588 589 def process_cleanup(self): 590 """Cleanup actions to do at the end of a task worker process.""" 591 592 def on_task_call(self, producer, task_id): 593 return {} 594 595 def add_to_chord(self, chord_id, result): 596 raise NotImplementedError('Backend does not support add_to_chord') 597 598 def on_chord_part_return(self, request, state, result, **kwargs): 599 pass 600 601 def fallback_chord_unlock(self, header_result, body, countdown=1, 602 **kwargs): 603 kwargs['result'] = [r.as_tuple() for r in header_result] 604 queue = body.options.get('queue', getattr(body.type, 'queue', None)) 605 priority = body.options.get('priority', getattr(body.type, 'priority', 0)) 606 self.app.tasks['celery.chord_unlock'].apply_async( 607 (header_result.id, body,), kwargs, 608 countdown=countdown, 609 queue=queue, 610 priority=priority, 611 ) 612 613 def ensure_chords_allowed(self): 614 pass 615 616 def apply_chord(self, header_result, body, **kwargs): 617 self.ensure_chords_allowed() 618 self.fallback_chord_unlock(header_result, body, **kwargs) 619 620 def current_task_children(self, request=None): 621 request = request or getattr(get_current_task(), 'request', None) 622 if request: 623 return [r.as_tuple() for r in getattr(request, 'children', [])] 624 625 def __reduce__(self, args=(), kwargs=None): 626 kwargs = {} if not kwargs else kwargs 627 return (unpickle_backend, (self.__class__, args, kwargs)) 628 629 630class SyncBackendMixin(object): 631 def iter_native(self, result, timeout=None, interval=0.5, no_ack=True, 632 on_message=None, on_interval=None): 633 self._ensure_not_eager() 634 results = result.results 635 if not results: 636 return 637 638 task_ids = set() 639 for result in results: 640 if isinstance(result, ResultSet): 641 yield result.id, result.results 642 else: 643 task_ids.add(result.id) 644 645 for task_id, meta in self.get_many( 646 task_ids, 647 timeout=timeout, interval=interval, no_ack=no_ack, 648 on_message=on_message, on_interval=on_interval, 649 ): 650 yield task_id, meta 651 652 def wait_for_pending(self, result, timeout=None, interval=0.5, 653 no_ack=True, on_message=None, on_interval=None, 654 callback=None, propagate=True): 655 self._ensure_not_eager() 656 if on_message is not None: 657 raise ImproperlyConfigured( 658 'Backend does not support on_message callback') 659 660 meta = self.wait_for( 661 result.id, timeout=timeout, 662 interval=interval, 663 on_interval=on_interval, 664 no_ack=no_ack, 665 ) 666 if meta: 667 result._maybe_set_cache(meta) 668 return result.maybe_throw(propagate=propagate, callback=callback) 669 670 def wait_for(self, task_id, 671 timeout=None, interval=0.5, no_ack=True, on_interval=None): 672 """Wait for task and return its result. 673 674 If the task raises an exception, this exception 675 will be re-raised by :func:`wait_for`. 676 677 Raises: 678 celery.exceptions.TimeoutError: 679 If `timeout` is not :const:`None`, and the operation 680 takes longer than `timeout` seconds. 681 """ 682 self._ensure_not_eager() 683 684 time_elapsed = 0.0 685 686 while 1: 687 meta = self.get_task_meta(task_id) 688 if meta['status'] in states.READY_STATES: 689 return meta 690 if on_interval: 691 on_interval() 692 # avoid hammering the CPU checking status. 693 time.sleep(interval) 694 time_elapsed += interval 695 if timeout and time_elapsed >= timeout: 696 raise TimeoutError('The operation timed out.') 697 698 def add_pending_result(self, result, weak=False): 699 return result 700 701 def remove_pending_result(self, result): 702 return result 703 704 @property 705 def is_async(self): 706 return False 707 708 709class BaseBackend(Backend, SyncBackendMixin): 710 """Base (synchronous) result backend.""" 711 712 713BaseDictBackend = BaseBackend # noqa: E305 XXX compat 714 715 716class BaseKeyValueStoreBackend(Backend): 717 key_t = ensure_bytes 718 task_keyprefix = 'celery-task-meta-' 719 group_keyprefix = 'celery-taskset-meta-' 720 chord_keyprefix = 'chord-unlock-' 721 implements_incr = False 722 723 def __init__(self, *args, **kwargs): 724 if hasattr(self.key_t, '__func__'): # pragma: no cover 725 self.key_t = self.key_t.__func__ # remove binding 726 self._encode_prefixes() 727 super(BaseKeyValueStoreBackend, self).__init__(*args, **kwargs) 728 if self.implements_incr: 729 self.apply_chord = self._apply_chord_incr 730 731 def _encode_prefixes(self): 732 self.task_keyprefix = self.key_t(self.task_keyprefix) 733 self.group_keyprefix = self.key_t(self.group_keyprefix) 734 self.chord_keyprefix = self.key_t(self.chord_keyprefix) 735 736 def get(self, key): 737 raise NotImplementedError('Must implement the get method.') 738 739 def mget(self, keys): 740 raise NotImplementedError('Does not support get_many') 741 742 def _set_with_state(self, key, value, state): 743 return self.set(key, value) 744 745 def set(self, key, value): 746 raise NotImplementedError('Must implement the set method.') 747 748 def delete(self, key): 749 raise NotImplementedError('Must implement the delete method') 750 751 def incr(self, key): 752 raise NotImplementedError('Does not implement incr') 753 754 def expire(self, key, value): 755 pass 756 757 def get_key_for_task(self, task_id, key=''): 758 """Get the cache key for a task by id.""" 759 key_t = self.key_t 760 return key_t('').join([ 761 self.task_keyprefix, key_t(task_id), key_t(key), 762 ]) 763 764 def get_key_for_group(self, group_id, key=''): 765 """Get the cache key for a group by id.""" 766 key_t = self.key_t 767 return key_t('').join([ 768 self.group_keyprefix, key_t(group_id), key_t(key), 769 ]) 770 771 def get_key_for_chord(self, group_id, key=''): 772 """Get the cache key for the chord waiting on group with given id.""" 773 key_t = self.key_t 774 return key_t('').join([ 775 self.chord_keyprefix, key_t(group_id), key_t(key), 776 ]) 777 778 def _strip_prefix(self, key): 779 """Take bytes: emit string.""" 780 key = self.key_t(key) 781 for prefix in self.task_keyprefix, self.group_keyprefix: 782 if key.startswith(prefix): 783 return bytes_to_str(key[len(prefix):]) 784 return bytes_to_str(key) 785 786 def _filter_ready(self, values, READY_STATES=states.READY_STATES): 787 for k, value in values: 788 if value is not None: 789 value = self.decode_result(value) 790 if value['status'] in READY_STATES: 791 yield k, value 792 793 def _mget_to_results(self, values, keys, READY_STATES=states.READY_STATES): 794 if hasattr(values, 'items'): 795 # client returns dict so mapping preserved. 796 return { 797 self._strip_prefix(k): v 798 for k, v in self._filter_ready(items(values), READY_STATES) 799 } 800 else: 801 # client returns list so need to recreate mapping. 802 return { 803 bytes_to_str(keys[i]): v 804 for i, v in self._filter_ready(enumerate(values), READY_STATES) 805 } 806 807 def get_many(self, task_ids, timeout=None, interval=0.5, no_ack=True, 808 on_message=None, on_interval=None, max_iterations=None, 809 READY_STATES=states.READY_STATES): 810 interval = 0.5 if interval is None else interval 811 ids = task_ids if isinstance(task_ids, set) else set(task_ids) 812 cached_ids = set() 813 cache = self._cache 814 for task_id in ids: 815 try: 816 cached = cache[task_id] 817 except KeyError: 818 pass 819 else: 820 if cached['status'] in READY_STATES: 821 yield bytes_to_str(task_id), cached 822 cached_ids.add(task_id) 823 824 ids.difference_update(cached_ids) 825 iterations = 0 826 while ids: 827 keys = list(ids) 828 r = self._mget_to_results(self.mget([self.get_key_for_task(k) 829 for k in keys]), keys, READY_STATES) 830 cache.update(r) 831 ids.difference_update({bytes_to_str(v) for v in r}) 832 for key, value in items(r): 833 if on_message is not None: 834 on_message(value) 835 yield bytes_to_str(key), value 836 if timeout and iterations * interval >= timeout: 837 raise TimeoutError('Operation timed out ({0})'.format(timeout)) 838 if on_interval: 839 on_interval() 840 time.sleep(interval) # don't busy loop. 841 iterations += 1 842 if max_iterations and iterations >= max_iterations: 843 break 844 845 def _forget(self, task_id): 846 self.delete(self.get_key_for_task(task_id)) 847 848 def _store_result(self, task_id, result, state, 849 traceback=None, request=None, **kwargs): 850 meta = self._get_result_meta(result=result, state=state, 851 traceback=traceback, request=request) 852 meta['task_id'] = bytes_to_str(task_id) 853 854 # Retrieve metadata from the backend, if the status 855 # is a success then we ignore any following update to the state. 856 # This solves a task deduplication issue because of network 857 # partitioning or lost workers. This issue involved a race condition 858 # making a lost task overwrite the last successful result in the 859 # result backend. 860 current_meta = self._get_task_meta_for(task_id) 861 862 if current_meta['status'] == states.SUCCESS: 863 return result 864 865 self._set_with_state(self.get_key_for_task(task_id), self.encode(meta), state) 866 return result 867 868 def _save_group(self, group_id, result): 869 self._set_with_state(self.get_key_for_group(group_id), 870 self.encode({'result': result.as_tuple()}), states.SUCCESS) 871 return result 872 873 def _delete_group(self, group_id): 874 self.delete(self.get_key_for_group(group_id)) 875 876 def _get_task_meta_for(self, task_id): 877 """Get task meta-data for a task by id.""" 878 meta = self.get(self.get_key_for_task(task_id)) 879 if not meta: 880 return {'status': states.PENDING, 'result': None} 881 return self.decode_result(meta) 882 883 def _restore_group(self, group_id): 884 """Get task meta-data for a task by id.""" 885 meta = self.get(self.get_key_for_group(group_id)) 886 # previously this was always pickled, but later this 887 # was extended to support other serializers, so the 888 # structure is kind of weird. 889 if meta: 890 meta = self.decode(meta) 891 result = meta['result'] 892 meta['result'] = result_from_tuple(result, self.app) 893 return meta 894 895 def _apply_chord_incr(self, header_result, body, **kwargs): 896 self.ensure_chords_allowed() 897 header_result.save(backend=self) 898 899 def on_chord_part_return(self, request, state, result, **kwargs): 900 if not self.implements_incr: 901 return 902 app = self.app 903 gid = request.group 904 if not gid: 905 return 906 key = self.get_key_for_chord(gid) 907 try: 908 deps = GroupResult.restore(gid, backend=self) 909 except Exception as exc: # pylint: disable=broad-except 910 callback = maybe_signature(request.chord, app=app) 911 logger.exception('Chord %r raised: %r', gid, exc) 912 return self.chord_error_from_stack( 913 callback, 914 ChordError('Cannot restore group: {0!r}'.format(exc)), 915 ) 916 if deps is None: 917 try: 918 raise ValueError(gid) 919 except ValueError as exc: 920 callback = maybe_signature(request.chord, app=app) 921 logger.exception('Chord callback %r raised: %r', gid, exc) 922 return self.chord_error_from_stack( 923 callback, 924 ChordError('GroupResult {0} no longer exists'.format(gid)), 925 ) 926 val = self.incr(key) 927 size = len(deps) 928 if val > size: # pragma: no cover 929 logger.warning('Chord counter incremented too many times for %r', 930 gid) 931 elif val == size: 932 callback = maybe_signature(request.chord, app=app) 933 j = deps.join_native if deps.supports_native_join else deps.join 934 try: 935 with allow_join_result(): 936 ret = j(timeout=3.0, propagate=True) 937 except Exception as exc: # pylint: disable=broad-except 938 try: 939 culprit = next(deps._failed_join_report()) 940 reason = 'Dependency {0.id} raised {1!r}'.format( 941 culprit, exc, 942 ) 943 except StopIteration: 944 reason = repr(exc) 945 946 logger.exception('Chord %r raised: %r', gid, reason) 947 self.chord_error_from_stack(callback, ChordError(reason)) 948 else: 949 try: 950 callback.delay(ret) 951 except Exception as exc: # pylint: disable=broad-except 952 logger.exception('Chord %r raised: %r', gid, exc) 953 self.chord_error_from_stack( 954 callback, 955 ChordError('Callback error: {0!r}'.format(exc)), 956 ) 957 finally: 958 deps.delete() 959 self.client.delete(key) 960 else: 961 self.expire(key, self.expires) 962 963 964class KeyValueStoreBackend(BaseKeyValueStoreBackend, SyncBackendMixin): 965 """Result backend base class for key/value stores.""" 966 967 968class DisabledBackend(BaseBackend): 969 """Dummy result backend.""" 970 971 _cache = {} # need this attribute to reset cache in tests. 972 973 def store_result(self, *args, **kwargs): 974 pass 975 976 def ensure_chords_allowed(self): 977 raise NotImplementedError(E_CHORD_NO_BACKEND.strip()) 978 979 def _is_disabled(self, *args, **kwargs): 980 raise NotImplementedError(E_NO_BACKEND.strip()) 981 982 def as_uri(self, *args, **kwargs): 983 return 'disabled://' 984 985 get_state = get_status = get_result = get_traceback = _is_disabled 986 get_task_meta_for = wait_for = get_many = _is_disabled 987