1from __future__ import absolute_import, unicode_literals 2 3import json 4import random 5import ssl 6from contextlib import contextmanager 7from datetime import timedelta 8from pickle import dumps, loads 9 10import pytest 11from case import ANY, ContextMock, Mock, call, mock, patch, skip 12 13from celery import signature, states, uuid 14from celery.canvas import Signature 15from celery.exceptions import (ChordError, CPendingDeprecationWarning, 16 ImproperlyConfigured) 17from celery.utils.collections import AttributeDict 18 19 20def raise_on_second_call(mock, exc, *retval): 21 def on_first_call(*args, **kwargs): 22 mock.side_effect = exc 23 return mock.return_value 24 25 mock.side_effect = on_first_call 26 if retval: 27 mock.return_value, = retval 28 29 30class ConnectionError(Exception): 31 pass 32 33 34class Connection(object): 35 connected = True 36 37 def disconnect(self): 38 self.connected = False 39 40 41class Pipeline(object): 42 def __init__(self, client): 43 self.client = client 44 self.steps = [] 45 46 def __getattr__(self, attr): 47 def add_step(*args, **kwargs): 48 self.steps.append((getattr(self.client, attr), args, kwargs)) 49 return self 50 51 return add_step 52 53 def __enter__(self): 54 return self 55 56 def __exit__(self, type, value, traceback): 57 pass 58 59 def execute(self): 60 return [step(*a, **kw) for step, a, kw in self.steps] 61 62 63class PubSub(mock.MockCallbacks): 64 def __init__(self, ignore_subscribe_messages=False): 65 self._subscribed_to = set() 66 67 def close(self): 68 self._subscribed_to = set() 69 70 def subscribe(self, *args): 71 self._subscribed_to.update(args) 72 73 def unsubscribe(self, *args): 74 self._subscribed_to.difference_update(args) 75 76 def get_message(self, timeout=None): 77 pass 78 79 80class Redis(mock.MockCallbacks): 81 Connection = Connection 82 Pipeline = Pipeline 83 pubsub = PubSub 84 85 def __init__(self, host=None, port=None, db=None, password=None, **kw): 86 self.host = host 87 self.port = port 88 self.db = db 89 self.password = password 90 self.keyspace = {} 91 self.expiry = {} 92 self.connection = self.Connection() 93 94 def get(self, key): 95 return self.keyspace.get(key) 96 97 def mget(self, keys): 98 return [self.get(key) for key in keys] 99 100 def setex(self, key, expires, value): 101 self.set(key, value) 102 self.expire(key, expires) 103 104 def set(self, key, value): 105 self.keyspace[key] = value 106 107 def expire(self, key, expires): 108 self.expiry[key] = expires 109 return expires 110 111 def delete(self, key): 112 return bool(self.keyspace.pop(key, None)) 113 114 def pipeline(self): 115 return self.Pipeline(self) 116 117 def _get_unsorted_list(self, key): 118 # We simply store the values in append (rpush) order 119 return self.keyspace.setdefault(key, list()) 120 121 def rpush(self, key, value): 122 self._get_unsorted_list(key).append(value) 123 124 def lrange(self, key, start, stop): 125 return self._get_unsorted_list(key)[start:stop] 126 127 def llen(self, key): 128 return len(self._get_unsorted_list(key)) 129 130 def _get_sorted_set(self, key): 131 # We store 2-tuples of (score, value) and sort after each append (zadd) 132 return self.keyspace.setdefault(key, list()) 133 134 def zadd(self, key, mapping): 135 # Store elements as 2-tuples with the score first so we can sort it 136 # once the new items have been inserted 137 fake_sorted_set = self._get_sorted_set(key) 138 fake_sorted_set.extend( 139 (score, value) for value, score in mapping.items() 140 ) 141 fake_sorted_set.sort() 142 143 def zrange(self, key, start, stop): 144 # `stop` is inclusive in Redis so we use `stop + 1` unless that would 145 # cause us to move from negative (right-most) indicies to positive 146 stop = stop + 1 if stop != -1 else None 147 return [e[1] for e in self._get_sorted_set(key)[start:stop]] 148 149 def zrangebyscore(self, key, min_, max_): 150 return [ 151 e[1] for e in self._get_sorted_set(key) 152 if (min_ == "-inf" or e[0] >= min_) and 153 (max_ == "+inf" or e[1] <= max_) 154 ] 155 156 def zcount(self, key, min_, max_): 157 return len(self.zrangebyscore(key, min_, max_)) 158 159 160class Sentinel(mock.MockCallbacks): 161 def __init__(self, sentinels, min_other_sentinels=0, sentinel_kwargs=None, 162 **connection_kwargs): 163 self.sentinel_kwargs = sentinel_kwargs 164 self.sentinels = [Redis(hostname, port, **self.sentinel_kwargs) 165 for hostname, port in sentinels] 166 self.min_other_sentinels = min_other_sentinels 167 self.connection_kwargs = connection_kwargs 168 169 def master_for(self, service_name, redis_class): 170 return random.choice(self.sentinels) 171 172 173class redis(object): 174 StrictRedis = Redis 175 176 class ConnectionPool(object): 177 def __init__(self, **kwargs): 178 pass 179 180 class UnixDomainSocketConnection(object): 181 def __init__(self, **kwargs): 182 pass 183 184 185class sentinel(object): 186 Sentinel = Sentinel 187 188 189class test_RedisResultConsumer: 190 def get_backend(self): 191 from celery.backends.redis import RedisBackend 192 193 class _RedisBackend(RedisBackend): 194 redis = redis 195 196 return _RedisBackend(app=self.app) 197 198 def get_consumer(self): 199 consumer = self.get_backend().result_consumer 200 consumer._connection_errors = (ConnectionError,) 201 return consumer 202 203 @patch('celery.backends.asynchronous.BaseResultConsumer.on_after_fork') 204 def test_on_after_fork(self, parent_method): 205 consumer = self.get_consumer() 206 consumer.start('none') 207 consumer.on_after_fork() 208 parent_method.assert_called_once() 209 consumer.backend.client.connection_pool.reset.assert_called_once() 210 consumer._pubsub.close.assert_called_once() 211 # PubSub instance not initialized - exception would be raised 212 # when calling .close() 213 consumer._pubsub = None 214 parent_method.reset_mock() 215 consumer.backend.client.connection_pool.reset.reset_mock() 216 consumer.on_after_fork() 217 parent_method.assert_called_once() 218 consumer.backend.client.connection_pool.reset.assert_called_once() 219 220 # Continues on KeyError 221 consumer._pubsub = Mock() 222 consumer._pubsub.close = Mock(side_effect=KeyError) 223 parent_method.reset_mock() 224 consumer.backend.client.connection_pool.reset.reset_mock() 225 consumer.on_after_fork() 226 parent_method.assert_called_once() 227 228 @patch('celery.backends.redis.ResultConsumer.cancel_for') 229 @patch('celery.backends.asynchronous.BaseResultConsumer.on_state_change') 230 def test_on_state_change(self, parent_method, cancel_for): 231 consumer = self.get_consumer() 232 meta = {'task_id': 'testing', 'status': states.SUCCESS} 233 message = 'hello' 234 consumer.on_state_change(meta, message) 235 parent_method.assert_called_once_with(meta, message) 236 cancel_for.assert_called_once_with(meta['task_id']) 237 238 # Does not call cancel_for for other states 239 meta = {'task_id': 'testing2', 'status': states.PENDING} 240 parent_method.reset_mock() 241 cancel_for.reset_mock() 242 consumer.on_state_change(meta, message) 243 parent_method.assert_called_once_with(meta, message) 244 cancel_for.assert_not_called() 245 246 def test_drain_events_before_start(self): 247 consumer = self.get_consumer() 248 # drain_events shouldn't crash when called before start 249 consumer.drain_events(0.001) 250 251 def test_consume_from_connection_error(self): 252 consumer = self.get_consumer() 253 consumer.start('initial') 254 consumer._pubsub.subscribe.side_effect = (ConnectionError(), None) 255 consumer.consume_from('some-task') 256 assert consumer._pubsub._subscribed_to == {b'celery-task-meta-initial', b'celery-task-meta-some-task'} 257 258 def test_cancel_for_connection_error(self): 259 consumer = self.get_consumer() 260 consumer.start('initial') 261 consumer._pubsub.unsubscribe.side_effect = ConnectionError() 262 consumer.consume_from('some-task') 263 consumer.cancel_for('some-task') 264 assert consumer._pubsub._subscribed_to == {b'celery-task-meta-initial'} 265 266 @patch('celery.backends.redis.ResultConsumer.cancel_for') 267 @patch('celery.backends.asynchronous.BaseResultConsumer.on_state_change') 268 def test_drain_events_connection_error(self, parent_on_state_change, cancel_for): 269 meta = {'task_id': 'initial', 'status': states.SUCCESS} 270 consumer = self.get_consumer() 271 consumer.start('initial') 272 consumer.backend._set_with_state(b'celery-task-meta-initial', json.dumps(meta), states.SUCCESS) 273 consumer._pubsub.get_message.side_effect = ConnectionError() 274 consumer.drain_events() 275 parent_on_state_change.assert_called_with(meta, None) 276 assert consumer._pubsub._subscribed_to == {b'celery-task-meta-initial'} 277 278 279class test_RedisBackend: 280 def get_backend(self): 281 from celery.backends.redis import RedisBackend 282 283 class _RedisBackend(RedisBackend): 284 redis = redis 285 286 return _RedisBackend 287 288 def get_E_LOST(self): 289 from celery.backends.redis import E_LOST 290 return E_LOST 291 292 def setup(self): 293 self.Backend = self.get_backend() 294 self.E_LOST = self.get_E_LOST() 295 self.b = self.Backend(app=self.app) 296 297 @pytest.mark.usefixtures('depends_on_current_app') 298 @skip.unless_module('redis') 299 def test_reduce(self): 300 from celery.backends.redis import RedisBackend 301 x = RedisBackend(app=self.app) 302 assert loads(dumps(x)) 303 304 def test_no_redis(self): 305 self.Backend.redis = None 306 with pytest.raises(ImproperlyConfigured): 307 self.Backend(app=self.app) 308 309 def test_url(self): 310 self.app.conf.redis_socket_timeout = 30.0 311 self.app.conf.redis_socket_connect_timeout = 100.0 312 x = self.Backend( 313 'redis://:bosco@vandelay.com:123//1', app=self.app, 314 ) 315 assert x.connparams 316 assert x.connparams['host'] == 'vandelay.com' 317 assert x.connparams['db'] == 1 318 assert x.connparams['port'] == 123 319 assert x.connparams['password'] == 'bosco' 320 assert x.connparams['socket_timeout'] == 30.0 321 assert x.connparams['socket_connect_timeout'] == 100.0 322 323 @skip.unless_module('redis') 324 def test_timeouts_in_url_coerced(self): 325 x = self.Backend( 326 ('redis://:bosco@vandelay.com:123//1?' 327 'socket_timeout=30&socket_connect_timeout=100'), 328 app=self.app, 329 ) 330 assert x.connparams 331 assert x.connparams['host'] == 'vandelay.com' 332 assert x.connparams['db'] == 1 333 assert x.connparams['port'] == 123 334 assert x.connparams['password'] == 'bosco' 335 assert x.connparams['socket_timeout'] == 30 336 assert x.connparams['socket_connect_timeout'] == 100 337 338 @skip.unless_module('redis') 339 def test_socket_url(self): 340 self.app.conf.redis_socket_timeout = 30.0 341 self.app.conf.redis_socket_connect_timeout = 100.0 342 x = self.Backend( 343 'socket:///tmp/redis.sock?virtual_host=/3', app=self.app, 344 ) 345 assert x.connparams 346 assert x.connparams['path'] == '/tmp/redis.sock' 347 assert (x.connparams['connection_class'] is 348 redis.UnixDomainSocketConnection) 349 assert 'host' not in x.connparams 350 assert 'port' not in x.connparams 351 assert x.connparams['socket_timeout'] == 30.0 352 assert 'socket_connect_timeout' not in x.connparams 353 assert 'socket_keepalive' not in x.connparams 354 assert x.connparams['db'] == 3 355 356 @skip.unless_module('redis') 357 def test_backend_ssl(self): 358 self.app.conf.redis_backend_use_ssl = { 359 'ssl_cert_reqs': ssl.CERT_REQUIRED, 360 'ssl_ca_certs': '/path/to/ca.crt', 361 'ssl_certfile': '/path/to/client.crt', 362 'ssl_keyfile': '/path/to/client.key', 363 } 364 self.app.conf.redis_socket_timeout = 30.0 365 self.app.conf.redis_socket_connect_timeout = 100.0 366 x = self.Backend( 367 'rediss://:bosco@vandelay.com:123//1', app=self.app, 368 ) 369 assert x.connparams 370 assert x.connparams['host'] == 'vandelay.com' 371 assert x.connparams['db'] == 1 372 assert x.connparams['port'] == 123 373 assert x.connparams['password'] == 'bosco' 374 assert x.connparams['socket_timeout'] == 30.0 375 assert x.connparams['socket_connect_timeout'] == 100.0 376 assert x.connparams['ssl_cert_reqs'] == ssl.CERT_REQUIRED 377 assert x.connparams['ssl_ca_certs'] == '/path/to/ca.crt' 378 assert x.connparams['ssl_certfile'] == '/path/to/client.crt' 379 assert x.connparams['ssl_keyfile'] == '/path/to/client.key' 380 381 from redis.connection import SSLConnection 382 assert x.connparams['connection_class'] is SSLConnection 383 384 @skip.unless_module('redis') 385 @pytest.mark.parametrize('cert_str', [ 386 "required", 387 "CERT_REQUIRED", 388 ]) 389 def test_backend_ssl_certreq_str(self, cert_str): 390 self.app.conf.redis_backend_use_ssl = { 391 'ssl_cert_reqs': cert_str, 392 'ssl_ca_certs': '/path/to/ca.crt', 393 'ssl_certfile': '/path/to/client.crt', 394 'ssl_keyfile': '/path/to/client.key', 395 } 396 self.app.conf.redis_socket_timeout = 30.0 397 self.app.conf.redis_socket_connect_timeout = 100.0 398 x = self.Backend( 399 'rediss://:bosco@vandelay.com:123//1', app=self.app, 400 ) 401 assert x.connparams 402 assert x.connparams['host'] == 'vandelay.com' 403 assert x.connparams['db'] == 1 404 assert x.connparams['port'] == 123 405 assert x.connparams['password'] == 'bosco' 406 assert x.connparams['socket_timeout'] == 30.0 407 assert x.connparams['socket_connect_timeout'] == 100.0 408 assert x.connparams['ssl_cert_reqs'] == ssl.CERT_REQUIRED 409 assert x.connparams['ssl_ca_certs'] == '/path/to/ca.crt' 410 assert x.connparams['ssl_certfile'] == '/path/to/client.crt' 411 assert x.connparams['ssl_keyfile'] == '/path/to/client.key' 412 413 from redis.connection import SSLConnection 414 assert x.connparams['connection_class'] is SSLConnection 415 416 @skip.unless_module('redis') 417 @pytest.mark.parametrize('cert_str', [ 418 "required", 419 "CERT_REQUIRED", 420 ]) 421 def test_backend_ssl_url(self, cert_str): 422 self.app.conf.redis_socket_timeout = 30.0 423 self.app.conf.redis_socket_connect_timeout = 100.0 424 x = self.Backend( 425 'rediss://:bosco@vandelay.com:123//1?ssl_cert_reqs=%s' % cert_str, 426 app=self.app, 427 ) 428 assert x.connparams 429 assert x.connparams['host'] == 'vandelay.com' 430 assert x.connparams['db'] == 1 431 assert x.connparams['port'] == 123 432 assert x.connparams['password'] == 'bosco' 433 assert x.connparams['socket_timeout'] == 30.0 434 assert x.connparams['socket_connect_timeout'] == 100.0 435 assert x.connparams['ssl_cert_reqs'] == ssl.CERT_REQUIRED 436 437 from redis.connection import SSLConnection 438 assert x.connparams['connection_class'] is SSLConnection 439 440 @skip.unless_module('redis') 441 @pytest.mark.parametrize('cert_str', [ 442 "none", 443 "CERT_NONE", 444 ]) 445 def test_backend_ssl_url_options(self, cert_str): 446 x = self.Backend( 447 ( 448 'rediss://:bosco@vandelay.com:123//1' 449 '?ssl_cert_reqs={cert_str}' 450 '&ssl_ca_certs=%2Fvar%2Fssl%2Fmyca.pem' 451 '&ssl_certfile=%2Fvar%2Fssl%2Fredis-server-cert.pem' 452 '&ssl_keyfile=%2Fvar%2Fssl%2Fprivate%2Fworker-key.pem' 453 ).format(cert_str=cert_str), 454 app=self.app, 455 ) 456 assert x.connparams 457 assert x.connparams['host'] == 'vandelay.com' 458 assert x.connparams['db'] == 1 459 assert x.connparams['port'] == 123 460 assert x.connparams['password'] == 'bosco' 461 assert x.connparams['ssl_cert_reqs'] == ssl.CERT_NONE 462 assert x.connparams['ssl_ca_certs'] == '/var/ssl/myca.pem' 463 assert x.connparams['ssl_certfile'] == '/var/ssl/redis-server-cert.pem' 464 assert x.connparams['ssl_keyfile'] == '/var/ssl/private/worker-key.pem' 465 466 @skip.unless_module('redis') 467 @pytest.mark.parametrize('cert_str', [ 468 "optional", 469 "CERT_OPTIONAL", 470 ]) 471 def test_backend_ssl_url_cert_none(self, cert_str): 472 x = self.Backend( 473 'rediss://:bosco@vandelay.com:123//1?ssl_cert_reqs=%s' % cert_str, 474 app=self.app, 475 ) 476 assert x.connparams 477 assert x.connparams['host'] == 'vandelay.com' 478 assert x.connparams['db'] == 1 479 assert x.connparams['port'] == 123 480 assert x.connparams['ssl_cert_reqs'] == ssl.CERT_OPTIONAL 481 482 from redis.connection import SSLConnection 483 assert x.connparams['connection_class'] is SSLConnection 484 485 @skip.unless_module('redis') 486 @pytest.mark.parametrize("uri", [ 487 'rediss://:bosco@vandelay.com:123//1?ssl_cert_reqs=CERT_KITTY_CATS', 488 'rediss://:bosco@vandelay.com:123//1' 489 ]) 490 def test_backend_ssl_url_invalid(self, uri): 491 with pytest.raises(ValueError): 492 self.Backend( 493 uri, 494 app=self.app, 495 ) 496 497 def test_compat_propertie(self): 498 x = self.Backend( 499 'redis://:bosco@vandelay.com:123//1', app=self.app, 500 ) 501 with pytest.warns(CPendingDeprecationWarning): 502 assert x.host == 'vandelay.com' 503 with pytest.warns(CPendingDeprecationWarning): 504 assert x.db == 1 505 with pytest.warns(CPendingDeprecationWarning): 506 assert x.port == 123 507 with pytest.warns(CPendingDeprecationWarning): 508 assert x.password == 'bosco' 509 510 def test_conf_raises_KeyError(self): 511 self.app.conf = AttributeDict({ 512 'result_serializer': 'json', 513 'result_cache_max': 1, 514 'result_expires': None, 515 'accept_content': ['json'], 516 'result_accept_content': ['json'], 517 }) 518 self.Backend(app=self.app) 519 520 @patch('celery.backends.redis.logger') 521 def test_on_connection_error(self, logger): 522 intervals = iter([10, 20, 30]) 523 exc = KeyError() 524 assert self.b.on_connection_error(None, exc, intervals, 1) == 10 525 logger.error.assert_called_with( 526 self.E_LOST, 1, 'Inf', 'in 10.00 seconds') 527 assert self.b.on_connection_error(10, exc, intervals, 2) == 20 528 logger.error.assert_called_with(self.E_LOST, 2, 10, 'in 20.00 seconds') 529 assert self.b.on_connection_error(10, exc, intervals, 3) == 30 530 logger.error.assert_called_with(self.E_LOST, 3, 10, 'in 30.00 seconds') 531 532 def test_incr(self): 533 self.b.client = Mock(name='client') 534 self.b.incr('foo') 535 self.b.client.incr.assert_called_with('foo') 536 537 def test_expire(self): 538 self.b.client = Mock(name='client') 539 self.b.expire('foo', 300) 540 self.b.client.expire.assert_called_with('foo', 300) 541 542 def test_apply_chord(self, unlock='celery.chord_unlock'): 543 self.app.tasks[unlock] = Mock() 544 header_result = self.app.GroupResult( 545 uuid(), 546 [self.app.AsyncResult(x) for x in range(3)], 547 ) 548 self.b.apply_chord(header_result, None) 549 assert self.app.tasks[unlock].apply_async.call_count == 0 550 551 def test_unpack_chord_result(self): 552 self.b.exception_to_python = Mock(name='etp') 553 decode = Mock(name='decode') 554 exc = KeyError() 555 tup = decode.return_value = (1, 'id1', states.FAILURE, exc) 556 with pytest.raises(ChordError): 557 self.b._unpack_chord_result(tup, decode) 558 decode.assert_called_with(tup) 559 self.b.exception_to_python.assert_called_with(exc) 560 561 exc = ValueError() 562 tup = decode.return_value = (2, 'id2', states.RETRY, exc) 563 ret = self.b._unpack_chord_result(tup, decode) 564 self.b.exception_to_python.assert_called_with(exc) 565 assert ret is self.b.exception_to_python() 566 567 def test_on_chord_part_return_no_gid_or_tid(self): 568 request = Mock(name='request') 569 request.id = request.group = request.group_index = None 570 assert self.b.on_chord_part_return(request, 'SUCCESS', 10) is None 571 572 def test_ConnectionPool(self): 573 self.b.redis = Mock(name='redis') 574 assert self.b._ConnectionPool is None 575 assert self.b.ConnectionPool is self.b.redis.ConnectionPool 576 assert self.b.ConnectionPool is self.b.redis.ConnectionPool 577 578 def test_expires_defaults_to_config(self): 579 self.app.conf.result_expires = 10 580 b = self.Backend(expires=None, app=self.app) 581 assert b.expires == 10 582 583 def test_expires_is_int(self): 584 b = self.Backend(expires=48, app=self.app) 585 assert b.expires == 48 586 587 def test_add_to_chord(self): 588 b = self.Backend('redis://', app=self.app) 589 gid = uuid() 590 b.add_to_chord(gid, 'sig') 591 b.client.incr.assert_called_with(b.get_key_for_group(gid, '.t'), 1) 592 593 def test_expires_is_None(self): 594 b = self.Backend(expires=None, app=self.app) 595 assert b.expires == self.app.conf.result_expires.total_seconds() 596 597 def test_expires_is_timedelta(self): 598 b = self.Backend(expires=timedelta(minutes=1), app=self.app) 599 assert b.expires == 60 600 601 def test_mget(self): 602 assert self.b.mget(['a', 'b', 'c']) 603 self.b.client.mget.assert_called_with(['a', 'b', 'c']) 604 605 def test_set_no_expire(self): 606 self.b.expires = None 607 self.b._set_with_state('foo', 'bar', states.SUCCESS) 608 609 def create_task(self, i): 610 tid = uuid() 611 task = Mock(name='task-{0}'.format(tid)) 612 task.name = 'foobarbaz' 613 self.app.tasks['foobarbaz'] = task 614 task.request.chord = signature(task) 615 task.request.id = tid 616 task.request.chord['chord_size'] = 10 617 task.request.group = 'group_id' 618 task.request.group_index = i 619 return task 620 621 @patch('celery.result.GroupResult.restore') 622 def test_on_chord_part_return(self, restore): 623 tasks = [self.create_task(i) for i in range(10)] 624 random.shuffle(tasks) 625 626 for i in range(10): 627 self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i) 628 assert self.b.client.rpush.call_count 629 self.b.client.rpush.reset_mock() 630 assert self.b.client.lrange.call_count 631 jkey = self.b.get_key_for_group('group_id', '.j') 632 tkey = self.b.get_key_for_group('group_id', '.t') 633 self.b.client.delete.assert_has_calls([call(jkey), call(tkey)]) 634 self.b.client.expire.assert_has_calls([ 635 call(jkey, 86400), call(tkey, 86400), 636 ]) 637 638 @patch('celery.result.GroupResult.restore') 639 def test_on_chord_part_return__unordered(self, restore): 640 self.app.conf.result_backend_transport_options = dict( 641 result_chord_ordered=False, 642 ) 643 644 tasks = [self.create_task(i) for i in range(10)] 645 random.shuffle(tasks) 646 647 for i in range(10): 648 self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i) 649 assert self.b.client.rpush.call_count 650 self.b.client.rpush.reset_mock() 651 assert self.b.client.lrange.call_count 652 jkey = self.b.get_key_for_group('group_id', '.j') 653 tkey = self.b.get_key_for_group('group_id', '.t') 654 self.b.client.delete.assert_has_calls([call(jkey), call(tkey)]) 655 self.b.client.expire.assert_has_calls([ 656 call(jkey, 86400), call(tkey, 86400), 657 ]) 658 659 @patch('celery.result.GroupResult.restore') 660 def test_on_chord_part_return__ordered(self, restore): 661 self.app.conf.result_backend_transport_options = dict( 662 result_chord_ordered=True, 663 ) 664 665 tasks = [self.create_task(i) for i in range(10)] 666 random.shuffle(tasks) 667 668 for i in range(10): 669 self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i) 670 assert self.b.client.zadd.call_count 671 self.b.client.zadd.reset_mock() 672 assert self.b.client.zrangebyscore.call_count 673 jkey = self.b.get_key_for_group('group_id', '.j') 674 tkey = self.b.get_key_for_group('group_id', '.t') 675 self.b.client.delete.assert_has_calls([call(jkey), call(tkey)]) 676 self.b.client.expire.assert_has_calls([ 677 call(jkey, 86400), call(tkey, 86400), 678 ]) 679 680 @patch('celery.result.GroupResult.restore') 681 def test_on_chord_part_return_no_expiry(self, restore): 682 old_expires = self.b.expires 683 self.b.expires = None 684 tasks = [self.create_task(i) for i in range(10)] 685 686 for i in range(10): 687 self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i) 688 assert self.b.client.rpush.call_count 689 self.b.client.rpush.reset_mock() 690 assert self.b.client.lrange.call_count 691 jkey = self.b.get_key_for_group('group_id', '.j') 692 tkey = self.b.get_key_for_group('group_id', '.t') 693 self.b.client.delete.assert_has_calls([call(jkey), call(tkey)]) 694 self.b.client.expire.assert_not_called() 695 696 self.b.expires = old_expires 697 698 @patch('celery.result.GroupResult.restore') 699 def test_on_chord_part_return_no_expiry__unordered(self, restore): 700 self.app.conf.result_backend_transport_options = dict( 701 result_chord_ordered=False, 702 ) 703 704 old_expires = self.b.expires 705 self.b.expires = None 706 tasks = [self.create_task(i) for i in range(10)] 707 708 for i in range(10): 709 self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i) 710 assert self.b.client.rpush.call_count 711 self.b.client.rpush.reset_mock() 712 assert self.b.client.lrange.call_count 713 jkey = self.b.get_key_for_group('group_id', '.j') 714 tkey = self.b.get_key_for_group('group_id', '.t') 715 self.b.client.delete.assert_has_calls([call(jkey), call(tkey)]) 716 self.b.client.expire.assert_not_called() 717 718 self.b.expires = old_expires 719 720 @patch('celery.result.GroupResult.restore') 721 def test_on_chord_part_return_no_expiry__ordered(self, restore): 722 self.app.conf.result_backend_transport_options = dict( 723 result_chord_ordered=True, 724 ) 725 726 old_expires = self.b.expires 727 self.b.expires = None 728 tasks = [self.create_task(i) for i in range(10)] 729 730 for i in range(10): 731 self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i) 732 assert self.b.client.zadd.call_count 733 self.b.client.zadd.reset_mock() 734 assert self.b.client.zrangebyscore.call_count 735 jkey = self.b.get_key_for_group('group_id', '.j') 736 tkey = self.b.get_key_for_group('group_id', '.t') 737 self.b.client.delete.assert_has_calls([call(jkey), call(tkey)]) 738 self.b.client.expire.assert_not_called() 739 740 self.b.expires = old_expires 741 742 def test_on_chord_part_return__success(self): 743 with self.chord_context(2) as (_, request, callback): 744 self.b.on_chord_part_return(request, states.SUCCESS, 10) 745 callback.delay.assert_not_called() 746 self.b.on_chord_part_return(request, states.SUCCESS, 20) 747 callback.delay.assert_called_with([10, 20]) 748 749 def test_on_chord_part_return__success__unordered(self): 750 self.app.conf.result_backend_transport_options = dict( 751 result_chord_ordered=False, 752 ) 753 754 with self.chord_context(2) as (_, request, callback): 755 self.b.on_chord_part_return(request, states.SUCCESS, 10) 756 callback.delay.assert_not_called() 757 self.b.on_chord_part_return(request, states.SUCCESS, 20) 758 callback.delay.assert_called_with([10, 20]) 759 760 def test_on_chord_part_return__success__ordered(self): 761 self.app.conf.result_backend_transport_options = dict( 762 result_chord_ordered=True, 763 ) 764 765 with self.chord_context(2) as (_, request, callback): 766 self.b.on_chord_part_return(request, states.SUCCESS, 10) 767 callback.delay.assert_not_called() 768 self.b.on_chord_part_return(request, states.SUCCESS, 20) 769 callback.delay.assert_called_with([10, 20]) 770 771 def test_on_chord_part_return__callback_raises(self): 772 with self.chord_context(1) as (_, request, callback): 773 callback.delay.side_effect = KeyError(10) 774 task = self.app._tasks['add'] = Mock(name='add_task') 775 self.b.on_chord_part_return(request, states.SUCCESS, 10) 776 task.backend.fail_from_current_stack.assert_called_with( 777 callback.id, exc=ANY, 778 ) 779 780 def test_on_chord_part_return__callback_raises__unordered(self): 781 self.app.conf.result_backend_transport_options = dict( 782 result_chord_ordered=False, 783 ) 784 785 with self.chord_context(1) as (_, request, callback): 786 callback.delay.side_effect = KeyError(10) 787 task = self.app._tasks['add'] = Mock(name='add_task') 788 self.b.on_chord_part_return(request, states.SUCCESS, 10) 789 task.backend.fail_from_current_stack.assert_called_with( 790 callback.id, exc=ANY, 791 ) 792 793 def test_on_chord_part_return__callback_raises__ordered(self): 794 self.app.conf.result_backend_transport_options = dict( 795 result_chord_ordered=True, 796 ) 797 798 with self.chord_context(1) as (_, request, callback): 799 callback.delay.side_effect = KeyError(10) 800 task = self.app._tasks['add'] = Mock(name='add_task') 801 self.b.on_chord_part_return(request, states.SUCCESS, 10) 802 task.backend.fail_from_current_stack.assert_called_with( 803 callback.id, exc=ANY, 804 ) 805 806 def test_on_chord_part_return__ChordError(self): 807 with self.chord_context(1) as (_, request, callback): 808 self.b.client.pipeline = ContextMock() 809 raise_on_second_call(self.b.client.pipeline, ChordError()) 810 self.b.client.pipeline.return_value.rpush().llen().get().expire( 811 ).expire().execute.return_value = (1, 1, 0, 4, 5) 812 task = self.app._tasks['add'] = Mock(name='add_task') 813 self.b.on_chord_part_return(request, states.SUCCESS, 10) 814 task.backend.fail_from_current_stack.assert_called_with( 815 callback.id, exc=ANY, 816 ) 817 818 def test_on_chord_part_return__ChordError__unordered(self): 819 self.app.conf.result_backend_transport_options = dict( 820 result_chord_ordered=False, 821 ) 822 823 with self.chord_context(1) as (_, request, callback): 824 self.b.client.pipeline = ContextMock() 825 raise_on_second_call(self.b.client.pipeline, ChordError()) 826 self.b.client.pipeline.return_value.rpush().llen().get().expire( 827 ).expire().execute.return_value = (1, 1, 0, 4, 5) 828 task = self.app._tasks['add'] = Mock(name='add_task') 829 self.b.on_chord_part_return(request, states.SUCCESS, 10) 830 task.backend.fail_from_current_stack.assert_called_with( 831 callback.id, exc=ANY, 832 ) 833 834 def test_on_chord_part_return__ChordError__ordered(self): 835 self.app.conf.result_backend_transport_options = dict( 836 result_chord_ordered=True, 837 ) 838 839 with self.chord_context(1) as (_, request, callback): 840 self.b.client.pipeline = ContextMock() 841 raise_on_second_call(self.b.client.pipeline, ChordError()) 842 self.b.client.pipeline.return_value.zadd().zcount().get().expire( 843 ).expire().execute.return_value = (1, 1, 0, 4, 5) 844 task = self.app._tasks['add'] = Mock(name='add_task') 845 self.b.on_chord_part_return(request, states.SUCCESS, 10) 846 task.backend.fail_from_current_stack.assert_called_with( 847 callback.id, exc=ANY, 848 ) 849 850 def test_on_chord_part_return__other_error(self): 851 with self.chord_context(1) as (_, request, callback): 852 self.b.client.pipeline = ContextMock() 853 raise_on_second_call(self.b.client.pipeline, RuntimeError()) 854 self.b.client.pipeline.return_value.rpush().llen().get().expire( 855 ).expire().execute.return_value = (1, 1, 0, 4, 5) 856 task = self.app._tasks['add'] = Mock(name='add_task') 857 self.b.on_chord_part_return(request, states.SUCCESS, 10) 858 task.backend.fail_from_current_stack.assert_called_with( 859 callback.id, exc=ANY, 860 ) 861 862 def test_on_chord_part_return__other_error__unordered(self): 863 self.app.conf.result_backend_transport_options = dict( 864 result_chord_ordered=False, 865 ) 866 867 with self.chord_context(1) as (_, request, callback): 868 self.b.client.pipeline = ContextMock() 869 raise_on_second_call(self.b.client.pipeline, RuntimeError()) 870 self.b.client.pipeline.return_value.rpush().llen().get().expire( 871 ).expire().execute.return_value = (1, 1, 0, 4, 5) 872 task = self.app._tasks['add'] = Mock(name='add_task') 873 self.b.on_chord_part_return(request, states.SUCCESS, 10) 874 task.backend.fail_from_current_stack.assert_called_with( 875 callback.id, exc=ANY, 876 ) 877 878 def test_on_chord_part_return__other_error__ordered(self): 879 self.app.conf.result_backend_transport_options = dict( 880 result_chord_ordered=True, 881 ) 882 883 with self.chord_context(1) as (_, request, callback): 884 self.b.client.pipeline = ContextMock() 885 raise_on_second_call(self.b.client.pipeline, RuntimeError()) 886 self.b.client.pipeline.return_value.zadd().zcount().get().expire( 887 ).expire().execute.return_value = (1, 1, 0, 4, 5) 888 task = self.app._tasks['add'] = Mock(name='add_task') 889 self.b.on_chord_part_return(request, states.SUCCESS, 10) 890 task.backend.fail_from_current_stack.assert_called_with( 891 callback.id, exc=ANY, 892 ) 893 894 @contextmanager 895 def chord_context(self, size=1): 896 with patch('celery.backends.redis.maybe_signature') as ms: 897 tasks = [self.create_task(i) for i in range(size)] 898 request = Mock(name='request') 899 request.id = 'id1' 900 request.group = 'gid1' 901 request.group_index = None 902 callback = ms.return_value = Signature('add') 903 callback.id = 'id1' 904 callback['chord_size'] = size 905 callback.delay = Mock(name='callback.delay') 906 yield tasks, request, callback 907 908 def test_process_cleanup(self): 909 self.b.process_cleanup() 910 911 def test_get_set_forget(self): 912 tid = uuid() 913 self.b.store_result(tid, 42, states.SUCCESS) 914 assert self.b.get_state(tid) == states.SUCCESS 915 assert self.b.get_result(tid) == 42 916 self.b.forget(tid) 917 assert self.b.get_state(tid) == states.PENDING 918 919 def test_set_expires(self): 920 self.b = self.Backend(expires=512, app=self.app) 921 tid = uuid() 922 key = self.b.get_key_for_task(tid) 923 self.b.store_result(tid, 42, states.SUCCESS) 924 self.b.client.expire.assert_called_with( 925 key, 512, 926 ) 927 928 929class test_SentinelBackend: 930 def get_backend(self): 931 from celery.backends.redis import SentinelBackend 932 933 class _SentinelBackend(SentinelBackend): 934 redis = redis 935 sentinel = sentinel 936 937 return _SentinelBackend 938 939 def get_E_LOST(self): 940 from celery.backends.redis import E_LOST 941 return E_LOST 942 943 def setup(self): 944 self.Backend = self.get_backend() 945 self.E_LOST = self.get_E_LOST() 946 self.b = self.Backend(app=self.app) 947 948 @pytest.mark.usefixtures('depends_on_current_app') 949 @skip.unless_module('redis') 950 def test_reduce(self): 951 from celery.backends.redis import SentinelBackend 952 x = SentinelBackend(app=self.app) 953 assert loads(dumps(x)) 954 955 def test_no_redis(self): 956 self.Backend.redis = None 957 with pytest.raises(ImproperlyConfigured): 958 self.Backend(app=self.app) 959 960 def test_url(self): 961 self.app.conf.redis_socket_timeout = 30.0 962 self.app.conf.redis_socket_connect_timeout = 100.0 963 x = self.Backend( 964 'sentinel://:test@github.com:123/1;' 965 'sentinel://:test@github.com:124/1', 966 app=self.app, 967 ) 968 assert x.connparams 969 assert "host" not in x.connparams 970 assert x.connparams['db'] == 1 971 assert "port" not in x.connparams 972 assert x.connparams['password'] == "test" 973 assert len(x.connparams['hosts']) == 2 974 expected_hosts = ["github.com", "github.com"] 975 found_hosts = [cp['host'] for cp in x.connparams['hosts']] 976 assert found_hosts == expected_hosts 977 978 expected_ports = [123, 124] 979 found_ports = [cp['port'] for cp in x.connparams['hosts']] 980 assert found_ports == expected_ports 981 982 expected_passwords = ["test", "test"] 983 found_passwords = [cp['password'] for cp in x.connparams['hosts']] 984 assert found_passwords == expected_passwords 985 986 expected_dbs = [1, 1] 987 found_dbs = [cp['db'] for cp in x.connparams['hosts']] 988 assert found_dbs == expected_dbs 989 990 def test_get_sentinel_instance(self): 991 x = self.Backend( 992 'sentinel://:test@github.com:123/1;' 993 'sentinel://:test@github.com:124/1', 994 app=self.app, 995 ) 996 sentinel_instance = x._get_sentinel_instance(**x.connparams) 997 assert sentinel_instance.sentinel_kwargs == {} 998 assert sentinel_instance.connection_kwargs['db'] == 1 999 assert sentinel_instance.connection_kwargs['password'] == "test" 1000 assert len(sentinel_instance.sentinels) == 2 1001 1002 def test_get_pool(self): 1003 x = self.Backend( 1004 'sentinel://:test@github.com:123/1;' 1005 'sentinel://:test@github.com:124/1', 1006 app=self.app, 1007 ) 1008 pool = x._get_pool(**x.connparams) 1009 assert pool 1010