1from __future__ import absolute_import, unicode_literals 2 3import pickle 4from contextlib import contextmanager 5from datetime import timedelta 6from pickle import dumps, loads 7 8import pytest 9from billiard.einfo import ExceptionInfo 10from case import Mock, mock 11 12from celery import states, uuid 13from celery.app.task import Context 14from celery.backends.amqp import AMQPBackend 15from celery.five import Empty, Queue, range 16from celery.result import AsyncResult 17 18 19class SomeClass(object): 20 21 def __init__(self, data): 22 self.data = data 23 24 25class test_AMQPBackend: 26 27 def setup(self): 28 self.app.conf.result_cache_max = 100 29 30 def create_backend(self, **opts): 31 opts = dict({'serializer': 'pickle', 'persistent': True}, **opts) 32 return AMQPBackend(self.app, **opts) 33 34 def test_destination_for(self): 35 b = self.create_backend() 36 request = Mock() 37 assert b.destination_for('id', request) == ( 38 b.rkey('id'), request.correlation_id, 39 ) 40 41 def test_store_result__no_routing_key(self): 42 b = self.create_backend() 43 b.destination_for = Mock() 44 b.destination_for.return_value = None, None 45 b.store_result('id', None, states.SUCCESS) 46 47 def test_mark_as_done(self): 48 tb1 = self.create_backend(max_cached_results=1) 49 tb2 = self.create_backend(max_cached_results=1) 50 51 tid = uuid() 52 53 tb1.mark_as_done(tid, 42) 54 assert tb2.get_state(tid) == states.SUCCESS 55 assert tb2.get_result(tid) == 42 56 assert tb2._cache.get(tid) 57 assert tb2.get_result(tid), 42 58 59 @pytest.mark.usefixtures('depends_on_current_app') 60 def test_pickleable(self): 61 assert loads(dumps(self.create_backend())) 62 63 def test_revive(self): 64 tb = self.create_backend() 65 tb.revive(None) 66 67 def test_is_pickled(self): 68 tb1 = self.create_backend() 69 tb2 = self.create_backend() 70 71 tid2 = uuid() 72 result = {'foo': 'baz', 'bar': SomeClass(12345)} 73 tb1.mark_as_done(tid2, result) 74 # is serialized properly. 75 rindb = tb2.get_result(tid2) 76 assert rindb.get('foo') == 'baz' 77 assert rindb.get('bar').data == 12345 78 79 def test_mark_as_failure(self): 80 tb1 = self.create_backend() 81 tb2 = self.create_backend() 82 83 tid3 = uuid() 84 try: 85 raise KeyError('foo') 86 except KeyError as exception: 87 einfo = ExceptionInfo() 88 tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback) 89 assert tb2.get_state(tid3) == states.FAILURE 90 assert isinstance(tb2.get_result(tid3), KeyError) 91 assert tb2.get_traceback(tid3) == einfo.traceback 92 93 def test_repair_uuid(self): 94 from celery.backends.amqp import repair_uuid 95 for i in range(10): 96 tid = uuid() 97 assert repair_uuid(tid.replace('-', '')) == tid 98 99 def test_expires_is_int(self): 100 b = self.create_backend(expires=48) 101 q = b._create_binding('x1y2z3') 102 assert q.expires == 48 103 104 def test_expires_is_float(self): 105 b = self.create_backend(expires=48.3) 106 q = b._create_binding('x1y2z3') 107 assert q.expires == 48.3 108 109 def test_expires_is_timedelta(self): 110 b = self.create_backend(expires=timedelta(minutes=1)) 111 q = b._create_binding('x1y2z3') 112 assert q.expires == 60 113 114 @mock.sleepdeprived() 115 def test_store_result_retries(self): 116 iterations = [0] 117 stop_raising_at = [5] 118 119 def publish(*args, **kwargs): 120 if iterations[0] > stop_raising_at[0]: 121 return 122 iterations[0] += 1 123 raise KeyError('foo') 124 125 backend = AMQPBackend(self.app) 126 from celery.app.amqp import Producer 127 prod, Producer.publish = Producer.publish, publish 128 try: 129 with pytest.raises(KeyError): 130 backend.retry_policy['max_retries'] = None 131 backend.store_result('foo', 'bar', 'STARTED') 132 133 with pytest.raises(KeyError): 134 backend.retry_policy['max_retries'] = 10 135 backend.store_result('foo', 'bar', 'STARTED') 136 finally: 137 Producer.publish = prod 138 139 def test_poll_no_messages(self): 140 b = self.create_backend() 141 assert b.get_task_meta(uuid())['status'] == states.PENDING 142 143 @contextmanager 144 def _result_context(self): 145 results = Queue() 146 147 class Message(object): 148 acked = 0 149 requeued = 0 150 151 def __init__(self, **merge): 152 self.payload = dict({'status': states.STARTED, 153 'result': None}, **merge) 154 self.properties = {'correlation_id': merge.get('task_id')} 155 self.body = pickle.dumps(self.payload) 156 self.content_type = 'application/x-python-serialize' 157 self.content_encoding = 'binary' 158 159 def ack(self, *args, **kwargs): 160 self.acked += 1 161 162 def requeue(self, *args, **kwargs): 163 self.requeued += 1 164 165 class MockBinding(object): 166 167 def __init__(self, *args, **kwargs): 168 self.channel = Mock() 169 170 def __call__(self, *args, **kwargs): 171 return self 172 173 def declare(self): 174 pass 175 176 def get(self, no_ack=False, accept=None): 177 try: 178 m = results.get(block=False) 179 if m: 180 m.accept = accept 181 return m 182 except Empty: 183 pass 184 185 def is_bound(self): 186 return True 187 188 class MockBackend(AMQPBackend): 189 Queue = MockBinding 190 191 backend = MockBackend(self.app, max_cached_results=100) 192 backend._republish = Mock() 193 194 yield results, backend, Message 195 196 def test_backlog_limit_exceeded(self): 197 with self._result_context() as (results, backend, Message): 198 for i in range(1001): 199 results.put(Message(task_id='id', status=states.RECEIVED)) 200 with pytest.raises(backend.BacklogLimitExceeded): 201 backend.get_task_meta('id') 202 203 def test_poll_result(self): 204 with self._result_context() as (results, backend, Message): 205 tid = uuid() 206 # FFWD's to the latest state. 207 state_messages = [ 208 Message(task_id=tid, status=states.RECEIVED, seq=1), 209 Message(task_id=tid, status=states.STARTED, seq=2), 210 Message(task_id=tid, status=states.FAILURE, seq=3), 211 ] 212 for state_message in state_messages: 213 results.put(state_message) 214 r1 = backend.get_task_meta(tid) 215 # FFWDs to the last state. 216 assert r1['status'] == states.FAILURE 217 assert r1['seq'] == 3 218 219 # Caches last known state. 220 tid = uuid() 221 results.put(Message(task_id=tid)) 222 backend.get_task_meta(tid) 223 assert tid, backend._cache in 'Caches last known state' 224 225 assert state_messages[-1].requeued 226 227 # Returns cache if no new states. 228 results.queue.clear() 229 assert not results.qsize() 230 backend._cache[tid] = 'hello' 231 # returns cache if no new states. 232 assert backend.get_task_meta(tid) == 'hello' 233 234 def test_drain_events_decodes_exceptions_in_meta(self): 235 tid = uuid() 236 b = self.create_backend(serializer='json') 237 b.store_result(tid, RuntimeError('aap'), states.FAILURE) 238 result = AsyncResult(tid, backend=b) 239 240 with pytest.raises(Exception) as excinfo: 241 result.get() 242 243 assert excinfo.value.__class__.__name__ == 'RuntimeError' 244 assert str(excinfo.value) == 'aap' 245 246 def test_no_expires(self): 247 b = self.create_backend(expires=None) 248 app = self.app 249 app.conf.result_expires = None 250 b = self.create_backend(expires=None) 251 q = b._create_binding('foo') 252 assert q.expires is None 253 254 def test_process_cleanup(self): 255 self.create_backend().process_cleanup() 256 257 def test_reload_task_result(self): 258 with pytest.raises(NotImplementedError): 259 self.create_backend().reload_task_result('x') 260 261 def test_reload_group_result(self): 262 with pytest.raises(NotImplementedError): 263 self.create_backend().reload_group_result('x') 264 265 def test_save_group(self): 266 with pytest.raises(NotImplementedError): 267 self.create_backend().save_group('x', 'x') 268 269 def test_restore_group(self): 270 with pytest.raises(NotImplementedError): 271 self.create_backend().restore_group('x') 272 273 def test_delete_group(self): 274 with pytest.raises(NotImplementedError): 275 self.create_backend().delete_group('x') 276 277 278class test_AMQPBackend_result_extended: 279 def setup(self): 280 self.app.conf.result_extended = True 281 282 def test_store_result(self): 283 b = AMQPBackend(self.app) 284 tid = uuid() 285 286 request = Context(args=(1, 2, 3), kwargs={'foo': 'bar'}, 287 task_name='mytask', retries=2, 288 hostname='celery@worker_1', 289 delivery_info={'routing_key': 'celery'}) 290 291 b.store_result(tid, {'fizz': 'buzz'}, states.SUCCESS, request=request) 292 293 meta = b.get_task_meta(tid) 294 assert meta == { 295 'args': [1, 2, 3], 296 'children': [], 297 'kwargs': {'foo': 'bar'}, 298 'name': 'mytask', 299 'queue': 'celery', 300 'result': {'fizz': 'buzz'}, 301 'retries': 2, 302 'status': 'SUCCESS', 303 'task_id': tid, 304 'traceback': None, 305 'worker': 'celery@worker_1', 306 } 307