1# -*- coding: utf-8 -*-
2"""Redis result store backend."""
3from __future__ import absolute_import, unicode_literals
4
5import time
6from contextlib import contextmanager
7from functools import partial
8from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
9
10from kombu.utils.functional import retry_over_time
11from kombu.utils.objects import cached_property
12from kombu.utils.url import _parse_url
13
14from celery import states
15from celery._state import task_join_will_block
16from celery.canvas import maybe_signature
17from celery.exceptions import ChordError, ImproperlyConfigured
18from celery.five import string_t, text_t
19from celery.utils import deprecated
20from celery.utils.functional import dictfilter
21from celery.utils.log import get_logger
22from celery.utils.time import humanize_seconds
23
24from .asynchronous import AsyncBackendMixin, BaseResultConsumer
25from .base import BaseKeyValueStoreBackend
26
27try:
28    from urllib.parse import unquote
29except ImportError:
30    # Python 2
31    from urlparse import unquote
32
33try:
34    import redis.connection
35    from kombu.transport.redis import get_redis_error_classes
36except ImportError:  # pragma: no cover
37    redis = None  # noqa
38    get_redis_error_classes = None  # noqa
39
40try:
41    import redis.sentinel
42except ImportError:
43    pass
44
45__all__ = ('RedisBackend', 'SentinelBackend')
46
47E_REDIS_MISSING = """
48You need to install the redis library in order to use \
49the Redis result store backend.
50"""
51
52E_REDIS_SENTINEL_MISSING = """
53You need to install the redis library with support of \
54sentinel in order to use the Redis result store backend.
55"""
56
57W_REDIS_SSL_CERT_OPTIONAL = """
58Setting ssl_cert_reqs=CERT_OPTIONAL when connecting to redis means that \
59celery might not valdate the identity of the redis broker when connecting. \
60This leaves you vulnerable to man in the middle attacks.
61"""
62
63W_REDIS_SSL_CERT_NONE = """
64Setting ssl_cert_reqs=CERT_NONE when connecting to redis means that celery \
65will not valdate the identity of the redis broker when connecting. This \
66leaves you vulnerable to man in the middle attacks.
67"""
68
69E_REDIS_SSL_PARAMS_AND_SCHEME_MISMATCH = """
70SSL connection parameters have been provided but the specified URL scheme \
71is redis://. A Redis SSL connection URL should use the scheme rediss://.
72"""
73
74E_REDIS_SSL_CERT_REQS_MISSING_INVALID = """
75A rediss:// URL must have parameter ssl_cert_reqs and this must be set to \
76CERT_REQUIRED, CERT_OPTIONAL, or CERT_NONE
77"""
78
79E_LOST = 'Connection to Redis lost: Retry (%s/%s) %s.'
80
81E_RETRY_LIMIT_EXCEEDED = """
82Retry limit exceeded while trying to reconnect to the Celery redis result \
83store backend. The Celery application must be restarted.
84"""
85
86logger = get_logger(__name__)
87
88
89class ResultConsumer(BaseResultConsumer):
90    _pubsub = None
91
92    def __init__(self, *args, **kwargs):
93        super(ResultConsumer, self).__init__(*args, **kwargs)
94        self._get_key_for_task = self.backend.get_key_for_task
95        self._decode_result = self.backend.decode_result
96        self._ensure = self.backend.ensure
97        self._connection_errors = self.backend.connection_errors
98        self.subscribed_to = set()
99
100    def on_after_fork(self):
101        try:
102            self.backend.client.connection_pool.reset()
103            if self._pubsub is not None:
104                self._pubsub.close()
105        except KeyError as e:
106            logger.warning(text_t(e))
107        super(ResultConsumer, self).on_after_fork()
108
109    def _reconnect_pubsub(self):
110        self._pubsub = None
111        self.backend.client.connection_pool.reset()
112        # task state might have changed when the connection was down so we
113        # retrieve meta for all subscribed tasks before going into pubsub mode
114        metas = self.backend.client.mget(self.subscribed_to)
115        metas = [meta for meta in metas if meta]
116        for meta in metas:
117            self.on_state_change(self._decode_result(meta), None)
118        self._pubsub = self.backend.client.pubsub(
119            ignore_subscribe_messages=True,
120        )
121        self._pubsub.subscribe(*self.subscribed_to)
122
123    @contextmanager
124    def reconnect_on_error(self):
125        try:
126            yield
127        except self._connection_errors:
128            try:
129                self._ensure(self._reconnect_pubsub, ())
130            except self._connection_errors:
131                logger.critical(E_RETRY_LIMIT_EXCEEDED)
132                raise
133
134    def _maybe_cancel_ready_task(self, meta):
135        if meta['status'] in states.READY_STATES:
136            self.cancel_for(meta['task_id'])
137
138    def on_state_change(self, meta, message):
139        super(ResultConsumer, self).on_state_change(meta, message)
140        self._maybe_cancel_ready_task(meta)
141
142    def start(self, initial_task_id, **kwargs):
143        self._pubsub = self.backend.client.pubsub(
144            ignore_subscribe_messages=True,
145        )
146        self._consume_from(initial_task_id)
147
148    def on_wait_for_pending(self, result, **kwargs):
149        for meta in result._iter_meta(**kwargs):
150            if meta is not None:
151                self.on_state_change(meta, None)
152
153    def stop(self):
154        if self._pubsub is not None:
155            self._pubsub.close()
156
157    def drain_events(self, timeout=None):
158        if self._pubsub:
159            with self.reconnect_on_error():
160                message = self._pubsub.get_message(timeout=timeout)
161                if message and message['type'] == 'message':
162                    self.on_state_change(self._decode_result(message['data']), message)
163        elif timeout:
164            time.sleep(timeout)
165
166    def consume_from(self, task_id):
167        if self._pubsub is None:
168            return self.start(task_id)
169        self._consume_from(task_id)
170
171    def _consume_from(self, task_id):
172        key = self._get_key_for_task(task_id)
173        if key not in self.subscribed_to:
174            self.subscribed_to.add(key)
175            with self.reconnect_on_error():
176                self._pubsub.subscribe(key)
177
178    def cancel_for(self, task_id):
179        key = self._get_key_for_task(task_id)
180        self.subscribed_to.discard(key)
181        if self._pubsub:
182            with self.reconnect_on_error():
183                self._pubsub.unsubscribe(key)
184
185
186class RedisBackend(BaseKeyValueStoreBackend, AsyncBackendMixin):
187    """Redis task result store.
188
189    It makes use of the following commands:
190    GET, MGET, DEL, INCRBY, EXPIRE, SET, SETEX
191    """
192
193    ResultConsumer = ResultConsumer
194
195    #: :pypi:`redis` client module.
196    redis = redis
197
198    #: Maximum number of connections in the pool.
199    max_connections = None
200
201    supports_autoexpire = True
202    supports_native_join = True
203
204    def __init__(self, host=None, port=None, db=None, password=None,
205                 max_connections=None, url=None,
206                 connection_pool=None, **kwargs):
207        super(RedisBackend, self).__init__(expires_type=int, **kwargs)
208        _get = self.app.conf.get
209        if self.redis is None:
210            raise ImproperlyConfigured(E_REDIS_MISSING.strip())
211
212        if host and '://' in host:
213            url, host = host, None
214
215        self.max_connections = (
216            max_connections or
217            _get('redis_max_connections') or
218            self.max_connections)
219        self._ConnectionPool = connection_pool
220
221        socket_timeout = _get('redis_socket_timeout')
222        socket_connect_timeout = _get('redis_socket_connect_timeout')
223        retry_on_timeout = _get('redis_retry_on_timeout')
224        socket_keepalive = _get('redis_socket_keepalive')
225
226        self.connparams = {
227            'host': _get('redis_host') or 'localhost',
228            'port': _get('redis_port') or 6379,
229            'db': _get('redis_db') or 0,
230            'password': _get('redis_password'),
231            'max_connections': self.max_connections,
232            'socket_timeout': socket_timeout and float(socket_timeout),
233            'retry_on_timeout': retry_on_timeout or False,
234            'socket_connect_timeout':
235                socket_connect_timeout and float(socket_connect_timeout),
236        }
237
238        # absent in redis.connection.UnixDomainSocketConnection
239        if socket_keepalive:
240            self.connparams['socket_keepalive'] = socket_keepalive
241
242        # "redis_backend_use_ssl" must be a dict with the keys:
243        # 'ssl_cert_reqs', 'ssl_ca_certs', 'ssl_certfile', 'ssl_keyfile'
244        # (the same as "broker_use_ssl")
245        ssl = _get('redis_backend_use_ssl')
246        if ssl:
247            self.connparams.update(ssl)
248            self.connparams['connection_class'] = redis.SSLConnection
249
250        if url:
251            self.connparams = self._params_from_url(url, self.connparams)
252
253        # If we've received SSL parameters via query string or the
254        # redis_backend_use_ssl dict, check ssl_cert_reqs is valid. If set
255        # via query string ssl_cert_reqs will be a string so convert it here
256        if ('connection_class' in self.connparams and
257                self.connparams['connection_class'] is redis.SSLConnection):
258            ssl_cert_reqs_missing = 'MISSING'
259            ssl_string_to_constant = {'CERT_REQUIRED': CERT_REQUIRED,
260                                      'CERT_OPTIONAL': CERT_OPTIONAL,
261                                      'CERT_NONE': CERT_NONE,
262                                      'required': CERT_REQUIRED,
263                                      'optional': CERT_OPTIONAL,
264                                      'none': CERT_NONE}
265            ssl_cert_reqs = self.connparams.get('ssl_cert_reqs', ssl_cert_reqs_missing)
266            ssl_cert_reqs = ssl_string_to_constant.get(ssl_cert_reqs, ssl_cert_reqs)
267            if ssl_cert_reqs not in ssl_string_to_constant.values():
268                raise ValueError(E_REDIS_SSL_CERT_REQS_MISSING_INVALID)
269
270            if ssl_cert_reqs == CERT_OPTIONAL:
271                logger.warning(W_REDIS_SSL_CERT_OPTIONAL)
272            elif ssl_cert_reqs == CERT_NONE:
273                logger.warning(W_REDIS_SSL_CERT_NONE)
274            self.connparams['ssl_cert_reqs'] = ssl_cert_reqs
275
276        self.url = url
277
278        self.connection_errors, self.channel_errors = (
279            get_redis_error_classes() if get_redis_error_classes
280            else ((), ()))
281        self.result_consumer = self.ResultConsumer(
282            self, self.app, self.accept,
283            self._pending_results, self._pending_messages,
284        )
285
286    def _params_from_url(self, url, defaults):
287        scheme, host, port, _, password, path, query = _parse_url(url)
288        connparams = dict(
289            defaults, **dictfilter({
290                'host': host, 'port': port, 'password': password,
291                'db': query.pop('virtual_host', None)})
292        )
293
294        if scheme == 'socket':
295            # use 'path' as path to the socket… in this case
296            # the database number should be given in 'query'
297            connparams.update({
298                'connection_class': self.redis.UnixDomainSocketConnection,
299                'path': '/' + path,
300            })
301            # host+port are invalid options when using this connection type.
302            connparams.pop('host', None)
303            connparams.pop('port', None)
304            connparams.pop('socket_connect_timeout')
305        else:
306            connparams['db'] = path
307
308        ssl_param_keys = ['ssl_ca_certs', 'ssl_certfile', 'ssl_keyfile',
309                          'ssl_cert_reqs']
310
311        if scheme == 'redis':
312            # If connparams or query string contain ssl params, raise error
313            if (any(key in connparams for key in ssl_param_keys) or
314                    any(key in query for key in ssl_param_keys)):
315                raise ValueError(E_REDIS_SSL_PARAMS_AND_SCHEME_MISMATCH)
316
317        if scheme == 'rediss':
318            connparams['connection_class'] = redis.SSLConnection
319            # The following parameters, if present in the URL, are encoded. We
320            # must add the decoded values to connparams.
321            for ssl_setting in ssl_param_keys:
322                ssl_val = query.pop(ssl_setting, None)
323                if ssl_val:
324                    connparams[ssl_setting] = unquote(ssl_val)
325
326        # db may be string and start with / like in kombu.
327        db = connparams.get('db') or 0
328        db = db.strip('/') if isinstance(db, string_t) else db
329        connparams['db'] = int(db)
330
331        for key, value in query.items():
332            if key in redis.connection.URL_QUERY_ARGUMENT_PARSERS:
333                query[key] = redis.connection.URL_QUERY_ARGUMENT_PARSERS[key](
334                    value
335                )
336
337        # Query parameters override other parameters
338        connparams.update(query)
339        return connparams
340
341    def on_task_call(self, producer, task_id):
342        if not task_join_will_block():
343            self.result_consumer.consume_from(task_id)
344
345    def get(self, key):
346        return self.client.get(key)
347
348    def mget(self, keys):
349        return self.client.mget(keys)
350
351    def ensure(self, fun, args, **policy):
352        retry_policy = dict(self.retry_policy, **policy)
353        max_retries = retry_policy.get('max_retries')
354        return retry_over_time(
355            fun, self.connection_errors, args, {},
356            partial(self.on_connection_error, max_retries),
357            **retry_policy)
358
359    def on_connection_error(self, max_retries, exc, intervals, retries):
360        tts = next(intervals)
361        logger.error(
362            E_LOST.strip(),
363            retries, max_retries or 'Inf', humanize_seconds(tts, 'in '))
364        return tts
365
366    def set(self, key, value, **retry_policy):
367        return self.ensure(self._set, (key, value), **retry_policy)
368
369    def _set(self, key, value):
370        with self.client.pipeline() as pipe:
371            if self.expires:
372                pipe.setex(key, self.expires, value)
373            else:
374                pipe.set(key, value)
375            pipe.publish(key, value)
376            pipe.execute()
377
378    def forget(self, task_id):
379        super(RedisBackend, self).forget(task_id)
380        self.result_consumer.cancel_for(task_id)
381
382    def delete(self, key):
383        self.client.delete(key)
384
385    def incr(self, key):
386        return self.client.incr(key)
387
388    def expire(self, key, value):
389        return self.client.expire(key, value)
390
391    def add_to_chord(self, group_id, result):
392        self.client.incr(self.get_key_for_group(group_id, '.t'), 1)
393
394    def _unpack_chord_result(self, tup, decode,
395                             EXCEPTION_STATES=states.EXCEPTION_STATES,
396                             PROPAGATE_STATES=states.PROPAGATE_STATES):
397        _, tid, state, retval = decode(tup)
398        if state in EXCEPTION_STATES:
399            retval = self.exception_to_python(retval)
400        if state in PROPAGATE_STATES:
401            raise ChordError('Dependency {0} raised {1!r}'.format(tid, retval))
402        return retval
403
404    def apply_chord(self, header_result, body, **kwargs):
405        # Overrides this to avoid calling GroupResult.save
406        # pylint: disable=method-hidden
407        # Note that KeyValueStoreBackend.__init__ sets self.apply_chord
408        # if the implements_incr attr is set.  Redis backend doesn't set
409        # this flag.
410        pass
411
412    @cached_property
413    def _chord_zset(self):
414        transport_options = self.app.conf.get(
415            'result_backend_transport_options', {}
416        )
417        return transport_options.get('result_chord_ordered', False)
418
419    def on_chord_part_return(self, request, state, result,
420                             propagate=None, **kwargs):
421        app = self.app
422        tid, gid, group_index = request.id, request.group, request.group_index
423        if not gid or not tid:
424            return
425        if group_index is None:
426            group_index = '+inf'
427
428        client = self.client
429        jkey = self.get_key_for_group(gid, '.j')
430        tkey = self.get_key_for_group(gid, '.t')
431        result = self.encode_result(result, state)
432        with client.pipeline() as pipe:
433            if self._chord_zset:
434                pipeline = (pipe
435                    .zadd(jkey, {
436                        self.encode([1, tid, state, result]): group_index
437                    })
438                    .zcount(jkey, '-inf', '+inf')
439                )
440            else:
441                pipeline = (pipe
442                    .rpush(jkey, self.encode([1, tid, state, result]))
443                    .llen(jkey)
444                )
445            pipeline = pipeline.get(tkey)
446
447            if self.expires is not None:
448                pipeline = pipeline \
449                    .expire(jkey, self.expires) \
450                    .expire(tkey, self.expires)
451
452            _, readycount, totaldiff = pipeline.execute()[:3]
453
454        totaldiff = int(totaldiff or 0)
455
456        try:
457            callback = maybe_signature(request.chord, app=app)
458            total = callback['chord_size'] + totaldiff
459            if readycount == total:
460                decode, unpack = self.decode, self._unpack_chord_result
461                with client.pipeline() as pipe:
462                    if self._chord_zset:
463                        pipeline = pipe.zrange(jkey, 0, -1)
464                    else:
465                        pipeline = pipe.lrange(jkey, 0, total)
466                    resl, = pipeline.execute()
467                try:
468                    callback.delay([unpack(tup, decode) for tup in resl])
469                    with client.pipeline() as pipe:
470                        _, _ = pipe \
471                            .delete(jkey) \
472                            .delete(tkey) \
473                            .execute()
474                except Exception as exc:  # pylint: disable=broad-except
475                    logger.exception(
476                        'Chord callback for %r raised: %r', request.group, exc)
477                    return self.chord_error_from_stack(
478                        callback,
479                        ChordError('Callback error: {0!r}'.format(exc)),
480                    )
481        except ChordError as exc:
482            logger.exception('Chord %r raised: %r', request.group, exc)
483            return self.chord_error_from_stack(callback, exc)
484        except Exception as exc:  # pylint: disable=broad-except
485            logger.exception('Chord %r raised: %r', request.group, exc)
486            return self.chord_error_from_stack(
487                callback,
488                ChordError('Join error: {0!r}'.format(exc)),
489            )
490
491    def _create_client(self, **params):
492        return self._get_client()(
493            connection_pool=self._get_pool(**params),
494        )
495
496    def _get_client(self):
497        return self.redis.StrictRedis
498
499    def _get_pool(self, **params):
500        return self.ConnectionPool(**params)
501
502    @property
503    def ConnectionPool(self):
504        if self._ConnectionPool is None:
505            self._ConnectionPool = self.redis.ConnectionPool
506        return self._ConnectionPool
507
508    @cached_property
509    def client(self):
510        return self._create_client(**self.connparams)
511
512    def __reduce__(self, args=(), kwargs=None):
513        kwargs = {} if not kwargs else kwargs
514        return super(RedisBackend, self).__reduce__(
515            (self.url,), {'expires': self.expires},
516        )
517
518    @deprecated.Property(4.0, 5.0)
519    def host(self):
520        return self.connparams['host']
521
522    @deprecated.Property(4.0, 5.0)
523    def port(self):
524        return self.connparams['port']
525
526    @deprecated.Property(4.0, 5.0)
527    def db(self):
528        return self.connparams['db']
529
530    @deprecated.Property(4.0, 5.0)
531    def password(self):
532        return self.connparams['password']
533
534
535class SentinelBackend(RedisBackend):
536    """Redis sentinel task result store."""
537
538    sentinel = getattr(redis, "sentinel", None)
539
540    def __init__(self, *args, **kwargs):
541        if self.sentinel is None:
542            raise ImproperlyConfigured(E_REDIS_SENTINEL_MISSING.strip())
543
544        super(SentinelBackend, self).__init__(*args, **kwargs)
545
546    def _params_from_url(self, url, defaults):
547        # URL looks like sentinel://0.0.0.0:26347/3;sentinel://0.0.0.0:26348/3.
548        chunks = url.split(";")
549        connparams = dict(defaults, hosts=[])
550        for chunk in chunks:
551            data = super(SentinelBackend, self)._params_from_url(
552                url=chunk, defaults=defaults)
553            connparams['hosts'].append(data)
554        for param in ("host", "port", "db", "password"):
555            connparams.pop(param)
556
557        # Adding db/password in connparams to connect to the correct instance
558        for param in ("db", "password"):
559            if connparams['hosts'] and param in connparams['hosts'][0]:
560                connparams[param] = connparams['hosts'][0].get(param)
561        return connparams
562
563    def _get_sentinel_instance(self, **params):
564        connparams = params.copy()
565
566        hosts = connparams.pop("hosts")
567        result_backend_transport_opts = self.app.conf.get(
568            "result_backend_transport_options", {})
569        min_other_sentinels = result_backend_transport_opts.get(
570            "min_other_sentinels", 0)
571        sentinel_kwargs = result_backend_transport_opts.get(
572            "sentinel_kwargs", {})
573
574        sentinel_instance = self.sentinel.Sentinel(
575            [(cp['host'], cp['port']) for cp in hosts],
576            min_other_sentinels=min_other_sentinels,
577            sentinel_kwargs=sentinel_kwargs,
578            **connparams)
579
580        return sentinel_instance
581
582    def _get_pool(self, **params):
583        sentinel_instance = self._get_sentinel_instance(**params)
584
585        result_backend_transport_opts = self.app.conf.get(
586            "result_backend_transport_options", {})
587        master_name = result_backend_transport_opts.get("master_name", None)
588
589        return sentinel_instance.master_for(
590            service_name=master_name,
591            redis_class=self._get_client(),
592        ).connection_pool
593