1# -*- coding: utf-8 -*-
2"""The old AMQP result backend, deprecated and replaced by the RPC backend."""
3from __future__ import absolute_import, unicode_literals
4
5import socket
6from collections import deque
7from operator import itemgetter
8
9from kombu import Consumer, Exchange, Producer, Queue
10
11from celery import states
12from celery.exceptions import TimeoutError
13from celery.five import monotonic, range
14from celery.utils import deprecated
15from celery.utils.log import get_logger
16
17from .base import BaseBackend
18
19__all__ = ('BacklogLimitExceeded', 'AMQPBackend')
20
21logger = get_logger(__name__)
22
23
24class BacklogLimitExceeded(Exception):
25    """Too much state history to fast-forward."""
26
27
28def repair_uuid(s):
29    # Historically the dashes in UUIDS are removed from AMQ entity names,
30    # but there's no known reason to.  Hopefully we'll be able to fix
31    # this in v4.0.
32    return '%s-%s-%s-%s-%s' % (s[:8], s[8:12], s[12:16], s[16:20], s[20:])
33
34
35class NoCacheQueue(Queue):
36    can_cache_declaration = False
37
38
39class AMQPBackend(BaseBackend):
40    """The AMQP result backend.
41
42    Deprecated: Please use the RPC backend or a persistent backend.
43    """
44
45    Exchange = Exchange
46    Queue = NoCacheQueue
47    Consumer = Consumer
48    Producer = Producer
49
50    BacklogLimitExceeded = BacklogLimitExceeded
51
52    persistent = True
53    supports_autoexpire = True
54    supports_native_join = True
55
56    retry_policy = {
57        'max_retries': 20,
58        'interval_start': 0,
59        'interval_step': 1,
60        'interval_max': 1,
61    }
62
63    def __init__(self, app, connection=None, exchange=None, exchange_type=None,
64                 persistent=None, serializer=None, auto_delete=True, **kwargs):
65        deprecated.warn(
66            'The AMQP result backend', deprecation='4.0', removal='5.0',
67            alternative='Please use RPC backend or a persistent backend.')
68        super(AMQPBackend, self).__init__(app, **kwargs)
69        conf = self.app.conf
70        self._connection = connection
71        self.persistent = self.prepare_persistent(persistent)
72        self.delivery_mode = 2 if self.persistent else 1
73        exchange = exchange or conf.result_exchange
74        exchange_type = exchange_type or conf.result_exchange_type
75        self.exchange = self._create_exchange(
76            exchange, exchange_type, self.delivery_mode,
77        )
78        self.serializer = serializer or conf.result_serializer
79        self.auto_delete = auto_delete
80
81    def _create_exchange(self, name, type='direct', delivery_mode=2):
82        return self.Exchange(name=name,
83                             type=type,
84                             delivery_mode=delivery_mode,
85                             durable=self.persistent,
86                             auto_delete=False)
87
88    def _create_binding(self, task_id):
89        name = self.rkey(task_id)
90        return self.Queue(
91            name=name,
92            exchange=self.exchange,
93            routing_key=name,
94            durable=self.persistent,
95            auto_delete=self.auto_delete,
96            expires=self.expires,
97        )
98
99    def revive(self, channel):
100        pass
101
102    def rkey(self, task_id):
103        return task_id.replace('-', '')
104
105    def destination_for(self, task_id, request):
106        if request:
107            return self.rkey(task_id), request.correlation_id or task_id
108        return self.rkey(task_id), task_id
109
110    def store_result(self, task_id, result, state,
111                     traceback=None, request=None, **kwargs):
112        """Send task return value and state."""
113        routing_key, correlation_id = self.destination_for(task_id, request)
114        if not routing_key:
115            return
116
117        payload = {'task_id': task_id, 'status': state,
118                   'result': self.encode_result(result, state),
119                   'traceback': traceback,
120                   'children': self.current_task_children(request)}
121        if self.app.conf.find_value_for_key('extended', 'result'):
122            payload['name'] = getattr(request, 'task_name', None)
123            payload['args'] = getattr(request, 'args', None)
124            payload['kwargs'] = getattr(request, 'kwargs', None)
125            payload['worker'] = getattr(request, 'hostname', None)
126            payload['retries'] = getattr(request, 'retries', None)
127            payload['queue'] = request.delivery_info.get('routing_key')\
128                if hasattr(request, 'delivery_info') \
129                and request.delivery_info else None
130
131        with self.app.amqp.producer_pool.acquire(block=True) as producer:
132            producer.publish(
133                payload,
134                exchange=self.exchange,
135                routing_key=routing_key,
136                correlation_id=correlation_id,
137                serializer=self.serializer,
138                retry=True, retry_policy=self.retry_policy,
139                declare=self.on_reply_declare(task_id),
140                delivery_mode=self.delivery_mode,
141            )
142
143    def on_reply_declare(self, task_id):
144        return [self._create_binding(task_id)]
145
146    def wait_for(self, task_id, timeout=None, cache=True,
147                 no_ack=True, on_interval=None,
148                 READY_STATES=states.READY_STATES,
149                 PROPAGATE_STATES=states.PROPAGATE_STATES,
150                 **kwargs):
151        cached_meta = self._cache.get(task_id)
152        if cache and cached_meta and \
153                cached_meta['status'] in READY_STATES:
154            return cached_meta
155        try:
156            return self.consume(task_id, timeout=timeout, no_ack=no_ack,
157                                on_interval=on_interval)
158        except socket.timeout:
159            raise TimeoutError('The operation timed out.')
160
161    def get_task_meta(self, task_id, backlog_limit=1000):
162        # Polling and using basic_get
163        with self.app.pool.acquire_channel(block=True) as (_, channel):
164            binding = self._create_binding(task_id)(channel)
165            binding.declare()
166
167            prev = latest = acc = None
168            for i in range(backlog_limit):  # spool ffwd
169                acc = binding.get(
170                    accept=self.accept, no_ack=False,
171                )
172                if not acc:  # no more messages
173                    break
174                if acc.payload['task_id'] == task_id:
175                    prev, latest = latest, acc
176                if prev:
177                    # backends are not expected to keep history,
178                    # so we delete everything except the most recent state.
179                    prev.ack()
180                    prev = None
181            else:
182                raise self.BacklogLimitExceeded(task_id)
183
184            if latest:
185                payload = self._cache[task_id] = self.meta_from_decoded(
186                    latest.payload)
187                latest.requeue()
188                return payload
189            else:
190                # no new state, use previous
191                try:
192                    return self._cache[task_id]
193                except KeyError:
194                    # result probably pending.
195                    return {'status': states.PENDING, 'result': None}
196    poll = get_task_meta  # XXX compat
197
198    def drain_events(self, connection, consumer,
199                     timeout=None, on_interval=None, now=monotonic, wait=None):
200        wait = wait or connection.drain_events
201        results = {}
202
203        def callback(meta, message):
204            if meta['status'] in states.READY_STATES:
205                results[meta['task_id']] = self.meta_from_decoded(meta)
206
207        consumer.callbacks[:] = [callback]
208        time_start = now()
209
210        while 1:
211            # Total time spent may exceed a single call to wait()
212            if timeout and now() - time_start >= timeout:
213                raise socket.timeout()
214            try:
215                wait(timeout=1)
216            except socket.timeout:
217                pass
218            if on_interval:
219                on_interval()
220            if results:  # got event on the wanted channel.
221                break
222        self._cache.update(results)
223        return results
224
225    def consume(self, task_id, timeout=None, no_ack=True, on_interval=None):
226        wait = self.drain_events
227        with self.app.pool.acquire_channel(block=True) as (conn, channel):
228            binding = self._create_binding(task_id)
229            with self.Consumer(channel, binding,
230                               no_ack=no_ack, accept=self.accept) as consumer:
231                while 1:
232                    try:
233                        return wait(
234                            conn, consumer, timeout, on_interval)[task_id]
235                    except KeyError:
236                        continue
237
238    def _many_bindings(self, ids):
239        return [self._create_binding(task_id) for task_id in ids]
240
241    def get_many(self, task_ids, timeout=None, no_ack=True,
242                 on_message=None, on_interval=None,
243                 now=monotonic, getfields=itemgetter('status', 'task_id'),
244                 READY_STATES=states.READY_STATES,
245                 PROPAGATE_STATES=states.PROPAGATE_STATES, **kwargs):
246        with self.app.pool.acquire_channel(block=True) as (conn, channel):
247            ids = set(task_ids)
248            cached_ids = set()
249            mark_cached = cached_ids.add
250            for task_id in ids:
251                try:
252                    cached = self._cache[task_id]
253                except KeyError:
254                    pass
255                else:
256                    if cached['status'] in READY_STATES:
257                        yield task_id, cached
258                        mark_cached(task_id)
259            ids.difference_update(cached_ids)
260            results = deque()
261            push_result = results.append
262            push_cache = self._cache.__setitem__
263            decode_result = self.meta_from_decoded
264
265            def _on_message(message):
266                body = decode_result(message.decode())
267                if on_message is not None:
268                    on_message(body)
269                state, uid = getfields(body)
270                if state in READY_STATES:
271                    push_result(body) \
272                        if uid in task_ids else push_cache(uid, body)
273
274            bindings = self._many_bindings(task_ids)
275            with self.Consumer(channel, bindings, on_message=_on_message,
276                               accept=self.accept, no_ack=no_ack):
277                wait = conn.drain_events
278                popleft = results.popleft
279                while ids:
280                    wait(timeout=timeout)
281                    while results:
282                        state = popleft()
283                        task_id = state['task_id']
284                        ids.discard(task_id)
285                        push_cache(task_id, state)
286                        yield task_id, state
287                    if on_interval:
288                        on_interval()
289
290    def reload_task_result(self, task_id):
291        raise NotImplementedError(
292            'reload_task_result is not supported by this backend.')
293
294    def reload_group_result(self, task_id):
295        """Reload group result, even if it has been previously fetched."""
296        raise NotImplementedError(
297            'reload_group_result is not supported by this backend.')
298
299    def save_group(self, group_id, result):
300        raise NotImplementedError(
301            'save_group is not supported by this backend.')
302
303    def restore_group(self, group_id, cache=True):
304        raise NotImplementedError(
305            'restore_group is not supported by this backend.')
306
307    def delete_group(self, group_id):
308        raise NotImplementedError(
309            'delete_group is not supported by this backend.')
310
311    def __reduce__(self, args=(), kwargs=None):
312        kwargs = kwargs if kwargs else {}
313        kwargs.update(
314            connection=self._connection,
315            exchange=self.exchange.name,
316            exchange_type=self.exchange.type,
317            persistent=self.persistent,
318            serializer=self.serializer,
319            auto_delete=self.auto_delete,
320            expires=self.expires,
321        )
322        return super(AMQPBackend, self).__reduce__(args, kwargs)
323
324    def as_uri(self, include_password=True):
325        return 'amqp://'
326