1"""Async I/O backend support utilities."""
2from __future__ import absolute_import, unicode_literals
3
4import socket
5import threading
6from collections import deque
7from time import sleep
8from weakref import WeakKeyDictionary
9
10from kombu.utils.compat import detect_environment
11
12from celery import states
13from celery.exceptions import TimeoutError
14from celery.five import Empty, monotonic
15from celery.utils.threads import THREAD_TIMEOUT_MAX
16
17__all__ = (
18    'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer',
19    'register_drainer',
20)
21
22drainers = {}
23
24
25def register_drainer(name):
26    """Decorator used to register a new result drainer type."""
27    def _inner(cls):
28        drainers[name] = cls
29        return cls
30    return _inner
31
32
33@register_drainer('default')
34class Drainer(object):
35    """Result draining service."""
36
37    def __init__(self, result_consumer):
38        self.result_consumer = result_consumer
39
40    def start(self):
41        pass
42
43    def stop(self):
44        pass
45
46    def drain_events_until(self, p, timeout=None, interval=1, on_interval=None, wait=None):
47        wait = wait or self.result_consumer.drain_events
48        time_start = monotonic()
49
50        while 1:
51            # Total time spent may exceed a single call to wait()
52            if timeout and monotonic() - time_start >= timeout:
53                raise socket.timeout()
54            try:
55                yield self.wait_for(p, wait, timeout=interval)
56            except socket.timeout:
57                pass
58            if on_interval:
59                on_interval()
60            if p.ready:  # got event on the wanted channel.
61                break
62
63    def wait_for(self, p, wait, timeout=None):
64        wait(timeout=timeout)
65
66
67class greenletDrainer(Drainer):
68    spawn = None
69    _g = None
70
71    def __init__(self, *args, **kwargs):
72        super(greenletDrainer, self).__init__(*args, **kwargs)
73        self._started = threading.Event()
74        self._stopped = threading.Event()
75        self._shutdown = threading.Event()
76
77    def run(self):
78        self._started.set()
79        while not self._stopped.is_set():
80            try:
81                self.result_consumer.drain_events(timeout=1)
82            except socket.timeout:
83                pass
84        self._shutdown.set()
85
86    def start(self):
87        if not self._started.is_set():
88            self._g = self.spawn(self.run)
89            self._started.wait()
90
91    def stop(self):
92        self._stopped.set()
93        self._shutdown.wait(THREAD_TIMEOUT_MAX)
94
95
96@register_drainer('eventlet')
97class eventletDrainer(greenletDrainer):
98
99    def spawn(self, func):
100        from eventlet import spawn, sleep
101        g = spawn(func)
102        sleep(0)
103        return g
104
105    def wait_for(self, p, wait, timeout=None):
106        self.start()
107        if not p.ready:
108            self._g._exit_event.wait(timeout=timeout)
109
110
111@register_drainer('gevent')
112class geventDrainer(greenletDrainer):
113
114    def spawn(self, func):
115        import gevent
116        g = gevent.spawn(func)
117        gevent.sleep(0)
118        return g
119
120    def wait_for(self, p, wait, timeout=None):
121        import gevent
122        self.start()
123        if not p.ready:
124            gevent.wait([self._g], timeout=timeout)
125
126
127class AsyncBackendMixin(object):
128    """Mixin for backends that enables the async API."""
129
130    def _collect_into(self, result, bucket):
131        self.result_consumer.buckets[result] = bucket
132
133    def iter_native(self, result, no_ack=True, **kwargs):
134        self._ensure_not_eager()
135
136        results = result.results
137        if not results:
138            raise StopIteration()
139
140        # we tell the result consumer to put consumed results
141        # into these buckets.
142        bucket = deque()
143        for node in results:
144            if not hasattr(node, '_cache'):
145                bucket.append(node)
146            elif node._cache:
147                bucket.append(node)
148            else:
149                self._collect_into(node, bucket)
150
151        for _ in self._wait_for_pending(result, no_ack=no_ack, **kwargs):
152            while bucket:
153                node = bucket.popleft()
154                if not hasattr(node, '_cache'):
155                    yield node.id, node.children
156                else:
157                    yield node.id, node._cache
158        while bucket:
159            node = bucket.popleft()
160            yield node.id, node._cache
161
162    def add_pending_result(self, result, weak=False, start_drainer=True):
163        if start_drainer:
164            self.result_consumer.drainer.start()
165        try:
166            self._maybe_resolve_from_buffer(result)
167        except Empty:
168            self._add_pending_result(result.id, result, weak=weak)
169        return result
170
171    def _maybe_resolve_from_buffer(self, result):
172        result._maybe_set_cache(self._pending_messages.take(result.id))
173
174    def _add_pending_result(self, task_id, result, weak=False):
175        concrete, weak_ = self._pending_results
176        if task_id not in weak_ and result.id not in concrete:
177            (weak_ if weak else concrete)[task_id] = result
178            self.result_consumer.consume_from(task_id)
179
180    def add_pending_results(self, results, weak=False):
181        self.result_consumer.drainer.start()
182        return [self.add_pending_result(result, weak=weak, start_drainer=False)
183                for result in results]
184
185    def remove_pending_result(self, result):
186        self._remove_pending_result(result.id)
187        self.on_result_fulfilled(result)
188        return result
189
190    def _remove_pending_result(self, task_id):
191        for mapping in self._pending_results:
192            mapping.pop(task_id, None)
193
194    def on_result_fulfilled(self, result):
195        self.result_consumer.cancel_for(result.id)
196
197    def wait_for_pending(self, result,
198                         callback=None, propagate=True, **kwargs):
199        self._ensure_not_eager()
200        for _ in self._wait_for_pending(result, **kwargs):
201            pass
202        return result.maybe_throw(callback=callback, propagate=propagate)
203
204    def _wait_for_pending(self, result,
205                          timeout=None, on_interval=None, on_message=None,
206                          **kwargs):
207        return self.result_consumer._wait_for_pending(
208            result, timeout=timeout,
209            on_interval=on_interval, on_message=on_message,
210            **kwargs
211        )
212
213    @property
214    def is_async(self):
215        return True
216
217
218class BaseResultConsumer(object):
219    """Manager responsible for consuming result messages."""
220
221    def __init__(self, backend, app, accept,
222                 pending_results, pending_messages):
223        self.backend = backend
224        self.app = app
225        self.accept = accept
226        self._pending_results = pending_results
227        self._pending_messages = pending_messages
228        self.on_message = None
229        self.buckets = WeakKeyDictionary()
230        self.drainer = drainers[detect_environment()](self)
231
232    def start(self, initial_task_id, **kwargs):
233        raise NotImplementedError()
234
235    def stop(self):
236        pass
237
238    def drain_events(self, timeout=None):
239        raise NotImplementedError()
240
241    def consume_from(self, task_id):
242        raise NotImplementedError()
243
244    def cancel_for(self, task_id):
245        raise NotImplementedError()
246
247    def _after_fork(self):
248        self.buckets.clear()
249        self.buckets = WeakKeyDictionary()
250        self.on_message = None
251        self.on_after_fork()
252
253    def on_after_fork(self):
254        pass
255
256    def drain_events_until(self, p, timeout=None, on_interval=None):
257        return self.drainer.drain_events_until(
258            p, timeout=timeout, on_interval=on_interval)
259
260    def _wait_for_pending(self, result,
261                          timeout=None, on_interval=None, on_message=None,
262                          **kwargs):
263        self.on_wait_for_pending(result, timeout=timeout, **kwargs)
264        prev_on_m, self.on_message = self.on_message, on_message
265        try:
266            for _ in self.drain_events_until(
267                    result.on_ready, timeout=timeout,
268                    on_interval=on_interval):
269                yield
270                sleep(0)
271        except socket.timeout:
272            raise TimeoutError('The operation timed out.')
273        finally:
274            self.on_message = prev_on_m
275
276    def on_wait_for_pending(self, result, timeout=None, **kwargs):
277        pass
278
279    def on_out_of_band_result(self, message):
280        self.on_state_change(message.payload, message)
281
282    def _get_pending_result(self, task_id):
283        for mapping in self._pending_results:
284            try:
285                return mapping[task_id]
286            except KeyError:
287                pass
288        raise KeyError(task_id)
289
290    def on_state_change(self, meta, message):
291        if self.on_message:
292            self.on_message(meta)
293        if meta['status'] in states.READY_STATES:
294            task_id = meta['task_id']
295            try:
296                result = self._get_pending_result(task_id)
297            except KeyError:
298                # send to buffer in case we received this result
299                # before it was added to _pending_results.
300                self._pending_messages.put(task_id, meta)
301            else:
302                result._maybe_set_cache(meta)
303                buckets = self.buckets
304                try:
305                    # remove bucket for this result, since it's fulfilled
306                    bucket = buckets.pop(result)
307                except KeyError:
308                    pass
309                else:
310                    # send to waiter via bucket
311                    bucket.append(result)
312        sleep(0)
313