1from __future__ import absolute_import, unicode_literals
2
3import json
4import random
5import ssl
6from contextlib import contextmanager
7from datetime import timedelta
8from pickle import dumps, loads
9
10import pytest
11from case import ANY, ContextMock, Mock, call, mock, patch, skip
12
13from celery import signature, states, uuid
14from celery.canvas import Signature
15from celery.exceptions import (ChordError, CPendingDeprecationWarning,
16                               ImproperlyConfigured)
17from celery.utils.collections import AttributeDict
18
19
20def raise_on_second_call(mock, exc, *retval):
21    def on_first_call(*args, **kwargs):
22        mock.side_effect = exc
23        return mock.return_value
24
25    mock.side_effect = on_first_call
26    if retval:
27        mock.return_value, = retval
28
29
30class ConnectionError(Exception):
31    pass
32
33
34class Connection(object):
35    connected = True
36
37    def disconnect(self):
38        self.connected = False
39
40
41class Pipeline(object):
42    def __init__(self, client):
43        self.client = client
44        self.steps = []
45
46    def __getattr__(self, attr):
47        def add_step(*args, **kwargs):
48            self.steps.append((getattr(self.client, attr), args, kwargs))
49            return self
50
51        return add_step
52
53    def __enter__(self):
54        return self
55
56    def __exit__(self, type, value, traceback):
57        pass
58
59    def execute(self):
60        return [step(*a, **kw) for step, a, kw in self.steps]
61
62
63class PubSub(mock.MockCallbacks):
64    def __init__(self, ignore_subscribe_messages=False):
65        self._subscribed_to = set()
66
67    def close(self):
68        self._subscribed_to = set()
69
70    def subscribe(self, *args):
71        self._subscribed_to.update(args)
72
73    def unsubscribe(self, *args):
74        self._subscribed_to.difference_update(args)
75
76    def get_message(self, timeout=None):
77        pass
78
79
80class Redis(mock.MockCallbacks):
81    Connection = Connection
82    Pipeline = Pipeline
83    pubsub = PubSub
84
85    def __init__(self, host=None, port=None, db=None, password=None, **kw):
86        self.host = host
87        self.port = port
88        self.db = db
89        self.password = password
90        self.keyspace = {}
91        self.expiry = {}
92        self.connection = self.Connection()
93
94    def get(self, key):
95        return self.keyspace.get(key)
96
97    def mget(self, keys):
98        return [self.get(key) for key in keys]
99
100    def setex(self, key, expires, value):
101        self.set(key, value)
102        self.expire(key, expires)
103
104    def set(self, key, value):
105        self.keyspace[key] = value
106
107    def expire(self, key, expires):
108        self.expiry[key] = expires
109        return expires
110
111    def delete(self, key):
112        return bool(self.keyspace.pop(key, None))
113
114    def pipeline(self):
115        return self.Pipeline(self)
116
117    def _get_unsorted_list(self, key):
118        # We simply store the values in append (rpush) order
119        return self.keyspace.setdefault(key, list())
120
121    def rpush(self, key, value):
122        self._get_unsorted_list(key).append(value)
123
124    def lrange(self, key, start, stop):
125        return self._get_unsorted_list(key)[start:stop]
126
127    def llen(self, key):
128        return len(self._get_unsorted_list(key))
129
130    def _get_sorted_set(self, key):
131        # We store 2-tuples of (score, value) and sort after each append (zadd)
132        return self.keyspace.setdefault(key, list())
133
134    def zadd(self, key, mapping):
135        # Store elements as 2-tuples with the score first so we can sort it
136        # once the new items have been inserted
137        fake_sorted_set = self._get_sorted_set(key)
138        fake_sorted_set.extend(
139            (score, value) for value, score in mapping.items()
140        )
141        fake_sorted_set.sort()
142
143    def zrange(self, key, start, stop):
144        # `stop` is inclusive in Redis so we use `stop + 1` unless that would
145        # cause us to move from negative (right-most) indicies to positive
146        stop = stop + 1 if stop != -1 else None
147        return [e[1] for e in self._get_sorted_set(key)[start:stop]]
148
149    def zrangebyscore(self, key, min_, max_):
150        return [
151            e[1] for e in self._get_sorted_set(key)
152            if (min_ == "-inf" or e[0] >= min_) and
153            (max_ == "+inf" or e[1] <= max_)
154        ]
155
156    def zcount(self, key, min_, max_):
157        return len(self.zrangebyscore(key, min_, max_))
158
159
160class Sentinel(mock.MockCallbacks):
161    def __init__(self, sentinels, min_other_sentinels=0, sentinel_kwargs=None,
162                 **connection_kwargs):
163        self.sentinel_kwargs = sentinel_kwargs
164        self.sentinels = [Redis(hostname, port, **self.sentinel_kwargs)
165                          for hostname, port in sentinels]
166        self.min_other_sentinels = min_other_sentinels
167        self.connection_kwargs = connection_kwargs
168
169    def master_for(self, service_name, redis_class):
170        return random.choice(self.sentinels)
171
172
173class redis(object):
174    StrictRedis = Redis
175
176    class ConnectionPool(object):
177        def __init__(self, **kwargs):
178            pass
179
180    class UnixDomainSocketConnection(object):
181        def __init__(self, **kwargs):
182            pass
183
184
185class sentinel(object):
186    Sentinel = Sentinel
187
188
189class test_RedisResultConsumer:
190    def get_backend(self):
191        from celery.backends.redis import RedisBackend
192
193        class _RedisBackend(RedisBackend):
194            redis = redis
195
196        return _RedisBackend(app=self.app)
197
198    def get_consumer(self):
199        consumer = self.get_backend().result_consumer
200        consumer._connection_errors = (ConnectionError,)
201        return consumer
202
203    @patch('celery.backends.asynchronous.BaseResultConsumer.on_after_fork')
204    def test_on_after_fork(self, parent_method):
205        consumer = self.get_consumer()
206        consumer.start('none')
207        consumer.on_after_fork()
208        parent_method.assert_called_once()
209        consumer.backend.client.connection_pool.reset.assert_called_once()
210        consumer._pubsub.close.assert_called_once()
211        # PubSub instance not initialized - exception would be raised
212        # when calling .close()
213        consumer._pubsub = None
214        parent_method.reset_mock()
215        consumer.backend.client.connection_pool.reset.reset_mock()
216        consumer.on_after_fork()
217        parent_method.assert_called_once()
218        consumer.backend.client.connection_pool.reset.assert_called_once()
219
220        # Continues on KeyError
221        consumer._pubsub = Mock()
222        consumer._pubsub.close = Mock(side_effect=KeyError)
223        parent_method.reset_mock()
224        consumer.backend.client.connection_pool.reset.reset_mock()
225        consumer.on_after_fork()
226        parent_method.assert_called_once()
227
228    @patch('celery.backends.redis.ResultConsumer.cancel_for')
229    @patch('celery.backends.asynchronous.BaseResultConsumer.on_state_change')
230    def test_on_state_change(self, parent_method, cancel_for):
231        consumer = self.get_consumer()
232        meta = {'task_id': 'testing', 'status': states.SUCCESS}
233        message = 'hello'
234        consumer.on_state_change(meta, message)
235        parent_method.assert_called_once_with(meta, message)
236        cancel_for.assert_called_once_with(meta['task_id'])
237
238        # Does not call cancel_for for other states
239        meta = {'task_id': 'testing2', 'status': states.PENDING}
240        parent_method.reset_mock()
241        cancel_for.reset_mock()
242        consumer.on_state_change(meta, message)
243        parent_method.assert_called_once_with(meta, message)
244        cancel_for.assert_not_called()
245
246    def test_drain_events_before_start(self):
247        consumer = self.get_consumer()
248        # drain_events shouldn't crash when called before start
249        consumer.drain_events(0.001)
250
251    def test_consume_from_connection_error(self):
252        consumer = self.get_consumer()
253        consumer.start('initial')
254        consumer._pubsub.subscribe.side_effect = (ConnectionError(), None)
255        consumer.consume_from('some-task')
256        assert consumer._pubsub._subscribed_to == {b'celery-task-meta-initial', b'celery-task-meta-some-task'}
257
258    def test_cancel_for_connection_error(self):
259        consumer = self.get_consumer()
260        consumer.start('initial')
261        consumer._pubsub.unsubscribe.side_effect = ConnectionError()
262        consumer.consume_from('some-task')
263        consumer.cancel_for('some-task')
264        assert consumer._pubsub._subscribed_to == {b'celery-task-meta-initial'}
265
266    @patch('celery.backends.redis.ResultConsumer.cancel_for')
267    @patch('celery.backends.asynchronous.BaseResultConsumer.on_state_change')
268    def test_drain_events_connection_error(self, parent_on_state_change, cancel_for):
269        meta = {'task_id': 'initial', 'status': states.SUCCESS}
270        consumer = self.get_consumer()
271        consumer.start('initial')
272        consumer.backend._set_with_state(b'celery-task-meta-initial', json.dumps(meta), states.SUCCESS)
273        consumer._pubsub.get_message.side_effect = ConnectionError()
274        consumer.drain_events()
275        parent_on_state_change.assert_called_with(meta, None)
276        assert consumer._pubsub._subscribed_to == {b'celery-task-meta-initial'}
277
278
279class test_RedisBackend:
280    def get_backend(self):
281        from celery.backends.redis import RedisBackend
282
283        class _RedisBackend(RedisBackend):
284            redis = redis
285
286        return _RedisBackend
287
288    def get_E_LOST(self):
289        from celery.backends.redis import E_LOST
290        return E_LOST
291
292    def setup(self):
293        self.Backend = self.get_backend()
294        self.E_LOST = self.get_E_LOST()
295        self.b = self.Backend(app=self.app)
296
297    @pytest.mark.usefixtures('depends_on_current_app')
298    @skip.unless_module('redis')
299    def test_reduce(self):
300        from celery.backends.redis import RedisBackend
301        x = RedisBackend(app=self.app)
302        assert loads(dumps(x))
303
304    def test_no_redis(self):
305        self.Backend.redis = None
306        with pytest.raises(ImproperlyConfigured):
307            self.Backend(app=self.app)
308
309    def test_url(self):
310        self.app.conf.redis_socket_timeout = 30.0
311        self.app.conf.redis_socket_connect_timeout = 100.0
312        x = self.Backend(
313            'redis://:bosco@vandelay.com:123//1', app=self.app,
314        )
315        assert x.connparams
316        assert x.connparams['host'] == 'vandelay.com'
317        assert x.connparams['db'] == 1
318        assert x.connparams['port'] == 123
319        assert x.connparams['password'] == 'bosco'
320        assert x.connparams['socket_timeout'] == 30.0
321        assert x.connparams['socket_connect_timeout'] == 100.0
322
323    @skip.unless_module('redis')
324    def test_timeouts_in_url_coerced(self):
325        x = self.Backend(
326            ('redis://:bosco@vandelay.com:123//1?'
327             'socket_timeout=30&socket_connect_timeout=100'),
328            app=self.app,
329        )
330        assert x.connparams
331        assert x.connparams['host'] == 'vandelay.com'
332        assert x.connparams['db'] == 1
333        assert x.connparams['port'] == 123
334        assert x.connparams['password'] == 'bosco'
335        assert x.connparams['socket_timeout'] == 30
336        assert x.connparams['socket_connect_timeout'] == 100
337
338    @skip.unless_module('redis')
339    def test_socket_url(self):
340        self.app.conf.redis_socket_timeout = 30.0
341        self.app.conf.redis_socket_connect_timeout = 100.0
342        x = self.Backend(
343            'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,
344        )
345        assert x.connparams
346        assert x.connparams['path'] == '/tmp/redis.sock'
347        assert (x.connparams['connection_class'] is
348                redis.UnixDomainSocketConnection)
349        assert 'host' not in x.connparams
350        assert 'port' not in x.connparams
351        assert x.connparams['socket_timeout'] == 30.0
352        assert 'socket_connect_timeout' not in x.connparams
353        assert 'socket_keepalive' not in x.connparams
354        assert x.connparams['db'] == 3
355
356    @skip.unless_module('redis')
357    def test_backend_ssl(self):
358        self.app.conf.redis_backend_use_ssl = {
359            'ssl_cert_reqs': ssl.CERT_REQUIRED,
360            'ssl_ca_certs': '/path/to/ca.crt',
361            'ssl_certfile': '/path/to/client.crt',
362            'ssl_keyfile': '/path/to/client.key',
363        }
364        self.app.conf.redis_socket_timeout = 30.0
365        self.app.conf.redis_socket_connect_timeout = 100.0
366        x = self.Backend(
367            'rediss://:bosco@vandelay.com:123//1', app=self.app,
368        )
369        assert x.connparams
370        assert x.connparams['host'] == 'vandelay.com'
371        assert x.connparams['db'] == 1
372        assert x.connparams['port'] == 123
373        assert x.connparams['password'] == 'bosco'
374        assert x.connparams['socket_timeout'] == 30.0
375        assert x.connparams['socket_connect_timeout'] == 100.0
376        assert x.connparams['ssl_cert_reqs'] == ssl.CERT_REQUIRED
377        assert x.connparams['ssl_ca_certs'] == '/path/to/ca.crt'
378        assert x.connparams['ssl_certfile'] == '/path/to/client.crt'
379        assert x.connparams['ssl_keyfile'] == '/path/to/client.key'
380
381        from redis.connection import SSLConnection
382        assert x.connparams['connection_class'] is SSLConnection
383
384    @skip.unless_module('redis')
385    @pytest.mark.parametrize('cert_str', [
386        "required",
387        "CERT_REQUIRED",
388    ])
389    def test_backend_ssl_certreq_str(self, cert_str):
390        self.app.conf.redis_backend_use_ssl = {
391            'ssl_cert_reqs': cert_str,
392            'ssl_ca_certs': '/path/to/ca.crt',
393            'ssl_certfile': '/path/to/client.crt',
394            'ssl_keyfile': '/path/to/client.key',
395        }
396        self.app.conf.redis_socket_timeout = 30.0
397        self.app.conf.redis_socket_connect_timeout = 100.0
398        x = self.Backend(
399            'rediss://:bosco@vandelay.com:123//1', app=self.app,
400        )
401        assert x.connparams
402        assert x.connparams['host'] == 'vandelay.com'
403        assert x.connparams['db'] == 1
404        assert x.connparams['port'] == 123
405        assert x.connparams['password'] == 'bosco'
406        assert x.connparams['socket_timeout'] == 30.0
407        assert x.connparams['socket_connect_timeout'] == 100.0
408        assert x.connparams['ssl_cert_reqs'] == ssl.CERT_REQUIRED
409        assert x.connparams['ssl_ca_certs'] == '/path/to/ca.crt'
410        assert x.connparams['ssl_certfile'] == '/path/to/client.crt'
411        assert x.connparams['ssl_keyfile'] == '/path/to/client.key'
412
413        from redis.connection import SSLConnection
414        assert x.connparams['connection_class'] is SSLConnection
415
416    @skip.unless_module('redis')
417    @pytest.mark.parametrize('cert_str', [
418        "required",
419        "CERT_REQUIRED",
420    ])
421    def test_backend_ssl_url(self, cert_str):
422        self.app.conf.redis_socket_timeout = 30.0
423        self.app.conf.redis_socket_connect_timeout = 100.0
424        x = self.Backend(
425            'rediss://:bosco@vandelay.com:123//1?ssl_cert_reqs=%s' % cert_str,
426            app=self.app,
427        )
428        assert x.connparams
429        assert x.connparams['host'] == 'vandelay.com'
430        assert x.connparams['db'] == 1
431        assert x.connparams['port'] == 123
432        assert x.connparams['password'] == 'bosco'
433        assert x.connparams['socket_timeout'] == 30.0
434        assert x.connparams['socket_connect_timeout'] == 100.0
435        assert x.connparams['ssl_cert_reqs'] == ssl.CERT_REQUIRED
436
437        from redis.connection import SSLConnection
438        assert x.connparams['connection_class'] is SSLConnection
439
440    @skip.unless_module('redis')
441    @pytest.mark.parametrize('cert_str', [
442        "none",
443        "CERT_NONE",
444    ])
445    def test_backend_ssl_url_options(self, cert_str):
446        x = self.Backend(
447            (
448                'rediss://:bosco@vandelay.com:123//1'
449                '?ssl_cert_reqs={cert_str}'
450                '&ssl_ca_certs=%2Fvar%2Fssl%2Fmyca.pem'
451                '&ssl_certfile=%2Fvar%2Fssl%2Fredis-server-cert.pem'
452                '&ssl_keyfile=%2Fvar%2Fssl%2Fprivate%2Fworker-key.pem'
453            ).format(cert_str=cert_str),
454            app=self.app,
455        )
456        assert x.connparams
457        assert x.connparams['host'] == 'vandelay.com'
458        assert x.connparams['db'] == 1
459        assert x.connparams['port'] == 123
460        assert x.connparams['password'] == 'bosco'
461        assert x.connparams['ssl_cert_reqs'] == ssl.CERT_NONE
462        assert x.connparams['ssl_ca_certs'] == '/var/ssl/myca.pem'
463        assert x.connparams['ssl_certfile'] == '/var/ssl/redis-server-cert.pem'
464        assert x.connparams['ssl_keyfile'] == '/var/ssl/private/worker-key.pem'
465
466    @skip.unless_module('redis')
467    @pytest.mark.parametrize('cert_str', [
468        "optional",
469        "CERT_OPTIONAL",
470    ])
471    def test_backend_ssl_url_cert_none(self, cert_str):
472        x = self.Backend(
473            'rediss://:bosco@vandelay.com:123//1?ssl_cert_reqs=%s' % cert_str,
474            app=self.app,
475        )
476        assert x.connparams
477        assert x.connparams['host'] == 'vandelay.com'
478        assert x.connparams['db'] == 1
479        assert x.connparams['port'] == 123
480        assert x.connparams['ssl_cert_reqs'] == ssl.CERT_OPTIONAL
481
482        from redis.connection import SSLConnection
483        assert x.connparams['connection_class'] is SSLConnection
484
485    @skip.unless_module('redis')
486    @pytest.mark.parametrize("uri", [
487        'rediss://:bosco@vandelay.com:123//1?ssl_cert_reqs=CERT_KITTY_CATS',
488        'rediss://:bosco@vandelay.com:123//1'
489    ])
490    def test_backend_ssl_url_invalid(self, uri):
491        with pytest.raises(ValueError):
492            self.Backend(
493                uri,
494                app=self.app,
495            )
496
497    def test_compat_propertie(self):
498        x = self.Backend(
499            'redis://:bosco@vandelay.com:123//1', app=self.app,
500        )
501        with pytest.warns(CPendingDeprecationWarning):
502            assert x.host == 'vandelay.com'
503        with pytest.warns(CPendingDeprecationWarning):
504            assert x.db == 1
505        with pytest.warns(CPendingDeprecationWarning):
506            assert x.port == 123
507        with pytest.warns(CPendingDeprecationWarning):
508            assert x.password == 'bosco'
509
510    def test_conf_raises_KeyError(self):
511        self.app.conf = AttributeDict({
512            'result_serializer': 'json',
513            'result_cache_max': 1,
514            'result_expires': None,
515            'accept_content': ['json'],
516            'result_accept_content': ['json'],
517        })
518        self.Backend(app=self.app)
519
520    @patch('celery.backends.redis.logger')
521    def test_on_connection_error(self, logger):
522        intervals = iter([10, 20, 30])
523        exc = KeyError()
524        assert self.b.on_connection_error(None, exc, intervals, 1) == 10
525        logger.error.assert_called_with(
526            self.E_LOST, 1, 'Inf', 'in 10.00 seconds')
527        assert self.b.on_connection_error(10, exc, intervals, 2) == 20
528        logger.error.assert_called_with(self.E_LOST, 2, 10, 'in 20.00 seconds')
529        assert self.b.on_connection_error(10, exc, intervals, 3) == 30
530        logger.error.assert_called_with(self.E_LOST, 3, 10, 'in 30.00 seconds')
531
532    def test_incr(self):
533        self.b.client = Mock(name='client')
534        self.b.incr('foo')
535        self.b.client.incr.assert_called_with('foo')
536
537    def test_expire(self):
538        self.b.client = Mock(name='client')
539        self.b.expire('foo', 300)
540        self.b.client.expire.assert_called_with('foo', 300)
541
542    def test_apply_chord(self, unlock='celery.chord_unlock'):
543        self.app.tasks[unlock] = Mock()
544        header_result = self.app.GroupResult(
545            uuid(),
546            [self.app.AsyncResult(x) for x in range(3)],
547        )
548        self.b.apply_chord(header_result, None)
549        assert self.app.tasks[unlock].apply_async.call_count == 0
550
551    def test_unpack_chord_result(self):
552        self.b.exception_to_python = Mock(name='etp')
553        decode = Mock(name='decode')
554        exc = KeyError()
555        tup = decode.return_value = (1, 'id1', states.FAILURE, exc)
556        with pytest.raises(ChordError):
557            self.b._unpack_chord_result(tup, decode)
558        decode.assert_called_with(tup)
559        self.b.exception_to_python.assert_called_with(exc)
560
561        exc = ValueError()
562        tup = decode.return_value = (2, 'id2', states.RETRY, exc)
563        ret = self.b._unpack_chord_result(tup, decode)
564        self.b.exception_to_python.assert_called_with(exc)
565        assert ret is self.b.exception_to_python()
566
567    def test_on_chord_part_return_no_gid_or_tid(self):
568        request = Mock(name='request')
569        request.id = request.group = request.group_index = None
570        assert self.b.on_chord_part_return(request, 'SUCCESS', 10) is None
571
572    def test_ConnectionPool(self):
573        self.b.redis = Mock(name='redis')
574        assert self.b._ConnectionPool is None
575        assert self.b.ConnectionPool is self.b.redis.ConnectionPool
576        assert self.b.ConnectionPool is self.b.redis.ConnectionPool
577
578    def test_expires_defaults_to_config(self):
579        self.app.conf.result_expires = 10
580        b = self.Backend(expires=None, app=self.app)
581        assert b.expires == 10
582
583    def test_expires_is_int(self):
584        b = self.Backend(expires=48, app=self.app)
585        assert b.expires == 48
586
587    def test_add_to_chord(self):
588        b = self.Backend('redis://', app=self.app)
589        gid = uuid()
590        b.add_to_chord(gid, 'sig')
591        b.client.incr.assert_called_with(b.get_key_for_group(gid, '.t'), 1)
592
593    def test_expires_is_None(self):
594        b = self.Backend(expires=None, app=self.app)
595        assert b.expires == self.app.conf.result_expires.total_seconds()
596
597    def test_expires_is_timedelta(self):
598        b = self.Backend(expires=timedelta(minutes=1), app=self.app)
599        assert b.expires == 60
600
601    def test_mget(self):
602        assert self.b.mget(['a', 'b', 'c'])
603        self.b.client.mget.assert_called_with(['a', 'b', 'c'])
604
605    def test_set_no_expire(self):
606        self.b.expires = None
607        self.b._set_with_state('foo', 'bar', states.SUCCESS)
608
609    def create_task(self, i):
610        tid = uuid()
611        task = Mock(name='task-{0}'.format(tid))
612        task.name = 'foobarbaz'
613        self.app.tasks['foobarbaz'] = task
614        task.request.chord = signature(task)
615        task.request.id = tid
616        task.request.chord['chord_size'] = 10
617        task.request.group = 'group_id'
618        task.request.group_index = i
619        return task
620
621    @patch('celery.result.GroupResult.restore')
622    def test_on_chord_part_return(self, restore):
623        tasks = [self.create_task(i) for i in range(10)]
624        random.shuffle(tasks)
625
626        for i in range(10):
627            self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
628            assert self.b.client.rpush.call_count
629            self.b.client.rpush.reset_mock()
630        assert self.b.client.lrange.call_count
631        jkey = self.b.get_key_for_group('group_id', '.j')
632        tkey = self.b.get_key_for_group('group_id', '.t')
633        self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
634        self.b.client.expire.assert_has_calls([
635            call(jkey, 86400), call(tkey, 86400),
636        ])
637
638    @patch('celery.result.GroupResult.restore')
639    def test_on_chord_part_return__unordered(self, restore):
640        self.app.conf.result_backend_transport_options = dict(
641            result_chord_ordered=False,
642        )
643
644        tasks = [self.create_task(i) for i in range(10)]
645        random.shuffle(tasks)
646
647        for i in range(10):
648            self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
649            assert self.b.client.rpush.call_count
650            self.b.client.rpush.reset_mock()
651        assert self.b.client.lrange.call_count
652        jkey = self.b.get_key_for_group('group_id', '.j')
653        tkey = self.b.get_key_for_group('group_id', '.t')
654        self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
655        self.b.client.expire.assert_has_calls([
656            call(jkey, 86400), call(tkey, 86400),
657        ])
658
659    @patch('celery.result.GroupResult.restore')
660    def test_on_chord_part_return__ordered(self, restore):
661        self.app.conf.result_backend_transport_options = dict(
662            result_chord_ordered=True,
663        )
664
665        tasks = [self.create_task(i) for i in range(10)]
666        random.shuffle(tasks)
667
668        for i in range(10):
669            self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
670            assert self.b.client.zadd.call_count
671            self.b.client.zadd.reset_mock()
672        assert self.b.client.zrangebyscore.call_count
673        jkey = self.b.get_key_for_group('group_id', '.j')
674        tkey = self.b.get_key_for_group('group_id', '.t')
675        self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
676        self.b.client.expire.assert_has_calls([
677            call(jkey, 86400), call(tkey, 86400),
678        ])
679
680    @patch('celery.result.GroupResult.restore')
681    def test_on_chord_part_return_no_expiry(self, restore):
682        old_expires = self.b.expires
683        self.b.expires = None
684        tasks = [self.create_task(i) for i in range(10)]
685
686        for i in range(10):
687            self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
688            assert self.b.client.rpush.call_count
689            self.b.client.rpush.reset_mock()
690        assert self.b.client.lrange.call_count
691        jkey = self.b.get_key_for_group('group_id', '.j')
692        tkey = self.b.get_key_for_group('group_id', '.t')
693        self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
694        self.b.client.expire.assert_not_called()
695
696        self.b.expires = old_expires
697
698    @patch('celery.result.GroupResult.restore')
699    def test_on_chord_part_return_no_expiry__unordered(self, restore):
700        self.app.conf.result_backend_transport_options = dict(
701            result_chord_ordered=False,
702        )
703
704        old_expires = self.b.expires
705        self.b.expires = None
706        tasks = [self.create_task(i) for i in range(10)]
707
708        for i in range(10):
709            self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
710            assert self.b.client.rpush.call_count
711            self.b.client.rpush.reset_mock()
712        assert self.b.client.lrange.call_count
713        jkey = self.b.get_key_for_group('group_id', '.j')
714        tkey = self.b.get_key_for_group('group_id', '.t')
715        self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
716        self.b.client.expire.assert_not_called()
717
718        self.b.expires = old_expires
719
720    @patch('celery.result.GroupResult.restore')
721    def test_on_chord_part_return_no_expiry__ordered(self, restore):
722        self.app.conf.result_backend_transport_options = dict(
723            result_chord_ordered=True,
724        )
725
726        old_expires = self.b.expires
727        self.b.expires = None
728        tasks = [self.create_task(i) for i in range(10)]
729
730        for i in range(10):
731            self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
732            assert self.b.client.zadd.call_count
733            self.b.client.zadd.reset_mock()
734        assert self.b.client.zrangebyscore.call_count
735        jkey = self.b.get_key_for_group('group_id', '.j')
736        tkey = self.b.get_key_for_group('group_id', '.t')
737        self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
738        self.b.client.expire.assert_not_called()
739
740        self.b.expires = old_expires
741
742    def test_on_chord_part_return__success(self):
743        with self.chord_context(2) as (_, request, callback):
744            self.b.on_chord_part_return(request, states.SUCCESS, 10)
745            callback.delay.assert_not_called()
746            self.b.on_chord_part_return(request, states.SUCCESS, 20)
747            callback.delay.assert_called_with([10, 20])
748
749    def test_on_chord_part_return__success__unordered(self):
750        self.app.conf.result_backend_transport_options = dict(
751            result_chord_ordered=False,
752        )
753
754        with self.chord_context(2) as (_, request, callback):
755            self.b.on_chord_part_return(request, states.SUCCESS, 10)
756            callback.delay.assert_not_called()
757            self.b.on_chord_part_return(request, states.SUCCESS, 20)
758            callback.delay.assert_called_with([10, 20])
759
760    def test_on_chord_part_return__success__ordered(self):
761        self.app.conf.result_backend_transport_options = dict(
762            result_chord_ordered=True,
763        )
764
765        with self.chord_context(2) as (_, request, callback):
766            self.b.on_chord_part_return(request, states.SUCCESS, 10)
767            callback.delay.assert_not_called()
768            self.b.on_chord_part_return(request, states.SUCCESS, 20)
769            callback.delay.assert_called_with([10, 20])
770
771    def test_on_chord_part_return__callback_raises(self):
772        with self.chord_context(1) as (_, request, callback):
773            callback.delay.side_effect = KeyError(10)
774            task = self.app._tasks['add'] = Mock(name='add_task')
775            self.b.on_chord_part_return(request, states.SUCCESS, 10)
776            task.backend.fail_from_current_stack.assert_called_with(
777                callback.id, exc=ANY,
778            )
779
780    def test_on_chord_part_return__callback_raises__unordered(self):
781        self.app.conf.result_backend_transport_options = dict(
782            result_chord_ordered=False,
783        )
784
785        with self.chord_context(1) as (_, request, callback):
786            callback.delay.side_effect = KeyError(10)
787            task = self.app._tasks['add'] = Mock(name='add_task')
788            self.b.on_chord_part_return(request, states.SUCCESS, 10)
789            task.backend.fail_from_current_stack.assert_called_with(
790                callback.id, exc=ANY,
791            )
792
793    def test_on_chord_part_return__callback_raises__ordered(self):
794        self.app.conf.result_backend_transport_options = dict(
795            result_chord_ordered=True,
796        )
797
798        with self.chord_context(1) as (_, request, callback):
799            callback.delay.side_effect = KeyError(10)
800            task = self.app._tasks['add'] = Mock(name='add_task')
801            self.b.on_chord_part_return(request, states.SUCCESS, 10)
802            task.backend.fail_from_current_stack.assert_called_with(
803                callback.id, exc=ANY,
804            )
805
806    def test_on_chord_part_return__ChordError(self):
807        with self.chord_context(1) as (_, request, callback):
808            self.b.client.pipeline = ContextMock()
809            raise_on_second_call(self.b.client.pipeline, ChordError())
810            self.b.client.pipeline.return_value.rpush().llen().get().expire(
811            ).expire().execute.return_value = (1, 1, 0, 4, 5)
812            task = self.app._tasks['add'] = Mock(name='add_task')
813            self.b.on_chord_part_return(request, states.SUCCESS, 10)
814            task.backend.fail_from_current_stack.assert_called_with(
815                callback.id, exc=ANY,
816            )
817
818    def test_on_chord_part_return__ChordError__unordered(self):
819        self.app.conf.result_backend_transport_options = dict(
820            result_chord_ordered=False,
821        )
822
823        with self.chord_context(1) as (_, request, callback):
824            self.b.client.pipeline = ContextMock()
825            raise_on_second_call(self.b.client.pipeline, ChordError())
826            self.b.client.pipeline.return_value.rpush().llen().get().expire(
827            ).expire().execute.return_value = (1, 1, 0, 4, 5)
828            task = self.app._tasks['add'] = Mock(name='add_task')
829            self.b.on_chord_part_return(request, states.SUCCESS, 10)
830            task.backend.fail_from_current_stack.assert_called_with(
831                callback.id, exc=ANY,
832            )
833
834    def test_on_chord_part_return__ChordError__ordered(self):
835        self.app.conf.result_backend_transport_options = dict(
836            result_chord_ordered=True,
837        )
838
839        with self.chord_context(1) as (_, request, callback):
840            self.b.client.pipeline = ContextMock()
841            raise_on_second_call(self.b.client.pipeline, ChordError())
842            self.b.client.pipeline.return_value.zadd().zcount().get().expire(
843            ).expire().execute.return_value = (1, 1, 0, 4, 5)
844            task = self.app._tasks['add'] = Mock(name='add_task')
845            self.b.on_chord_part_return(request, states.SUCCESS, 10)
846            task.backend.fail_from_current_stack.assert_called_with(
847                callback.id, exc=ANY,
848            )
849
850    def test_on_chord_part_return__other_error(self):
851        with self.chord_context(1) as (_, request, callback):
852            self.b.client.pipeline = ContextMock()
853            raise_on_second_call(self.b.client.pipeline, RuntimeError())
854            self.b.client.pipeline.return_value.rpush().llen().get().expire(
855            ).expire().execute.return_value = (1, 1, 0, 4, 5)
856            task = self.app._tasks['add'] = Mock(name='add_task')
857            self.b.on_chord_part_return(request, states.SUCCESS, 10)
858            task.backend.fail_from_current_stack.assert_called_with(
859                callback.id, exc=ANY,
860            )
861
862    def test_on_chord_part_return__other_error__unordered(self):
863        self.app.conf.result_backend_transport_options = dict(
864            result_chord_ordered=False,
865        )
866
867        with self.chord_context(1) as (_, request, callback):
868            self.b.client.pipeline = ContextMock()
869            raise_on_second_call(self.b.client.pipeline, RuntimeError())
870            self.b.client.pipeline.return_value.rpush().llen().get().expire(
871            ).expire().execute.return_value = (1, 1, 0, 4, 5)
872            task = self.app._tasks['add'] = Mock(name='add_task')
873            self.b.on_chord_part_return(request, states.SUCCESS, 10)
874            task.backend.fail_from_current_stack.assert_called_with(
875                callback.id, exc=ANY,
876            )
877
878    def test_on_chord_part_return__other_error__ordered(self):
879        self.app.conf.result_backend_transport_options = dict(
880            result_chord_ordered=True,
881        )
882
883        with self.chord_context(1) as (_, request, callback):
884            self.b.client.pipeline = ContextMock()
885            raise_on_second_call(self.b.client.pipeline, RuntimeError())
886            self.b.client.pipeline.return_value.zadd().zcount().get().expire(
887            ).expire().execute.return_value = (1, 1, 0, 4, 5)
888            task = self.app._tasks['add'] = Mock(name='add_task')
889            self.b.on_chord_part_return(request, states.SUCCESS, 10)
890            task.backend.fail_from_current_stack.assert_called_with(
891                callback.id, exc=ANY,
892            )
893
894    @contextmanager
895    def chord_context(self, size=1):
896        with patch('celery.backends.redis.maybe_signature') as ms:
897            tasks = [self.create_task(i) for i in range(size)]
898            request = Mock(name='request')
899            request.id = 'id1'
900            request.group = 'gid1'
901            request.group_index = None
902            callback = ms.return_value = Signature('add')
903            callback.id = 'id1'
904            callback['chord_size'] = size
905            callback.delay = Mock(name='callback.delay')
906            yield tasks, request, callback
907
908    def test_process_cleanup(self):
909        self.b.process_cleanup()
910
911    def test_get_set_forget(self):
912        tid = uuid()
913        self.b.store_result(tid, 42, states.SUCCESS)
914        assert self.b.get_state(tid) == states.SUCCESS
915        assert self.b.get_result(tid) == 42
916        self.b.forget(tid)
917        assert self.b.get_state(tid) == states.PENDING
918
919    def test_set_expires(self):
920        self.b = self.Backend(expires=512, app=self.app)
921        tid = uuid()
922        key = self.b.get_key_for_task(tid)
923        self.b.store_result(tid, 42, states.SUCCESS)
924        self.b.client.expire.assert_called_with(
925            key, 512,
926        )
927
928
929class test_SentinelBackend:
930    def get_backend(self):
931        from celery.backends.redis import SentinelBackend
932
933        class _SentinelBackend(SentinelBackend):
934            redis = redis
935            sentinel = sentinel
936
937        return _SentinelBackend
938
939    def get_E_LOST(self):
940        from celery.backends.redis import E_LOST
941        return E_LOST
942
943    def setup(self):
944        self.Backend = self.get_backend()
945        self.E_LOST = self.get_E_LOST()
946        self.b = self.Backend(app=self.app)
947
948    @pytest.mark.usefixtures('depends_on_current_app')
949    @skip.unless_module('redis')
950    def test_reduce(self):
951        from celery.backends.redis import SentinelBackend
952        x = SentinelBackend(app=self.app)
953        assert loads(dumps(x))
954
955    def test_no_redis(self):
956        self.Backend.redis = None
957        with pytest.raises(ImproperlyConfigured):
958            self.Backend(app=self.app)
959
960    def test_url(self):
961        self.app.conf.redis_socket_timeout = 30.0
962        self.app.conf.redis_socket_connect_timeout = 100.0
963        x = self.Backend(
964            'sentinel://:test@github.com:123/1;'
965            'sentinel://:test@github.com:124/1',
966            app=self.app,
967        )
968        assert x.connparams
969        assert "host" not in x.connparams
970        assert x.connparams['db'] == 1
971        assert "port" not in x.connparams
972        assert x.connparams['password'] == "test"
973        assert len(x.connparams['hosts']) == 2
974        expected_hosts = ["github.com", "github.com"]
975        found_hosts = [cp['host'] for cp in x.connparams['hosts']]
976        assert found_hosts == expected_hosts
977
978        expected_ports = [123, 124]
979        found_ports = [cp['port'] for cp in x.connparams['hosts']]
980        assert found_ports == expected_ports
981
982        expected_passwords = ["test", "test"]
983        found_passwords = [cp['password'] for cp in x.connparams['hosts']]
984        assert found_passwords == expected_passwords
985
986        expected_dbs = [1, 1]
987        found_dbs = [cp['db'] for cp in x.connparams['hosts']]
988        assert found_dbs == expected_dbs
989
990    def test_get_sentinel_instance(self):
991        x = self.Backend(
992            'sentinel://:test@github.com:123/1;'
993            'sentinel://:test@github.com:124/1',
994            app=self.app,
995        )
996        sentinel_instance = x._get_sentinel_instance(**x.connparams)
997        assert sentinel_instance.sentinel_kwargs == {}
998        assert sentinel_instance.connection_kwargs['db'] == 1
999        assert sentinel_instance.connection_kwargs['password'] == "test"
1000        assert len(sentinel_instance.sentinels) == 2
1001
1002    def test_get_pool(self):
1003        x = self.Backend(
1004            'sentinel://:test@github.com:123/1;'
1005            'sentinel://:test@github.com:124/1',
1006            app=self.app,
1007        )
1008        pool = x._get_pool(**x.connparams)
1009        assert pool
1010