1# -*- coding: utf-8 -*-
2"""The ``RPC`` result backend for AMQP brokers.
3
4RPC-style result backend, using reply-to and one queue per client.
5"""
6from __future__ import absolute_import, unicode_literals
7
8import time
9
10import kombu
11from kombu.common import maybe_declare
12from kombu.utils.compat import register_after_fork
13from kombu.utils.objects import cached_property
14
15from celery import states
16from celery._state import current_task, task_join_will_block
17from celery.five import items, range
18
19from . import base
20from .asynchronous import AsyncBackendMixin, BaseResultConsumer
21
22__all__ = ('BacklogLimitExceeded', 'RPCBackend')
23
24E_NO_CHORD_SUPPORT = """
25The "rpc" result backend does not support chords!
26
27Note that a group chained with a task is also upgraded to be a chord,
28as this pattern requires synchronization.
29
30Result backends that supports chords: Redis, Database, Memcached, and more.
31"""
32
33
34class BacklogLimitExceeded(Exception):
35    """Too much state history to fast-forward."""
36
37
38def _on_after_fork_cleanup_backend(backend):
39    backend._after_fork()
40
41
42class ResultConsumer(BaseResultConsumer):
43    Consumer = kombu.Consumer
44
45    _connection = None
46    _consumer = None
47
48    def __init__(self, *args, **kwargs):
49        super(ResultConsumer, self).__init__(*args, **kwargs)
50        self._create_binding = self.backend._create_binding
51
52    def start(self, initial_task_id, no_ack=True, **kwargs):
53        self._connection = self.app.connection()
54        initial_queue = self._create_binding(initial_task_id)
55        self._consumer = self.Consumer(
56            self._connection.default_channel, [initial_queue],
57            callbacks=[self.on_state_change], no_ack=no_ack,
58            accept=self.accept)
59        self._consumer.consume()
60
61    def drain_events(self, timeout=None):
62        if self._connection:
63            return self._connection.drain_events(timeout=timeout)
64        elif timeout:
65            time.sleep(timeout)
66
67    def stop(self):
68        try:
69            self._consumer.cancel()
70        finally:
71            self._connection.close()
72
73    def on_after_fork(self):
74        self._consumer = None
75        if self._connection is not None:
76            self._connection.collect()
77            self._connection = None
78
79    def consume_from(self, task_id):
80        if self._consumer is None:
81            return self.start(task_id)
82        queue = self._create_binding(task_id)
83        if not self._consumer.consuming_from(queue):
84            self._consumer.add_queue(queue)
85            self._consumer.consume()
86
87    def cancel_for(self, task_id):
88        if self._consumer:
89            self._consumer.cancel_by_queue(self._create_binding(task_id).name)
90
91
92class RPCBackend(base.Backend, AsyncBackendMixin):
93    """Base class for the RPC result backend."""
94
95    Exchange = kombu.Exchange
96    Producer = kombu.Producer
97    ResultConsumer = ResultConsumer
98
99    #: Exception raised when there are too many messages for a task id.
100    BacklogLimitExceeded = BacklogLimitExceeded
101
102    persistent = False
103    supports_autoexpire = True
104    supports_native_join = True
105
106    retry_policy = {
107        'max_retries': 20,
108        'interval_start': 0,
109        'interval_step': 1,
110        'interval_max': 1,
111    }
112
113    class Consumer(kombu.Consumer):
114        """Consumer that requires manual declaration of queues."""
115
116        auto_declare = False
117
118    class Queue(kombu.Queue):
119        """Queue that never caches declaration."""
120
121        can_cache_declaration = False
122
123    def __init__(self, app, connection=None, exchange=None, exchange_type=None,
124                 persistent=None, serializer=None, auto_delete=True, **kwargs):
125        super(RPCBackend, self).__init__(app, **kwargs)
126        conf = self.app.conf
127        self._connection = connection
128        self._out_of_band = {}
129        self.persistent = self.prepare_persistent(persistent)
130        self.delivery_mode = 2 if self.persistent else 1
131        exchange = exchange or conf.result_exchange
132        exchange_type = exchange_type or conf.result_exchange_type
133        self.exchange = self._create_exchange(
134            exchange, exchange_type, self.delivery_mode,
135        )
136        self.serializer = serializer or conf.result_serializer
137        self.auto_delete = auto_delete
138        self.result_consumer = self.ResultConsumer(
139            self, self.app, self.accept,
140            self._pending_results, self._pending_messages,
141        )
142        if register_after_fork is not None:
143            register_after_fork(self, _on_after_fork_cleanup_backend)
144
145    def _after_fork(self):
146        # clear state for child processes.
147        self._pending_results.clear()
148        self.result_consumer._after_fork()
149
150    def _create_exchange(self, name, type='direct', delivery_mode=2):
151        # uses direct to queue routing (anon exchange).
152        return self.Exchange(None)
153
154    def _create_binding(self, task_id):
155        """Create new binding for task with id."""
156        # RPC backend caches the binding, as one queue is used for all tasks.
157        return self.binding
158
159    def ensure_chords_allowed(self):
160        raise NotImplementedError(E_NO_CHORD_SUPPORT.strip())
161
162    def on_task_call(self, producer, task_id):
163        # Called every time a task is sent when using this backend.
164        # We declare the queue we receive replies on in advance of sending
165        # the message, but we skip this if running in the prefork pool
166        # (task_join_will_block), as we know the queue is already declared.
167        if not task_join_will_block():
168            maybe_declare(self.binding(producer.channel), retry=True)
169
170    def destination_for(self, task_id, request):
171        """Get the destination for result by task id.
172
173        Returns:
174            Tuple[str, str]: tuple of ``(reply_to, correlation_id)``.
175        """
176        # Backends didn't always receive the `request`, so we must still
177        # support old code that relies on current_task.
178        try:
179            request = request or current_task.request
180        except AttributeError:
181            raise RuntimeError(
182                'RPC backend missing task request for {0!r}'.format(task_id))
183        return request.reply_to, request.correlation_id or task_id
184
185    def on_reply_declare(self, task_id):
186        # Return value here is used as the `declare=` argument
187        # for Producer.publish.
188        # By default we don't have to declare anything when sending a result.
189        pass
190
191    def on_result_fulfilled(self, result):
192        # This usually cancels the queue after the result is received,
193        # but we don't have to cancel since we have one queue per process.
194        pass
195
196    def as_uri(self, include_password=True):
197        return 'rpc://'
198
199    def store_result(self, task_id, result, state,
200                     traceback=None, request=None, **kwargs):
201        """Send task return value and state."""
202        routing_key, correlation_id = self.destination_for(task_id, request)
203        if not routing_key:
204            return
205        with self.app.amqp.producer_pool.acquire(block=True) as producer:
206            producer.publish(
207                self._to_result(task_id, state, result, traceback, request),
208                exchange=self.exchange,
209                routing_key=routing_key,
210                correlation_id=correlation_id,
211                serializer=self.serializer,
212                retry=True, retry_policy=self.retry_policy,
213                declare=self.on_reply_declare(task_id),
214                delivery_mode=self.delivery_mode,
215            )
216        return result
217
218    def _to_result(self, task_id, state, result, traceback, request):
219        return {
220            'task_id': task_id,
221            'status': state,
222            'result': self.encode_result(result, state),
223            'traceback': traceback,
224            'children': self.current_task_children(request),
225        }
226
227    def on_out_of_band_result(self, task_id, message):
228        # Callback called when a reply for a task is received,
229        # but we have no idea what do do with it.
230        # Since the result is not pending, we put it in a separate
231        # buffer: probably it will become pending later.
232        if self.result_consumer:
233            self.result_consumer.on_out_of_band_result(message)
234        self._out_of_band[task_id] = message
235
236    def get_task_meta(self, task_id, backlog_limit=1000):
237        buffered = self._out_of_band.pop(task_id, None)
238        if buffered:
239            return self._set_cache_by_message(task_id, buffered)
240
241        # Polling and using basic_get
242        latest_by_id = {}
243        prev = None
244        for acc in self._slurp_from_queue(task_id, self.accept, backlog_limit):
245            tid = self._get_message_task_id(acc)
246            prev, latest_by_id[tid] = latest_by_id.get(tid), acc
247            if prev:
248                # backends aren't expected to keep history,
249                # so we delete everything except the most recent state.
250                prev.ack()
251                prev = None
252
253        latest = latest_by_id.pop(task_id, None)
254        for tid, msg in items(latest_by_id):
255            self.on_out_of_band_result(tid, msg)
256
257        if latest:
258            latest.requeue()
259            return self._set_cache_by_message(task_id, latest)
260        else:
261            # no new state, use previous
262            try:
263                return self._cache[task_id]
264            except KeyError:
265                # result probably pending.
266                return {'status': states.PENDING, 'result': None}
267    poll = get_task_meta  # XXX compat
268
269    def _set_cache_by_message(self, task_id, message):
270        payload = self._cache[task_id] = self.meta_from_decoded(
271            message.payload)
272        return payload
273
274    def _slurp_from_queue(self, task_id, accept,
275                          limit=1000, no_ack=False):
276        with self.app.pool.acquire_channel(block=True) as (_, channel):
277            binding = self._create_binding(task_id)(channel)
278            binding.declare()
279
280            for _ in range(limit):
281                msg = binding.get(accept=accept, no_ack=no_ack)
282                if not msg:
283                    break
284                yield msg
285            else:
286                raise self.BacklogLimitExceeded(task_id)
287
288    def _get_message_task_id(self, message):
289        try:
290            # try property first so we don't have to deserialize
291            # the payload.
292            return message.properties['correlation_id']
293        except (AttributeError, KeyError):
294            # message sent by old Celery version, need to deserialize.
295            return message.payload['task_id']
296
297    def revive(self, channel):
298        pass
299
300    def reload_task_result(self, task_id):
301        raise NotImplementedError(
302            'reload_task_result is not supported by this backend.')
303
304    def reload_group_result(self, task_id):
305        """Reload group result, even if it has been previously fetched."""
306        raise NotImplementedError(
307            'reload_group_result is not supported by this backend.')
308
309    def save_group(self, group_id, result):
310        raise NotImplementedError(
311            'save_group is not supported by this backend.')
312
313    def restore_group(self, group_id, cache=True):
314        raise NotImplementedError(
315            'restore_group is not supported by this backend.')
316
317    def delete_group(self, group_id):
318        raise NotImplementedError(
319            'delete_group is not supported by this backend.')
320
321    def __reduce__(self, args=(), kwargs=None):
322        kwargs = {} if not kwargs else kwargs
323        return super(RPCBackend, self).__reduce__(args, dict(
324            kwargs,
325            connection=self._connection,
326            exchange=self.exchange.name,
327            exchange_type=self.exchange.type,
328            persistent=self.persistent,
329            serializer=self.serializer,
330            auto_delete=self.auto_delete,
331            expires=self.expires,
332        ))
333
334    @property
335    def binding(self):
336        return self.Queue(
337            self.oid, self.exchange, self.oid,
338            durable=False,
339            auto_delete=True,
340            expires=self.expires,
341        )
342
343    @cached_property
344    def oid(self):
345        # cached here is the app OID: name of queue we receive results on.
346        return self.app.oid
347