1# -*- coding: utf-8 -*-
2"""Message migration tools (Broker <-> Broker)."""
3from __future__ import absolute_import, print_function, unicode_literals
4
5import socket
6from functools import partial
7from itertools import cycle, islice
8
9from kombu import Queue, eventloop
10from kombu.common import maybe_declare
11from kombu.utils.encoding import ensure_bytes
12
13from celery.app import app_or_default
14from celery.five import python_2_unicode_compatible, string, string_t
15from celery.utils.nodenames import worker_direct
16from celery.utils.text import str_to_list
17
18__all__ = (
19    'StopFiltering', 'State', 'republish', 'migrate_task',
20    'migrate_tasks', 'move', 'task_id_eq', 'task_id_in',
21    'start_filter', 'move_task_by_id', 'move_by_idmap',
22    'move_by_taskmap', 'move_direct', 'move_direct_by_id',
23)
24
25MOVING_PROGRESS_FMT = """\
26Moving task {state.filtered}/{state.strtotal}: \
27{body[task]}[{body[id]}]\
28"""
29
30
31class StopFiltering(Exception):
32    """Semi-predicate used to signal filter stop."""
33
34
35@python_2_unicode_compatible
36class State(object):
37    """Migration progress state."""
38
39    count = 0
40    filtered = 0
41    total_apx = 0
42
43    @property
44    def strtotal(self):
45        if not self.total_apx:
46            return '?'
47        return string(self.total_apx)
48
49    def __repr__(self):
50        if self.filtered:
51            return '^{0.filtered}'.format(self)
52        return '{0.count}/{0.strtotal}'.format(self)
53
54
55def republish(producer, message, exchange=None, routing_key=None,
56              remove_props=None):
57    """Republish message."""
58    if not remove_props:
59        remove_props = ['application_headers', 'content_type',
60                        'content_encoding', 'headers']
61    body = ensure_bytes(message.body)  # use raw message body.
62    info, headers, props = (message.delivery_info,
63                            message.headers, message.properties)
64    exchange = info['exchange'] if exchange is None else exchange
65    routing_key = info['routing_key'] if routing_key is None else routing_key
66    ctype, enc = message.content_type, message.content_encoding
67    # remove compression header, as this will be inserted again
68    # when the message is recompressed.
69    compression = headers.pop('compression', None)
70
71    for key in remove_props:
72        props.pop(key, None)
73
74    producer.publish(ensure_bytes(body), exchange=exchange,
75                     routing_key=routing_key, compression=compression,
76                     headers=headers, content_type=ctype,
77                     content_encoding=enc, **props)
78
79
80def migrate_task(producer, body_, message, queues=None):
81    """Migrate single task message."""
82    info = message.delivery_info
83    queues = {} if queues is None else queues
84    republish(producer, message,
85              exchange=queues.get(info['exchange']),
86              routing_key=queues.get(info['routing_key']))
87
88
89def filter_callback(callback, tasks):
90
91    def filtered(body, message):
92        if tasks and body['task'] not in tasks:
93            return
94
95        return callback(body, message)
96    return filtered
97
98
99def migrate_tasks(source, dest, migrate=migrate_task, app=None,
100                  queues=None, **kwargs):
101    """Migrate tasks from one broker to another."""
102    app = app_or_default(app)
103    queues = prepare_queues(queues)
104    producer = app.amqp.Producer(dest, auto_declare=False)
105    migrate = partial(migrate, producer, queues=queues)
106
107    def on_declare_queue(queue):
108        new_queue = queue(producer.channel)
109        new_queue.name = queues.get(queue.name, queue.name)
110        if new_queue.routing_key == queue.name:
111            new_queue.routing_key = queues.get(queue.name,
112                                               new_queue.routing_key)
113        if new_queue.exchange.name == queue.name:
114            new_queue.exchange.name = queues.get(queue.name, queue.name)
115        new_queue.declare()
116
117    return start_filter(app, source, migrate, queues=queues,
118                        on_declare_queue=on_declare_queue, **kwargs)
119
120
121def _maybe_queue(app, q):
122    if isinstance(q, string_t):
123        return app.amqp.queues[q]
124    return q
125
126
127def move(predicate, connection=None, exchange=None, routing_key=None,
128         source=None, app=None, callback=None, limit=None, transform=None,
129         **kwargs):
130    """Find tasks by filtering them and move the tasks to a new queue.
131
132    Arguments:
133        predicate (Callable): Filter function used to decide the messages
134            to move.  Must accept the standard signature of ``(body, message)``
135            used by Kombu consumer callbacks.  If the predicate wants the
136            message to be moved it must return either:
137
138                1) a tuple of ``(exchange, routing_key)``, or
139
140                2) a :class:`~kombu.entity.Queue` instance, or
141
142                3) any other true value means the specified
143                    ``exchange`` and ``routing_key`` arguments will be used.
144        connection (kombu.Connection): Custom connection to use.
145        source: List[Union[str, kombu.Queue]]: Optional list of source
146            queues to use instead of the default (queues
147            in :setting:`task_queues`).  This list can also contain
148            :class:`~kombu.entity.Queue` instances.
149        exchange (str, kombu.Exchange): Default destination exchange.
150        routing_key (str): Default destination routing key.
151        limit (int): Limit number of messages to filter.
152        callback (Callable): Callback called after message moved,
153            with signature ``(state, body, message)``.
154        transform (Callable): Optional function to transform the return
155            value (destination) of the filter function.
156
157    Also supports the same keyword arguments as :func:`start_filter`.
158
159    To demonstrate, the :func:`move_task_by_id` operation can be implemented
160    like this:
161
162    .. code-block:: python
163
164        def is_wanted_task(body, message):
165            if body['id'] == wanted_id:
166                return Queue('foo', exchange=Exchange('foo'),
167                             routing_key='foo')
168
169        move(is_wanted_task)
170
171    or with a transform:
172
173    .. code-block:: python
174
175        def transform(value):
176            if isinstance(value, string_t):
177                return Queue(value, Exchange(value), value)
178            return value
179
180        move(is_wanted_task, transform=transform)
181
182    Note:
183        The predicate may also return a tuple of ``(exchange, routing_key)``
184        to specify the destination to where the task should be moved,
185        or a :class:`~kombu.entity.Queue` instance.
186        Any other true value means that the task will be moved to the
187        default exchange/routing_key.
188    """
189    app = app_or_default(app)
190    queues = [_maybe_queue(app, queue) for queue in source or []] or None
191    with app.connection_or_acquire(connection, pool=False) as conn:
192        producer = app.amqp.Producer(conn)
193        state = State()
194
195        def on_task(body, message):
196            ret = predicate(body, message)
197            if ret:
198                if transform:
199                    ret = transform(ret)
200                if isinstance(ret, Queue):
201                    maybe_declare(ret, conn.default_channel)
202                    ex, rk = ret.exchange.name, ret.routing_key
203                else:
204                    ex, rk = expand_dest(ret, exchange, routing_key)
205                republish(producer, message,
206                          exchange=ex, routing_key=rk)
207                message.ack()
208
209                state.filtered += 1
210                if callback:
211                    callback(state, body, message)
212                if limit and state.filtered >= limit:
213                    raise StopFiltering()
214
215        return start_filter(app, conn, on_task, consume_from=queues, **kwargs)
216
217
218def expand_dest(ret, exchange, routing_key):
219    try:
220        ex, rk = ret
221    except (TypeError, ValueError):
222        ex, rk = exchange, routing_key
223    return ex, rk
224
225
226def task_id_eq(task_id, body, message):
227    """Return true if task id equals task_id'."""
228    return body['id'] == task_id
229
230
231def task_id_in(ids, body, message):
232    """Return true if task id is member of set ids'."""
233    return body['id'] in ids
234
235
236def prepare_queues(queues):
237    if isinstance(queues, string_t):
238        queues = queues.split(',')
239    if isinstance(queues, list):
240        queues = dict(tuple(islice(cycle(q.split(':')), None, 2))
241                      for q in queues)
242    if queues is None:
243        queues = {}
244    return queues
245
246
247class Filterer(object):
248
249    def __init__(self, app, conn, filter,
250                 limit=None, timeout=1.0,
251                 ack_messages=False, tasks=None, queues=None,
252                 callback=None, forever=False, on_declare_queue=None,
253                 consume_from=None, state=None, accept=None, **kwargs):
254        self.app = app
255        self.conn = conn
256        self.filter = filter
257        self.limit = limit
258        self.timeout = timeout
259        self.ack_messages = ack_messages
260        self.tasks = set(str_to_list(tasks) or [])
261        self.queues = prepare_queues(queues)
262        self.callback = callback
263        self.forever = forever
264        self.on_declare_queue = on_declare_queue
265        self.consume_from = [
266            _maybe_queue(self.app, q)
267            for q in consume_from or list(self.queues)
268        ]
269        self.state = state or State()
270        self.accept = accept
271
272    def start(self):
273        # start migrating messages.
274        with self.prepare_consumer(self.create_consumer()):
275            try:
276                for _ in eventloop(self.conn,  # pragma: no cover
277                                   timeout=self.timeout,
278                                   ignore_timeouts=self.forever):
279                    pass
280            except socket.timeout:
281                pass
282            except StopFiltering:
283                pass
284        return self.state
285
286    def update_state(self, body, message):
287        self.state.count += 1
288        if self.limit and self.state.count >= self.limit:
289            raise StopFiltering()
290
291    def ack_message(self, body, message):
292        message.ack()
293
294    def create_consumer(self):
295        return self.app.amqp.TaskConsumer(
296            self.conn,
297            queues=self.consume_from,
298            accept=self.accept,
299        )
300
301    def prepare_consumer(self, consumer):
302        filter = self.filter
303        update_state = self.update_state
304        ack_message = self.ack_message
305        if self.tasks:
306            filter = filter_callback(filter, self.tasks)
307            update_state = filter_callback(update_state, self.tasks)
308            ack_message = filter_callback(ack_message, self.tasks)
309        consumer.register_callback(filter)
310        consumer.register_callback(update_state)
311        if self.ack_messages:
312            consumer.register_callback(self.ack_message)
313        if self.callback is not None:
314            callback = partial(self.callback, self.state)
315            if self.tasks:
316                callback = filter_callback(callback, self.tasks)
317            consumer.register_callback(callback)
318        self.declare_queues(consumer)
319        return consumer
320
321    def declare_queues(self, consumer):
322        # declare all queues on the new broker.
323        for queue in consumer.queues:
324            if self.queues and queue.name not in self.queues:
325                continue
326            if self.on_declare_queue is not None:
327                self.on_declare_queue(queue)
328            try:
329                _, mcount, _ = queue(
330                    consumer.channel).queue_declare(passive=True)
331                if mcount:
332                    self.state.total_apx += mcount
333            except self.conn.channel_errors:
334                pass
335
336
337def start_filter(app, conn, filter, limit=None, timeout=1.0,
338                 ack_messages=False, tasks=None, queues=None,
339                 callback=None, forever=False, on_declare_queue=None,
340                 consume_from=None, state=None, accept=None, **kwargs):
341    """Filter tasks."""
342    return Filterer(
343        app, conn, filter,
344        limit=limit,
345        timeout=timeout,
346        ack_messages=ack_messages,
347        tasks=tasks,
348        queues=queues,
349        callback=callback,
350        forever=forever,
351        on_declare_queue=on_declare_queue,
352        consume_from=consume_from,
353        state=state,
354        accept=accept,
355        **kwargs).start()
356
357
358def move_task_by_id(task_id, dest, **kwargs):
359    """Find a task by id and move it to another queue.
360
361    Arguments:
362        task_id (str): Id of task to find and move.
363        dest: (str, kombu.Queue): Destination queue.
364        transform (Callable): Optional function to transform the return
365            value (destination) of the filter function.
366        **kwargs (Any): Also supports the same keyword
367            arguments as :func:`move`.
368    """
369    return move_by_idmap({task_id: dest}, **kwargs)
370
371
372def move_by_idmap(map, **kwargs):
373    """Move tasks by matching from a ``task_id: queue`` mapping.
374
375    Where ``queue`` is a queue to move the task to.
376
377    Example:
378        >>> move_by_idmap({
379        ...     '5bee6e82-f4ac-468e-bd3d-13e8600250bc': Queue('name'),
380        ...     'ada8652d-aef3-466b-abd2-becdaf1b82b3': Queue('name'),
381        ...     '3a2b140d-7db1-41ba-ac90-c36a0ef4ab1f': Queue('name')},
382        ...   queues=['hipri'])
383    """
384    def task_id_in_map(body, message):
385        return map.get(message.properties['correlation_id'])
386
387    # adding the limit means that we don't have to consume any more
388    # when we've found everything.
389    return move(task_id_in_map, limit=len(map), **kwargs)
390
391
392def move_by_taskmap(map, **kwargs):
393    """Move tasks by matching from a ``task_name: queue`` mapping.
394
395    ``queue`` is the queue to move the task to.
396
397    Example:
398        >>> move_by_taskmap({
399        ...     'tasks.add': Queue('name'),
400        ...     'tasks.mul': Queue('name'),
401        ... })
402    """
403    def task_name_in_map(body, message):
404        return map.get(body['task'])  # <- name of task
405
406    return move(task_name_in_map, **kwargs)
407
408
409def filter_status(state, body, message, **kwargs):
410    print(MOVING_PROGRESS_FMT.format(state=state, body=body, **kwargs))
411
412
413move_direct = partial(move, transform=worker_direct)
414move_direct_by_id = partial(move_task_by_id, transform=worker_direct)
415move_direct_by_idmap = partial(move_by_idmap, transform=worker_direct)
416move_direct_by_taskmap = partial(move_by_taskmap, transform=worker_direct)
417