1# -*- coding: utf-8 -*-
2"""Sending/Receiving Messages (Kombu integration)."""
3from __future__ import absolute_import, unicode_literals
4
5import numbers
6from collections import namedtuple
7from datetime import timedelta
8from weakref import WeakValueDictionary
9
10from kombu import Connection, Consumer, Exchange, Producer, Queue, pools
11from kombu.common import Broadcast
12from kombu.utils.functional import maybe_list
13from kombu.utils.objects import cached_property
14
15from celery import signals
16from celery.five import PY3, items, string_t
17from celery.local import try_import
18from celery.utils.nodenames import anon_nodename
19from celery.utils.saferepr import saferepr
20from celery.utils.text import indent as textindent
21from celery.utils.time import maybe_make_aware
22
23from . import routes as _routes
24
25try:
26    from collections.abc import Mapping
27except ImportError:
28    # TODO: Remove this when we drop Python 2.7 support
29    from collections import Mapping
30
31__all__ = ('AMQP', 'Queues', 'task_message')
32
33#: earliest date supported by time.mktime.
34INT_MIN = -2147483648
35
36# json in Python 2.7 borks if dict contains byte keys.
37JSON_NEEDS_UNICODE_KEYS = not PY3 and not try_import('simplejson')
38
39#: Human readable queue declaration.
40QUEUE_FORMAT = """
41.> {0.name:<16} exchange={0.exchange.name}({0.exchange.type}) \
42key={0.routing_key}
43"""
44
45task_message = namedtuple('task_message',
46                          ('headers', 'properties', 'body', 'sent_event'))
47
48
49def utf8dict(d, encoding='utf-8'):
50    return {k.decode(encoding) if isinstance(k, bytes) else k: v
51            for k, v in items(d)}
52
53
54class Queues(dict):
55    """Queue name⇒ declaration mapping.
56
57    Arguments:
58        queues (Iterable): Initial list/tuple or dict of queues.
59        create_missing (bool): By default any unknown queues will be
60            added automatically, but if this flag is disabled the occurrence
61            of unknown queues in `wanted` will raise :exc:`KeyError`.
62        ha_policy (Sequence, str): Default HA policy for queues with none set.
63        max_priority (int): Default x-max-priority for queues with none set.
64    """
65
66    #: If set, this is a subset of queues to consume from.
67    #: The rest of the queues are then used for routing only.
68    _consume_from = None
69
70    def __init__(self, queues=None, default_exchange=None,
71                 create_missing=True, ha_policy=None, autoexchange=None,
72                 max_priority=None, default_routing_key=None):
73        dict.__init__(self)
74        self.aliases = WeakValueDictionary()
75        self.default_exchange = default_exchange
76        self.default_routing_key = default_routing_key
77        self.create_missing = create_missing
78        self.ha_policy = ha_policy
79        self.autoexchange = Exchange if autoexchange is None else autoexchange
80        self.max_priority = max_priority
81        if queues is not None and not isinstance(queues, Mapping):
82            queues = {q.name: q for q in queues}
83        for name, q in items(queues or {}):
84            self.add(q) if isinstance(q, Queue) else self.add_compat(name, **q)
85
86    def __getitem__(self, name):
87        try:
88            return self.aliases[name]
89        except KeyError:
90            return dict.__getitem__(self, name)
91
92    def __setitem__(self, name, queue):
93        if self.default_exchange and not queue.exchange:
94            queue.exchange = self.default_exchange
95        dict.__setitem__(self, name, queue)
96        if queue.alias:
97            self.aliases[queue.alias] = queue
98
99    def __missing__(self, name):
100        if self.create_missing:
101            return self.add(self.new_missing(name))
102        raise KeyError(name)
103
104    def add(self, queue, **kwargs):
105        """Add new queue.
106
107        The first argument can either be a :class:`kombu.Queue` instance,
108        or the name of a queue.  If the former the rest of the keyword
109        arguments are ignored, and options are simply taken from the queue
110        instance.
111
112        Arguments:
113            queue (kombu.Queue, str): Queue to add.
114            exchange (kombu.Exchange, str):
115                if queue is str, specifies exchange name.
116            routing_key (str): if queue is str, specifies binding key.
117            exchange_type (str): if queue is str, specifies type of exchange.
118            **options (Any): Additional declaration options used when
119                queue is a str.
120        """
121        if not isinstance(queue, Queue):
122            return self.add_compat(queue, **kwargs)
123        return self._add(queue)
124
125    def add_compat(self, name, **options):
126        # docs used to use binding_key as routing key
127        options.setdefault('routing_key', options.get('binding_key'))
128        if options['routing_key'] is None:
129            options['routing_key'] = name
130        return self._add(Queue.from_dict(name, **options))
131
132    def _add(self, queue):
133        if queue.exchange is None or queue.exchange.name == '':
134            queue.exchange = self.default_exchange
135        if not queue.routing_key:
136            queue.routing_key = self.default_routing_key
137        if self.ha_policy:
138            if queue.queue_arguments is None:
139                queue.queue_arguments = {}
140            self._set_ha_policy(queue.queue_arguments)
141        if self.max_priority is not None:
142            if queue.queue_arguments is None:
143                queue.queue_arguments = {}
144            self._set_max_priority(queue.queue_arguments)
145        self[queue.name] = queue
146        return queue
147
148    def _set_ha_policy(self, args):
149        policy = self.ha_policy
150        if isinstance(policy, (list, tuple)):
151            return args.update({'ha-mode': 'nodes',
152                                'ha-params': list(policy)})
153        args['ha-mode'] = policy
154
155    def _set_max_priority(self, args):
156        if 'x-max-priority' not in args and self.max_priority is not None:
157            return args.update({'x-max-priority': self.max_priority})
158
159    def format(self, indent=0, indent_first=True):
160        """Format routing table into string for log dumps."""
161        active = self.consume_from
162        if not active:
163            return ''
164        info = [QUEUE_FORMAT.strip().format(q)
165                for _, q in sorted(items(active))]
166        if indent_first:
167            return textindent('\n'.join(info), indent)
168        return info[0] + '\n' + textindent('\n'.join(info[1:]), indent)
169
170    def select_add(self, queue, **kwargs):
171        """Add new task queue that'll be consumed from.
172
173        The queue will be active even when a subset has been selected
174        using the :option:`celery worker -Q` option.
175        """
176        q = self.add(queue, **kwargs)
177        if self._consume_from is not None:
178            self._consume_from[q.name] = q
179        return q
180
181    def select(self, include):
182        """Select a subset of currently defined queues to consume from.
183
184        Arguments:
185            include (Sequence[str], str): Names of queues to consume from.
186        """
187        if include:
188            self._consume_from = {
189                name: self[name] for name in maybe_list(include)
190            }
191
192    def deselect(self, exclude):
193        """Deselect queues so that they won't be consumed from.
194
195        Arguments:
196            exclude (Sequence[str], str): Names of queues to avoid
197                consuming from.
198        """
199        if exclude:
200            exclude = maybe_list(exclude)
201            if self._consume_from is None:
202                # using all queues
203                return self.select(k for k in self if k not in exclude)
204            # using selection
205            for queue in exclude:
206                self._consume_from.pop(queue, None)
207
208    def new_missing(self, name):
209        return Queue(name, self.autoexchange(name), name)
210
211    @property
212    def consume_from(self):
213        if self._consume_from is not None:
214            return self._consume_from
215        return self
216
217
218class AMQP(object):
219    """App AMQP API: app.amqp."""
220
221    Connection = Connection
222    Consumer = Consumer
223    Producer = Producer
224
225    #: compat alias to Connection
226    BrokerConnection = Connection
227
228    queues_cls = Queues
229
230    #: Cached and prepared routing table.
231    _rtable = None
232
233    #: Underlying producer pool instance automatically
234    #: set by the :attr:`producer_pool`.
235    _producer_pool = None
236
237    # Exchange class/function used when defining automatic queues.
238    # For example, you can use ``autoexchange = lambda n: None`` to use the
239    # AMQP default exchange: a shortcut to bypass routing
240    # and instead send directly to the queue named in the routing key.
241    autoexchange = None
242
243    #: Max size of positional argument representation used for
244    #: logging purposes.
245    argsrepr_maxsize = 1024
246
247    #: Max size of keyword argument representation used for logging purposes.
248    kwargsrepr_maxsize = 1024
249
250    def __init__(self, app):
251        self.app = app
252        self.task_protocols = {
253            1: self.as_task_v1,
254            2: self.as_task_v2,
255        }
256        self.app._conf.bind_to(self._handle_conf_update)
257
258    @cached_property
259    def create_task_message(self):
260        return self.task_protocols[self.app.conf.task_protocol]
261
262    @cached_property
263    def send_task_message(self):
264        return self._create_task_sender()
265
266    def Queues(self, queues, create_missing=None, ha_policy=None,
267               autoexchange=None, max_priority=None):
268        # Create new :class:`Queues` instance, using queue defaults
269        # from the current configuration.
270        conf = self.app.conf
271        default_routing_key = conf.task_default_routing_key
272        if create_missing is None:
273            create_missing = conf.task_create_missing_queues
274        if ha_policy is None:
275            ha_policy = conf.task_queue_ha_policy
276        if max_priority is None:
277            max_priority = conf.task_queue_max_priority
278        if not queues and conf.task_default_queue:
279            queues = (Queue(conf.task_default_queue,
280                            exchange=self.default_exchange,
281                            routing_key=default_routing_key),)
282        autoexchange = (self.autoexchange if autoexchange is None
283                        else autoexchange)
284        return self.queues_cls(
285            queues, self.default_exchange, create_missing,
286            ha_policy, autoexchange, max_priority, default_routing_key,
287        )
288
289    def Router(self, queues=None, create_missing=None):
290        """Return the current task router."""
291        return _routes.Router(self.routes, queues or self.queues,
292                              self.app.either('task_create_missing_queues',
293                                              create_missing), app=self.app)
294
295    def flush_routes(self):
296        self._rtable = _routes.prepare(self.app.conf.task_routes)
297
298    def TaskConsumer(self, channel, queues=None, accept=None, **kw):
299        if accept is None:
300            accept = self.app.conf.accept_content
301        return self.Consumer(
302            channel, accept=accept,
303            queues=queues or list(self.queues.consume_from.values()),
304            **kw
305        )
306
307    def as_task_v2(self, task_id, name, args=None, kwargs=None,
308                   countdown=None, eta=None, group_id=None, group_index=None,
309                   expires=None, retries=0, chord=None,
310                   callbacks=None, errbacks=None, reply_to=None,
311                   time_limit=None, soft_time_limit=None,
312                   create_sent_event=False, root_id=None, parent_id=None,
313                   shadow=None, chain=None, now=None, timezone=None,
314                   origin=None, argsrepr=None, kwargsrepr=None):
315        args = args or ()
316        kwargs = kwargs or {}
317        if not isinstance(args, (list, tuple)):
318            raise TypeError('task args must be a list or tuple')
319        if not isinstance(kwargs, Mapping):
320            raise TypeError('task keyword arguments must be a mapping')
321        if countdown:  # convert countdown to ETA
322            self._verify_seconds(countdown, 'countdown')
323            now = now or self.app.now()
324            timezone = timezone or self.app.timezone
325            eta = maybe_make_aware(
326                now + timedelta(seconds=countdown), tz=timezone,
327            )
328        if isinstance(expires, numbers.Real):
329            self._verify_seconds(expires, 'expires')
330            now = now or self.app.now()
331            timezone = timezone or self.app.timezone
332            expires = maybe_make_aware(
333                now + timedelta(seconds=expires), tz=timezone,
334            )
335        if not isinstance(eta, string_t):
336            eta = eta and eta.isoformat()
337        # If we retry a task `expires` will already be ISO8601-formatted.
338        if not isinstance(expires, string_t):
339            expires = expires and expires.isoformat()
340
341        if argsrepr is None:
342            argsrepr = saferepr(args, self.argsrepr_maxsize)
343        if kwargsrepr is None:
344            kwargsrepr = saferepr(kwargs, self.kwargsrepr_maxsize)
345
346        if JSON_NEEDS_UNICODE_KEYS:  # pragma: no cover
347            if callbacks:
348                callbacks = [utf8dict(callback) for callback in callbacks]
349            if errbacks:
350                errbacks = [utf8dict(errback) for errback in errbacks]
351            if chord:
352                chord = utf8dict(chord)
353
354        if not root_id:  # empty root_id defaults to task_id
355            root_id = task_id
356
357        return task_message(
358            headers={
359                'lang': 'py',
360                'task': name,
361                'id': task_id,
362                'shadow': shadow,
363                'eta': eta,
364                'expires': expires,
365                'group': group_id,
366                'group_index': group_index,
367                'retries': retries,
368                'timelimit': [time_limit, soft_time_limit],
369                'root_id': root_id,
370                'parent_id': parent_id,
371                'argsrepr': argsrepr,
372                'kwargsrepr': kwargsrepr,
373                'origin': origin or anon_nodename()
374            },
375            properties={
376                'correlation_id': task_id,
377                'reply_to': reply_to or '',
378            },
379            body=(
380                args, kwargs, {
381                    'callbacks': callbacks,
382                    'errbacks': errbacks,
383                    'chain': chain,
384                    'chord': chord,
385                },
386            ),
387            sent_event={
388                'uuid': task_id,
389                'root_id': root_id,
390                'parent_id': parent_id,
391                'name': name,
392                'args': argsrepr,
393                'kwargs': kwargsrepr,
394                'retries': retries,
395                'eta': eta,
396                'expires': expires,
397            } if create_sent_event else None,
398        )
399
400    def as_task_v1(self, task_id, name, args=None, kwargs=None,
401                   countdown=None, eta=None, group_id=None, group_index=None,
402                   expires=None, retries=0,
403                   chord=None, callbacks=None, errbacks=None, reply_to=None,
404                   time_limit=None, soft_time_limit=None,
405                   create_sent_event=False, root_id=None, parent_id=None,
406                   shadow=None, now=None, timezone=None,
407                   **compat_kwargs):
408        args = args or ()
409        kwargs = kwargs or {}
410        utc = self.utc
411        if not isinstance(args, (list, tuple)):
412            raise TypeError('task args must be a list or tuple')
413        if not isinstance(kwargs, Mapping):
414            raise TypeError('task keyword arguments must be a mapping')
415        if countdown:  # convert countdown to ETA
416            self._verify_seconds(countdown, 'countdown')
417            now = now or self.app.now()
418            eta = now + timedelta(seconds=countdown)
419        if isinstance(expires, numbers.Real):
420            self._verify_seconds(expires, 'expires')
421            now = now or self.app.now()
422            expires = now + timedelta(seconds=expires)
423        eta = eta and eta.isoformat()
424        expires = expires and expires.isoformat()
425
426        if JSON_NEEDS_UNICODE_KEYS:  # pragma: no cover
427            if callbacks:
428                callbacks = [utf8dict(callback) for callback in callbacks]
429            if errbacks:
430                errbacks = [utf8dict(errback) for errback in errbacks]
431            if chord:
432                chord = utf8dict(chord)
433
434        return task_message(
435            headers={},
436            properties={
437                'correlation_id': task_id,
438                'reply_to': reply_to or '',
439            },
440            body={
441                'task': name,
442                'id': task_id,
443                'args': args,
444                'kwargs': kwargs,
445                'group': group_id,
446                'group_index': group_index,
447                'retries': retries,
448                'eta': eta,
449                'expires': expires,
450                'utc': utc,
451                'callbacks': callbacks,
452                'errbacks': errbacks,
453                'timelimit': (time_limit, soft_time_limit),
454                'taskset': group_id,
455                'chord': chord,
456            },
457            sent_event={
458                'uuid': task_id,
459                'name': name,
460                'args': saferepr(args),
461                'kwargs': saferepr(kwargs),
462                'retries': retries,
463                'eta': eta,
464                'expires': expires,
465            } if create_sent_event else None,
466        )
467
468    def _verify_seconds(self, s, what):
469        if s < INT_MIN:
470            raise ValueError('%s is out of range: %r' % (what, s))
471        return s
472
473    def _create_task_sender(self):
474        default_retry = self.app.conf.task_publish_retry
475        default_policy = self.app.conf.task_publish_retry_policy
476        default_delivery_mode = self.app.conf.task_default_delivery_mode
477        default_queue = self.default_queue
478        queues = self.queues
479        send_before_publish = signals.before_task_publish.send
480        before_receivers = signals.before_task_publish.receivers
481        send_after_publish = signals.after_task_publish.send
482        after_receivers = signals.after_task_publish.receivers
483
484        send_task_sent = signals.task_sent.send   # XXX compat
485        sent_receivers = signals.task_sent.receivers
486
487        default_evd = self._event_dispatcher
488        default_exchange = self.default_exchange
489
490        default_rkey = self.app.conf.task_default_routing_key
491        default_serializer = self.app.conf.task_serializer
492        default_compressor = self.app.conf.result_compression
493
494        def send_task_message(producer, name, message,
495                              exchange=None, routing_key=None, queue=None,
496                              event_dispatcher=None,
497                              retry=None, retry_policy=None,
498                              serializer=None, delivery_mode=None,
499                              compression=None, declare=None,
500                              headers=None, exchange_type=None, **kwargs):
501            retry = default_retry if retry is None else retry
502            headers2, properties, body, sent_event = message
503            if headers:
504                headers2.update(headers)
505            if kwargs:
506                properties.update(kwargs)
507
508            qname = queue
509            if queue is None and exchange is None:
510                queue = default_queue
511            if queue is not None:
512                if isinstance(queue, string_t):
513                    qname, queue = queue, queues[queue]
514                else:
515                    qname = queue.name
516
517            if delivery_mode is None:
518                try:
519                    delivery_mode = queue.exchange.delivery_mode
520                except AttributeError:
521                    pass
522                delivery_mode = delivery_mode or default_delivery_mode
523
524            if exchange_type is None:
525                try:
526                    exchange_type = queue.exchange.type
527                except AttributeError:
528                    exchange_type = 'direct'
529
530            # convert to anon-exchange, when exchange not set and direct ex.
531            if (not exchange or not routing_key) and exchange_type == 'direct':
532                exchange, routing_key = '', qname
533            elif exchange is None:
534                # not topic exchange, and exchange not undefined
535                exchange = queue.exchange.name or default_exchange
536                routing_key = routing_key or queue.routing_key or default_rkey
537            if declare is None and queue and not isinstance(queue, Broadcast):
538                declare = [queue]
539
540            # merge default and custom policy
541            retry = default_retry if retry is None else retry
542            _rp = (dict(default_policy, **retry_policy) if retry_policy
543                   else default_policy)
544
545            if before_receivers:
546                send_before_publish(
547                    sender=name, body=body,
548                    exchange=exchange, routing_key=routing_key,
549                    declare=declare, headers=headers2,
550                    properties=properties, retry_policy=retry_policy,
551                )
552            ret = producer.publish(
553                body,
554                exchange=exchange,
555                routing_key=routing_key,
556                serializer=serializer or default_serializer,
557                compression=compression or default_compressor,
558                retry=retry, retry_policy=_rp,
559                delivery_mode=delivery_mode, declare=declare,
560                headers=headers2,
561                **properties
562            )
563            if after_receivers:
564                send_after_publish(sender=name, body=body, headers=headers2,
565                                   exchange=exchange, routing_key=routing_key)
566            if sent_receivers:  # XXX deprecated
567                if isinstance(body, tuple):  # protocol version 2
568                    send_task_sent(
569                        sender=name, task_id=headers2['id'], task=name,
570                        args=body[0], kwargs=body[1],
571                        eta=headers2['eta'], taskset=headers2['group'],
572                    )
573                else:  # protocol version 1
574                    send_task_sent(
575                        sender=name, task_id=body['id'], task=name,
576                        args=body['args'], kwargs=body['kwargs'],
577                        eta=body['eta'], taskset=body['taskset'],
578                    )
579            if sent_event:
580                evd = event_dispatcher or default_evd
581                exname = exchange
582                if isinstance(exname, Exchange):
583                    exname = exname.name
584                sent_event.update({
585                    'queue': qname,
586                    'exchange': exname,
587                    'routing_key': routing_key,
588                })
589                evd.publish('task-sent', sent_event,
590                            producer, retry=retry, retry_policy=retry_policy)
591            return ret
592        return send_task_message
593
594    @cached_property
595    def default_queue(self):
596        return self.queues[self.app.conf.task_default_queue]
597
598    @cached_property
599    def queues(self):
600        """Queue name⇒ declaration mapping."""
601        return self.Queues(self.app.conf.task_queues)
602
603    @queues.setter  # noqa
604    def queues(self, queues):
605        return self.Queues(queues)
606
607    @property
608    def routes(self):
609        if self._rtable is None:
610            self.flush_routes()
611        return self._rtable
612
613    @cached_property
614    def router(self):
615        return self.Router()
616
617    @router.setter
618    def router(self, value):
619        return value
620
621    @property
622    def producer_pool(self):
623        if self._producer_pool is None:
624            self._producer_pool = pools.producers[
625                self.app.connection_for_write()]
626            self._producer_pool.limit = self.app.pool.limit
627        return self._producer_pool
628    publisher_pool = producer_pool  # compat alias
629
630    @cached_property
631    def default_exchange(self):
632        return Exchange(self.app.conf.task_default_exchange,
633                        self.app.conf.task_default_exchange_type)
634
635    @cached_property
636    def utc(self):
637        return self.app.conf.enable_utc
638
639    @cached_property
640    def _event_dispatcher(self):
641        # We call Dispatcher.publish with a custom producer
642        # so don't need the diuspatcher to be enabled.
643        return self.app.events.Dispatcher(enabled=False)
644
645    def _handle_conf_update(self, *args, **kwargs):
646        if ('task_routes' in kwargs or 'task_routes' in args):
647            self.flush_routes()
648            self.router = self.Router()
649        return
650