1# -*- coding: utf-8 -*- 2"""Redis result store backend.""" 3from __future__ import absolute_import, unicode_literals 4 5import time 6from contextlib import contextmanager 7from functools import partial 8from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED 9 10from kombu.utils.functional import retry_over_time 11from kombu.utils.objects import cached_property 12from kombu.utils.url import _parse_url 13 14from celery import states 15from celery._state import task_join_will_block 16from celery.canvas import maybe_signature 17from celery.exceptions import ChordError, ImproperlyConfigured 18from celery.five import string_t, text_t 19from celery.utils import deprecated 20from celery.utils.functional import dictfilter 21from celery.utils.log import get_logger 22from celery.utils.time import humanize_seconds 23 24from .asynchronous import AsyncBackendMixin, BaseResultConsumer 25from .base import BaseKeyValueStoreBackend 26 27try: 28 from urllib.parse import unquote 29except ImportError: 30 # Python 2 31 from urlparse import unquote 32 33try: 34 import redis.connection 35 from kombu.transport.redis import get_redis_error_classes 36except ImportError: # pragma: no cover 37 redis = None # noqa 38 get_redis_error_classes = None # noqa 39 40try: 41 import redis.sentinel 42except ImportError: 43 pass 44 45__all__ = ('RedisBackend', 'SentinelBackend') 46 47E_REDIS_MISSING = """ 48You need to install the redis library in order to use \ 49the Redis result store backend. 50""" 51 52E_REDIS_SENTINEL_MISSING = """ 53You need to install the redis library with support of \ 54sentinel in order to use the Redis result store backend. 55""" 56 57W_REDIS_SSL_CERT_OPTIONAL = """ 58Setting ssl_cert_reqs=CERT_OPTIONAL when connecting to redis means that \ 59celery might not valdate the identity of the redis broker when connecting. \ 60This leaves you vulnerable to man in the middle attacks. 61""" 62 63W_REDIS_SSL_CERT_NONE = """ 64Setting ssl_cert_reqs=CERT_NONE when connecting to redis means that celery \ 65will not valdate the identity of the redis broker when connecting. This \ 66leaves you vulnerable to man in the middle attacks. 67""" 68 69E_REDIS_SSL_PARAMS_AND_SCHEME_MISMATCH = """ 70SSL connection parameters have been provided but the specified URL scheme \ 71is redis://. A Redis SSL connection URL should use the scheme rediss://. 72""" 73 74E_REDIS_SSL_CERT_REQS_MISSING_INVALID = """ 75A rediss:// URL must have parameter ssl_cert_reqs and this must be set to \ 76CERT_REQUIRED, CERT_OPTIONAL, or CERT_NONE 77""" 78 79E_LOST = 'Connection to Redis lost: Retry (%s/%s) %s.' 80 81E_RETRY_LIMIT_EXCEEDED = """ 82Retry limit exceeded while trying to reconnect to the Celery redis result \ 83store backend. The Celery application must be restarted. 84""" 85 86logger = get_logger(__name__) 87 88 89class ResultConsumer(BaseResultConsumer): 90 _pubsub = None 91 92 def __init__(self, *args, **kwargs): 93 super(ResultConsumer, self).__init__(*args, **kwargs) 94 self._get_key_for_task = self.backend.get_key_for_task 95 self._decode_result = self.backend.decode_result 96 self._ensure = self.backend.ensure 97 self._connection_errors = self.backend.connection_errors 98 self.subscribed_to = set() 99 100 def on_after_fork(self): 101 try: 102 self.backend.client.connection_pool.reset() 103 if self._pubsub is not None: 104 self._pubsub.close() 105 except KeyError as e: 106 logger.warning(text_t(e)) 107 super(ResultConsumer, self).on_after_fork() 108 109 def _reconnect_pubsub(self): 110 self._pubsub = None 111 self.backend.client.connection_pool.reset() 112 # task state might have changed when the connection was down so we 113 # retrieve meta for all subscribed tasks before going into pubsub mode 114 metas = self.backend.client.mget(self.subscribed_to) 115 metas = [meta for meta in metas if meta] 116 for meta in metas: 117 self.on_state_change(self._decode_result(meta), None) 118 self._pubsub = self.backend.client.pubsub( 119 ignore_subscribe_messages=True, 120 ) 121 self._pubsub.subscribe(*self.subscribed_to) 122 123 @contextmanager 124 def reconnect_on_error(self): 125 try: 126 yield 127 except self._connection_errors: 128 try: 129 self._ensure(self._reconnect_pubsub, ()) 130 except self._connection_errors: 131 logger.critical(E_RETRY_LIMIT_EXCEEDED) 132 raise 133 134 def _maybe_cancel_ready_task(self, meta): 135 if meta['status'] in states.READY_STATES: 136 self.cancel_for(meta['task_id']) 137 138 def on_state_change(self, meta, message): 139 super(ResultConsumer, self).on_state_change(meta, message) 140 self._maybe_cancel_ready_task(meta) 141 142 def start(self, initial_task_id, **kwargs): 143 self._pubsub = self.backend.client.pubsub( 144 ignore_subscribe_messages=True, 145 ) 146 self._consume_from(initial_task_id) 147 148 def on_wait_for_pending(self, result, **kwargs): 149 for meta in result._iter_meta(**kwargs): 150 if meta is not None: 151 self.on_state_change(meta, None) 152 153 def stop(self): 154 if self._pubsub is not None: 155 self._pubsub.close() 156 157 def drain_events(self, timeout=None): 158 if self._pubsub: 159 with self.reconnect_on_error(): 160 message = self._pubsub.get_message(timeout=timeout) 161 if message and message['type'] == 'message': 162 self.on_state_change(self._decode_result(message['data']), message) 163 elif timeout: 164 time.sleep(timeout) 165 166 def consume_from(self, task_id): 167 if self._pubsub is None: 168 return self.start(task_id) 169 self._consume_from(task_id) 170 171 def _consume_from(self, task_id): 172 key = self._get_key_for_task(task_id) 173 if key not in self.subscribed_to: 174 self.subscribed_to.add(key) 175 with self.reconnect_on_error(): 176 self._pubsub.subscribe(key) 177 178 def cancel_for(self, task_id): 179 key = self._get_key_for_task(task_id) 180 self.subscribed_to.discard(key) 181 if self._pubsub: 182 with self.reconnect_on_error(): 183 self._pubsub.unsubscribe(key) 184 185 186class RedisBackend(BaseKeyValueStoreBackend, AsyncBackendMixin): 187 """Redis task result store. 188 189 It makes use of the following commands: 190 GET, MGET, DEL, INCRBY, EXPIRE, SET, SETEX 191 """ 192 193 ResultConsumer = ResultConsumer 194 195 #: :pypi:`redis` client module. 196 redis = redis 197 198 #: Maximum number of connections in the pool. 199 max_connections = None 200 201 supports_autoexpire = True 202 supports_native_join = True 203 204 def __init__(self, host=None, port=None, db=None, password=None, 205 max_connections=None, url=None, 206 connection_pool=None, **kwargs): 207 super(RedisBackend, self).__init__(expires_type=int, **kwargs) 208 _get = self.app.conf.get 209 if self.redis is None: 210 raise ImproperlyConfigured(E_REDIS_MISSING.strip()) 211 212 if host and '://' in host: 213 url, host = host, None 214 215 self.max_connections = ( 216 max_connections or 217 _get('redis_max_connections') or 218 self.max_connections) 219 self._ConnectionPool = connection_pool 220 221 socket_timeout = _get('redis_socket_timeout') 222 socket_connect_timeout = _get('redis_socket_connect_timeout') 223 retry_on_timeout = _get('redis_retry_on_timeout') 224 socket_keepalive = _get('redis_socket_keepalive') 225 226 self.connparams = { 227 'host': _get('redis_host') or 'localhost', 228 'port': _get('redis_port') or 6379, 229 'db': _get('redis_db') or 0, 230 'password': _get('redis_password'), 231 'max_connections': self.max_connections, 232 'socket_timeout': socket_timeout and float(socket_timeout), 233 'retry_on_timeout': retry_on_timeout or False, 234 'socket_connect_timeout': 235 socket_connect_timeout and float(socket_connect_timeout), 236 } 237 238 # absent in redis.connection.UnixDomainSocketConnection 239 if socket_keepalive: 240 self.connparams['socket_keepalive'] = socket_keepalive 241 242 # "redis_backend_use_ssl" must be a dict with the keys: 243 # 'ssl_cert_reqs', 'ssl_ca_certs', 'ssl_certfile', 'ssl_keyfile' 244 # (the same as "broker_use_ssl") 245 ssl = _get('redis_backend_use_ssl') 246 if ssl: 247 self.connparams.update(ssl) 248 self.connparams['connection_class'] = redis.SSLConnection 249 250 if url: 251 self.connparams = self._params_from_url(url, self.connparams) 252 253 # If we've received SSL parameters via query string or the 254 # redis_backend_use_ssl dict, check ssl_cert_reqs is valid. If set 255 # via query string ssl_cert_reqs will be a string so convert it here 256 if ('connection_class' in self.connparams and 257 self.connparams['connection_class'] is redis.SSLConnection): 258 ssl_cert_reqs_missing = 'MISSING' 259 ssl_string_to_constant = {'CERT_REQUIRED': CERT_REQUIRED, 260 'CERT_OPTIONAL': CERT_OPTIONAL, 261 'CERT_NONE': CERT_NONE, 262 'required': CERT_REQUIRED, 263 'optional': CERT_OPTIONAL, 264 'none': CERT_NONE} 265 ssl_cert_reqs = self.connparams.get('ssl_cert_reqs', ssl_cert_reqs_missing) 266 ssl_cert_reqs = ssl_string_to_constant.get(ssl_cert_reqs, ssl_cert_reqs) 267 if ssl_cert_reqs not in ssl_string_to_constant.values(): 268 raise ValueError(E_REDIS_SSL_CERT_REQS_MISSING_INVALID) 269 270 if ssl_cert_reqs == CERT_OPTIONAL: 271 logger.warning(W_REDIS_SSL_CERT_OPTIONAL) 272 elif ssl_cert_reqs == CERT_NONE: 273 logger.warning(W_REDIS_SSL_CERT_NONE) 274 self.connparams['ssl_cert_reqs'] = ssl_cert_reqs 275 276 self.url = url 277 278 self.connection_errors, self.channel_errors = ( 279 get_redis_error_classes() if get_redis_error_classes 280 else ((), ())) 281 self.result_consumer = self.ResultConsumer( 282 self, self.app, self.accept, 283 self._pending_results, self._pending_messages, 284 ) 285 286 def _params_from_url(self, url, defaults): 287 scheme, host, port, _, password, path, query = _parse_url(url) 288 connparams = dict( 289 defaults, **dictfilter({ 290 'host': host, 'port': port, 'password': password, 291 'db': query.pop('virtual_host', None)}) 292 ) 293 294 if scheme == 'socket': 295 # use 'path' as path to the socket… in this case 296 # the database number should be given in 'query' 297 connparams.update({ 298 'connection_class': self.redis.UnixDomainSocketConnection, 299 'path': '/' + path, 300 }) 301 # host+port are invalid options when using this connection type. 302 connparams.pop('host', None) 303 connparams.pop('port', None) 304 connparams.pop('socket_connect_timeout') 305 else: 306 connparams['db'] = path 307 308 ssl_param_keys = ['ssl_ca_certs', 'ssl_certfile', 'ssl_keyfile', 309 'ssl_cert_reqs'] 310 311 if scheme == 'redis': 312 # If connparams or query string contain ssl params, raise error 313 if (any(key in connparams for key in ssl_param_keys) or 314 any(key in query for key in ssl_param_keys)): 315 raise ValueError(E_REDIS_SSL_PARAMS_AND_SCHEME_MISMATCH) 316 317 if scheme == 'rediss': 318 connparams['connection_class'] = redis.SSLConnection 319 # The following parameters, if present in the URL, are encoded. We 320 # must add the decoded values to connparams. 321 for ssl_setting in ssl_param_keys: 322 ssl_val = query.pop(ssl_setting, None) 323 if ssl_val: 324 connparams[ssl_setting] = unquote(ssl_val) 325 326 # db may be string and start with / like in kombu. 327 db = connparams.get('db') or 0 328 db = db.strip('/') if isinstance(db, string_t) else db 329 connparams['db'] = int(db) 330 331 for key, value in query.items(): 332 if key in redis.connection.URL_QUERY_ARGUMENT_PARSERS: 333 query[key] = redis.connection.URL_QUERY_ARGUMENT_PARSERS[key]( 334 value 335 ) 336 337 # Query parameters override other parameters 338 connparams.update(query) 339 return connparams 340 341 def on_task_call(self, producer, task_id): 342 if not task_join_will_block(): 343 self.result_consumer.consume_from(task_id) 344 345 def get(self, key): 346 return self.client.get(key) 347 348 def mget(self, keys): 349 return self.client.mget(keys) 350 351 def ensure(self, fun, args, **policy): 352 retry_policy = dict(self.retry_policy, **policy) 353 max_retries = retry_policy.get('max_retries') 354 return retry_over_time( 355 fun, self.connection_errors, args, {}, 356 partial(self.on_connection_error, max_retries), 357 **retry_policy) 358 359 def on_connection_error(self, max_retries, exc, intervals, retries): 360 tts = next(intervals) 361 logger.error( 362 E_LOST.strip(), 363 retries, max_retries or 'Inf', humanize_seconds(tts, 'in ')) 364 return tts 365 366 def set(self, key, value, **retry_policy): 367 return self.ensure(self._set, (key, value), **retry_policy) 368 369 def _set(self, key, value): 370 with self.client.pipeline() as pipe: 371 if self.expires: 372 pipe.setex(key, self.expires, value) 373 else: 374 pipe.set(key, value) 375 pipe.publish(key, value) 376 pipe.execute() 377 378 def forget(self, task_id): 379 super(RedisBackend, self).forget(task_id) 380 self.result_consumer.cancel_for(task_id) 381 382 def delete(self, key): 383 self.client.delete(key) 384 385 def incr(self, key): 386 return self.client.incr(key) 387 388 def expire(self, key, value): 389 return self.client.expire(key, value) 390 391 def add_to_chord(self, group_id, result): 392 self.client.incr(self.get_key_for_group(group_id, '.t'), 1) 393 394 def _unpack_chord_result(self, tup, decode, 395 EXCEPTION_STATES=states.EXCEPTION_STATES, 396 PROPAGATE_STATES=states.PROPAGATE_STATES): 397 _, tid, state, retval = decode(tup) 398 if state in EXCEPTION_STATES: 399 retval = self.exception_to_python(retval) 400 if state in PROPAGATE_STATES: 401 raise ChordError('Dependency {0} raised {1!r}'.format(tid, retval)) 402 return retval 403 404 def apply_chord(self, header_result, body, **kwargs): 405 # Overrides this to avoid calling GroupResult.save 406 # pylint: disable=method-hidden 407 # Note that KeyValueStoreBackend.__init__ sets self.apply_chord 408 # if the implements_incr attr is set. Redis backend doesn't set 409 # this flag. 410 pass 411 412 @cached_property 413 def _chord_zset(self): 414 transport_options = self.app.conf.get( 415 'result_backend_transport_options', {} 416 ) 417 return transport_options.get('result_chord_ordered', False) 418 419 def on_chord_part_return(self, request, state, result, 420 propagate=None, **kwargs): 421 app = self.app 422 tid, gid, group_index = request.id, request.group, request.group_index 423 if not gid or not tid: 424 return 425 if group_index is None: 426 group_index = '+inf' 427 428 client = self.client 429 jkey = self.get_key_for_group(gid, '.j') 430 tkey = self.get_key_for_group(gid, '.t') 431 result = self.encode_result(result, state) 432 with client.pipeline() as pipe: 433 if self._chord_zset: 434 pipeline = (pipe 435 .zadd(jkey, { 436 self.encode([1, tid, state, result]): group_index 437 }) 438 .zcount(jkey, '-inf', '+inf') 439 ) 440 else: 441 pipeline = (pipe 442 .rpush(jkey, self.encode([1, tid, state, result])) 443 .llen(jkey) 444 ) 445 pipeline = pipeline.get(tkey) 446 447 if self.expires is not None: 448 pipeline = pipeline \ 449 .expire(jkey, self.expires) \ 450 .expire(tkey, self.expires) 451 452 _, readycount, totaldiff = pipeline.execute()[:3] 453 454 totaldiff = int(totaldiff or 0) 455 456 try: 457 callback = maybe_signature(request.chord, app=app) 458 total = callback['chord_size'] + totaldiff 459 if readycount == total: 460 decode, unpack = self.decode, self._unpack_chord_result 461 with client.pipeline() as pipe: 462 if self._chord_zset: 463 pipeline = pipe.zrange(jkey, 0, -1) 464 else: 465 pipeline = pipe.lrange(jkey, 0, total) 466 resl, = pipeline.execute() 467 try: 468 callback.delay([unpack(tup, decode) for tup in resl]) 469 with client.pipeline() as pipe: 470 _, _ = pipe \ 471 .delete(jkey) \ 472 .delete(tkey) \ 473 .execute() 474 except Exception as exc: # pylint: disable=broad-except 475 logger.exception( 476 'Chord callback for %r raised: %r', request.group, exc) 477 return self.chord_error_from_stack( 478 callback, 479 ChordError('Callback error: {0!r}'.format(exc)), 480 ) 481 except ChordError as exc: 482 logger.exception('Chord %r raised: %r', request.group, exc) 483 return self.chord_error_from_stack(callback, exc) 484 except Exception as exc: # pylint: disable=broad-except 485 logger.exception('Chord %r raised: %r', request.group, exc) 486 return self.chord_error_from_stack( 487 callback, 488 ChordError('Join error: {0!r}'.format(exc)), 489 ) 490 491 def _create_client(self, **params): 492 return self._get_client()( 493 connection_pool=self._get_pool(**params), 494 ) 495 496 def _get_client(self): 497 return self.redis.StrictRedis 498 499 def _get_pool(self, **params): 500 return self.ConnectionPool(**params) 501 502 @property 503 def ConnectionPool(self): 504 if self._ConnectionPool is None: 505 self._ConnectionPool = self.redis.ConnectionPool 506 return self._ConnectionPool 507 508 @cached_property 509 def client(self): 510 return self._create_client(**self.connparams) 511 512 def __reduce__(self, args=(), kwargs=None): 513 kwargs = {} if not kwargs else kwargs 514 return super(RedisBackend, self).__reduce__( 515 (self.url,), {'expires': self.expires}, 516 ) 517 518 @deprecated.Property(4.0, 5.0) 519 def host(self): 520 return self.connparams['host'] 521 522 @deprecated.Property(4.0, 5.0) 523 def port(self): 524 return self.connparams['port'] 525 526 @deprecated.Property(4.0, 5.0) 527 def db(self): 528 return self.connparams['db'] 529 530 @deprecated.Property(4.0, 5.0) 531 def password(self): 532 return self.connparams['password'] 533 534 535class SentinelBackend(RedisBackend): 536 """Redis sentinel task result store.""" 537 538 sentinel = getattr(redis, "sentinel", None) 539 540 def __init__(self, *args, **kwargs): 541 if self.sentinel is None: 542 raise ImproperlyConfigured(E_REDIS_SENTINEL_MISSING.strip()) 543 544 super(SentinelBackend, self).__init__(*args, **kwargs) 545 546 def _params_from_url(self, url, defaults): 547 # URL looks like sentinel://0.0.0.0:26347/3;sentinel://0.0.0.0:26348/3. 548 chunks = url.split(";") 549 connparams = dict(defaults, hosts=[]) 550 for chunk in chunks: 551 data = super(SentinelBackend, self)._params_from_url( 552 url=chunk, defaults=defaults) 553 connparams['hosts'].append(data) 554 for param in ("host", "port", "db", "password"): 555 connparams.pop(param) 556 557 # Adding db/password in connparams to connect to the correct instance 558 for param in ("db", "password"): 559 if connparams['hosts'] and param in connparams['hosts'][0]: 560 connparams[param] = connparams['hosts'][0].get(param) 561 return connparams 562 563 def _get_sentinel_instance(self, **params): 564 connparams = params.copy() 565 566 hosts = connparams.pop("hosts") 567 result_backend_transport_opts = self.app.conf.get( 568 "result_backend_transport_options", {}) 569 min_other_sentinels = result_backend_transport_opts.get( 570 "min_other_sentinels", 0) 571 sentinel_kwargs = result_backend_transport_opts.get( 572 "sentinel_kwargs", {}) 573 574 sentinel_instance = self.sentinel.Sentinel( 575 [(cp['host'], cp['port']) for cp in hosts], 576 min_other_sentinels=min_other_sentinels, 577 sentinel_kwargs=sentinel_kwargs, 578 **connparams) 579 580 return sentinel_instance 581 582 def _get_pool(self, **params): 583 sentinel_instance = self._get_sentinel_instance(**params) 584 585 result_backend_transport_opts = self.app.conf.get( 586 "result_backend_transport_options", {}) 587 master_name = result_backend_transport_opts.get("master_name", None) 588 589 return sentinel_instance.master_for( 590 service_name=master_name, 591 redis_class=self._get_client(), 592 ).connection_pool 593