1"""Async I/O backend support utilities.""" 2from __future__ import absolute_import, unicode_literals 3 4import socket 5import threading 6from collections import deque 7from time import sleep 8from weakref import WeakKeyDictionary 9 10from kombu.utils.compat import detect_environment 11 12from celery import states 13from celery.exceptions import TimeoutError 14from celery.five import Empty, monotonic 15from celery.utils.threads import THREAD_TIMEOUT_MAX 16 17__all__ = ( 18 'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer', 19 'register_drainer', 20) 21 22drainers = {} 23 24 25def register_drainer(name): 26 """Decorator used to register a new result drainer type.""" 27 def _inner(cls): 28 drainers[name] = cls 29 return cls 30 return _inner 31 32 33@register_drainer('default') 34class Drainer(object): 35 """Result draining service.""" 36 37 def __init__(self, result_consumer): 38 self.result_consumer = result_consumer 39 40 def start(self): 41 pass 42 43 def stop(self): 44 pass 45 46 def drain_events_until(self, p, timeout=None, interval=1, on_interval=None, wait=None): 47 wait = wait or self.result_consumer.drain_events 48 time_start = monotonic() 49 50 while 1: 51 # Total time spent may exceed a single call to wait() 52 if timeout and monotonic() - time_start >= timeout: 53 raise socket.timeout() 54 try: 55 yield self.wait_for(p, wait, timeout=interval) 56 except socket.timeout: 57 pass 58 if on_interval: 59 on_interval() 60 if p.ready: # got event on the wanted channel. 61 break 62 63 def wait_for(self, p, wait, timeout=None): 64 wait(timeout=timeout) 65 66 67class greenletDrainer(Drainer): 68 spawn = None 69 _g = None 70 71 def __init__(self, *args, **kwargs): 72 super(greenletDrainer, self).__init__(*args, **kwargs) 73 self._started = threading.Event() 74 self._stopped = threading.Event() 75 self._shutdown = threading.Event() 76 77 def run(self): 78 self._started.set() 79 while not self._stopped.is_set(): 80 try: 81 self.result_consumer.drain_events(timeout=1) 82 except socket.timeout: 83 pass 84 self._shutdown.set() 85 86 def start(self): 87 if not self._started.is_set(): 88 self._g = self.spawn(self.run) 89 self._started.wait() 90 91 def stop(self): 92 self._stopped.set() 93 self._shutdown.wait(THREAD_TIMEOUT_MAX) 94 95 96@register_drainer('eventlet') 97class eventletDrainer(greenletDrainer): 98 99 def spawn(self, func): 100 from eventlet import spawn, sleep 101 g = spawn(func) 102 sleep(0) 103 return g 104 105 def wait_for(self, p, wait, timeout=None): 106 self.start() 107 if not p.ready: 108 self._g._exit_event.wait(timeout=timeout) 109 110 111@register_drainer('gevent') 112class geventDrainer(greenletDrainer): 113 114 def spawn(self, func): 115 import gevent 116 g = gevent.spawn(func) 117 gevent.sleep(0) 118 return g 119 120 def wait_for(self, p, wait, timeout=None): 121 import gevent 122 self.start() 123 if not p.ready: 124 gevent.wait([self._g], timeout=timeout) 125 126 127class AsyncBackendMixin(object): 128 """Mixin for backends that enables the async API.""" 129 130 def _collect_into(self, result, bucket): 131 self.result_consumer.buckets[result] = bucket 132 133 def iter_native(self, result, no_ack=True, **kwargs): 134 self._ensure_not_eager() 135 136 results = result.results 137 if not results: 138 raise StopIteration() 139 140 # we tell the result consumer to put consumed results 141 # into these buckets. 142 bucket = deque() 143 for node in results: 144 if not hasattr(node, '_cache'): 145 bucket.append(node) 146 elif node._cache: 147 bucket.append(node) 148 else: 149 self._collect_into(node, bucket) 150 151 for _ in self._wait_for_pending(result, no_ack=no_ack, **kwargs): 152 while bucket: 153 node = bucket.popleft() 154 if not hasattr(node, '_cache'): 155 yield node.id, node.children 156 else: 157 yield node.id, node._cache 158 while bucket: 159 node = bucket.popleft() 160 yield node.id, node._cache 161 162 def add_pending_result(self, result, weak=False, start_drainer=True): 163 if start_drainer: 164 self.result_consumer.drainer.start() 165 try: 166 self._maybe_resolve_from_buffer(result) 167 except Empty: 168 self._add_pending_result(result.id, result, weak=weak) 169 return result 170 171 def _maybe_resolve_from_buffer(self, result): 172 result._maybe_set_cache(self._pending_messages.take(result.id)) 173 174 def _add_pending_result(self, task_id, result, weak=False): 175 concrete, weak_ = self._pending_results 176 if task_id not in weak_ and result.id not in concrete: 177 (weak_ if weak else concrete)[task_id] = result 178 self.result_consumer.consume_from(task_id) 179 180 def add_pending_results(self, results, weak=False): 181 self.result_consumer.drainer.start() 182 return [self.add_pending_result(result, weak=weak, start_drainer=False) 183 for result in results] 184 185 def remove_pending_result(self, result): 186 self._remove_pending_result(result.id) 187 self.on_result_fulfilled(result) 188 return result 189 190 def _remove_pending_result(self, task_id): 191 for mapping in self._pending_results: 192 mapping.pop(task_id, None) 193 194 def on_result_fulfilled(self, result): 195 self.result_consumer.cancel_for(result.id) 196 197 def wait_for_pending(self, result, 198 callback=None, propagate=True, **kwargs): 199 self._ensure_not_eager() 200 for _ in self._wait_for_pending(result, **kwargs): 201 pass 202 return result.maybe_throw(callback=callback, propagate=propagate) 203 204 def _wait_for_pending(self, result, 205 timeout=None, on_interval=None, on_message=None, 206 **kwargs): 207 return self.result_consumer._wait_for_pending( 208 result, timeout=timeout, 209 on_interval=on_interval, on_message=on_message, 210 **kwargs 211 ) 212 213 @property 214 def is_async(self): 215 return True 216 217 218class BaseResultConsumer(object): 219 """Manager responsible for consuming result messages.""" 220 221 def __init__(self, backend, app, accept, 222 pending_results, pending_messages): 223 self.backend = backend 224 self.app = app 225 self.accept = accept 226 self._pending_results = pending_results 227 self._pending_messages = pending_messages 228 self.on_message = None 229 self.buckets = WeakKeyDictionary() 230 self.drainer = drainers[detect_environment()](self) 231 232 def start(self, initial_task_id, **kwargs): 233 raise NotImplementedError() 234 235 def stop(self): 236 pass 237 238 def drain_events(self, timeout=None): 239 raise NotImplementedError() 240 241 def consume_from(self, task_id): 242 raise NotImplementedError() 243 244 def cancel_for(self, task_id): 245 raise NotImplementedError() 246 247 def _after_fork(self): 248 self.buckets.clear() 249 self.buckets = WeakKeyDictionary() 250 self.on_message = None 251 self.on_after_fork() 252 253 def on_after_fork(self): 254 pass 255 256 def drain_events_until(self, p, timeout=None, on_interval=None): 257 return self.drainer.drain_events_until( 258 p, timeout=timeout, on_interval=on_interval) 259 260 def _wait_for_pending(self, result, 261 timeout=None, on_interval=None, on_message=None, 262 **kwargs): 263 self.on_wait_for_pending(result, timeout=timeout, **kwargs) 264 prev_on_m, self.on_message = self.on_message, on_message 265 try: 266 for _ in self.drain_events_until( 267 result.on_ready, timeout=timeout, 268 on_interval=on_interval): 269 yield 270 sleep(0) 271 except socket.timeout: 272 raise TimeoutError('The operation timed out.') 273 finally: 274 self.on_message = prev_on_m 275 276 def on_wait_for_pending(self, result, timeout=None, **kwargs): 277 pass 278 279 def on_out_of_band_result(self, message): 280 self.on_state_change(message.payload, message) 281 282 def _get_pending_result(self, task_id): 283 for mapping in self._pending_results: 284 try: 285 return mapping[task_id] 286 except KeyError: 287 pass 288 raise KeyError(task_id) 289 290 def on_state_change(self, meta, message): 291 if self.on_message: 292 self.on_message(meta) 293 if meta['status'] in states.READY_STATES: 294 task_id = meta['task_id'] 295 try: 296 result = self._get_pending_result(task_id) 297 except KeyError: 298 # send to buffer in case we received this result 299 # before it was added to _pending_results. 300 self._pending_messages.put(task_id, meta) 301 else: 302 result._maybe_set_cache(meta) 303 buckets = self.buckets 304 try: 305 # remove bucket for this result, since it's fulfilled 306 bucket = buckets.pop(result) 307 except KeyError: 308 pass 309 else: 310 # send to waiter via bucket 311 bucket.append(result) 312 sleep(0) 313