1# -*- coding: utf-8 -*- 2"""Message migration tools (Broker <-> Broker).""" 3from __future__ import absolute_import, print_function, unicode_literals 4 5import socket 6from functools import partial 7from itertools import cycle, islice 8 9from kombu import Queue, eventloop 10from kombu.common import maybe_declare 11from kombu.utils.encoding import ensure_bytes 12 13from celery.app import app_or_default 14from celery.five import python_2_unicode_compatible, string, string_t 15from celery.utils.nodenames import worker_direct 16from celery.utils.text import str_to_list 17 18__all__ = ( 19 'StopFiltering', 'State', 'republish', 'migrate_task', 20 'migrate_tasks', 'move', 'task_id_eq', 'task_id_in', 21 'start_filter', 'move_task_by_id', 'move_by_idmap', 22 'move_by_taskmap', 'move_direct', 'move_direct_by_id', 23) 24 25MOVING_PROGRESS_FMT = """\ 26Moving task {state.filtered}/{state.strtotal}: \ 27{body[task]}[{body[id]}]\ 28""" 29 30 31class StopFiltering(Exception): 32 """Semi-predicate used to signal filter stop.""" 33 34 35@python_2_unicode_compatible 36class State(object): 37 """Migration progress state.""" 38 39 count = 0 40 filtered = 0 41 total_apx = 0 42 43 @property 44 def strtotal(self): 45 if not self.total_apx: 46 return '?' 47 return string(self.total_apx) 48 49 def __repr__(self): 50 if self.filtered: 51 return '^{0.filtered}'.format(self) 52 return '{0.count}/{0.strtotal}'.format(self) 53 54 55def republish(producer, message, exchange=None, routing_key=None, 56 remove_props=None): 57 """Republish message.""" 58 if not remove_props: 59 remove_props = ['application_headers', 'content_type', 60 'content_encoding', 'headers'] 61 body = ensure_bytes(message.body) # use raw message body. 62 info, headers, props = (message.delivery_info, 63 message.headers, message.properties) 64 exchange = info['exchange'] if exchange is None else exchange 65 routing_key = info['routing_key'] if routing_key is None else routing_key 66 ctype, enc = message.content_type, message.content_encoding 67 # remove compression header, as this will be inserted again 68 # when the message is recompressed. 69 compression = headers.pop('compression', None) 70 71 for key in remove_props: 72 props.pop(key, None) 73 74 producer.publish(ensure_bytes(body), exchange=exchange, 75 routing_key=routing_key, compression=compression, 76 headers=headers, content_type=ctype, 77 content_encoding=enc, **props) 78 79 80def migrate_task(producer, body_, message, queues=None): 81 """Migrate single task message.""" 82 info = message.delivery_info 83 queues = {} if queues is None else queues 84 republish(producer, message, 85 exchange=queues.get(info['exchange']), 86 routing_key=queues.get(info['routing_key'])) 87 88 89def filter_callback(callback, tasks): 90 91 def filtered(body, message): 92 if tasks and body['task'] not in tasks: 93 return 94 95 return callback(body, message) 96 return filtered 97 98 99def migrate_tasks(source, dest, migrate=migrate_task, app=None, 100 queues=None, **kwargs): 101 """Migrate tasks from one broker to another.""" 102 app = app_or_default(app) 103 queues = prepare_queues(queues) 104 producer = app.amqp.Producer(dest, auto_declare=False) 105 migrate = partial(migrate, producer, queues=queues) 106 107 def on_declare_queue(queue): 108 new_queue = queue(producer.channel) 109 new_queue.name = queues.get(queue.name, queue.name) 110 if new_queue.routing_key == queue.name: 111 new_queue.routing_key = queues.get(queue.name, 112 new_queue.routing_key) 113 if new_queue.exchange.name == queue.name: 114 new_queue.exchange.name = queues.get(queue.name, queue.name) 115 new_queue.declare() 116 117 return start_filter(app, source, migrate, queues=queues, 118 on_declare_queue=on_declare_queue, **kwargs) 119 120 121def _maybe_queue(app, q): 122 if isinstance(q, string_t): 123 return app.amqp.queues[q] 124 return q 125 126 127def move(predicate, connection=None, exchange=None, routing_key=None, 128 source=None, app=None, callback=None, limit=None, transform=None, 129 **kwargs): 130 """Find tasks by filtering them and move the tasks to a new queue. 131 132 Arguments: 133 predicate (Callable): Filter function used to decide the messages 134 to move. Must accept the standard signature of ``(body, message)`` 135 used by Kombu consumer callbacks. If the predicate wants the 136 message to be moved it must return either: 137 138 1) a tuple of ``(exchange, routing_key)``, or 139 140 2) a :class:`~kombu.entity.Queue` instance, or 141 142 3) any other true value means the specified 143 ``exchange`` and ``routing_key`` arguments will be used. 144 connection (kombu.Connection): Custom connection to use. 145 source: List[Union[str, kombu.Queue]]: Optional list of source 146 queues to use instead of the default (queues 147 in :setting:`task_queues`). This list can also contain 148 :class:`~kombu.entity.Queue` instances. 149 exchange (str, kombu.Exchange): Default destination exchange. 150 routing_key (str): Default destination routing key. 151 limit (int): Limit number of messages to filter. 152 callback (Callable): Callback called after message moved, 153 with signature ``(state, body, message)``. 154 transform (Callable): Optional function to transform the return 155 value (destination) of the filter function. 156 157 Also supports the same keyword arguments as :func:`start_filter`. 158 159 To demonstrate, the :func:`move_task_by_id` operation can be implemented 160 like this: 161 162 .. code-block:: python 163 164 def is_wanted_task(body, message): 165 if body['id'] == wanted_id: 166 return Queue('foo', exchange=Exchange('foo'), 167 routing_key='foo') 168 169 move(is_wanted_task) 170 171 or with a transform: 172 173 .. code-block:: python 174 175 def transform(value): 176 if isinstance(value, string_t): 177 return Queue(value, Exchange(value), value) 178 return value 179 180 move(is_wanted_task, transform=transform) 181 182 Note: 183 The predicate may also return a tuple of ``(exchange, routing_key)`` 184 to specify the destination to where the task should be moved, 185 or a :class:`~kombu.entity.Queue` instance. 186 Any other true value means that the task will be moved to the 187 default exchange/routing_key. 188 """ 189 app = app_or_default(app) 190 queues = [_maybe_queue(app, queue) for queue in source or []] or None 191 with app.connection_or_acquire(connection, pool=False) as conn: 192 producer = app.amqp.Producer(conn) 193 state = State() 194 195 def on_task(body, message): 196 ret = predicate(body, message) 197 if ret: 198 if transform: 199 ret = transform(ret) 200 if isinstance(ret, Queue): 201 maybe_declare(ret, conn.default_channel) 202 ex, rk = ret.exchange.name, ret.routing_key 203 else: 204 ex, rk = expand_dest(ret, exchange, routing_key) 205 republish(producer, message, 206 exchange=ex, routing_key=rk) 207 message.ack() 208 209 state.filtered += 1 210 if callback: 211 callback(state, body, message) 212 if limit and state.filtered >= limit: 213 raise StopFiltering() 214 215 return start_filter(app, conn, on_task, consume_from=queues, **kwargs) 216 217 218def expand_dest(ret, exchange, routing_key): 219 try: 220 ex, rk = ret 221 except (TypeError, ValueError): 222 ex, rk = exchange, routing_key 223 return ex, rk 224 225 226def task_id_eq(task_id, body, message): 227 """Return true if task id equals task_id'.""" 228 return body['id'] == task_id 229 230 231def task_id_in(ids, body, message): 232 """Return true if task id is member of set ids'.""" 233 return body['id'] in ids 234 235 236def prepare_queues(queues): 237 if isinstance(queues, string_t): 238 queues = queues.split(',') 239 if isinstance(queues, list): 240 queues = dict(tuple(islice(cycle(q.split(':')), None, 2)) 241 for q in queues) 242 if queues is None: 243 queues = {} 244 return queues 245 246 247class Filterer(object): 248 249 def __init__(self, app, conn, filter, 250 limit=None, timeout=1.0, 251 ack_messages=False, tasks=None, queues=None, 252 callback=None, forever=False, on_declare_queue=None, 253 consume_from=None, state=None, accept=None, **kwargs): 254 self.app = app 255 self.conn = conn 256 self.filter = filter 257 self.limit = limit 258 self.timeout = timeout 259 self.ack_messages = ack_messages 260 self.tasks = set(str_to_list(tasks) or []) 261 self.queues = prepare_queues(queues) 262 self.callback = callback 263 self.forever = forever 264 self.on_declare_queue = on_declare_queue 265 self.consume_from = [ 266 _maybe_queue(self.app, q) 267 for q in consume_from or list(self.queues) 268 ] 269 self.state = state or State() 270 self.accept = accept 271 272 def start(self): 273 # start migrating messages. 274 with self.prepare_consumer(self.create_consumer()): 275 try: 276 for _ in eventloop(self.conn, # pragma: no cover 277 timeout=self.timeout, 278 ignore_timeouts=self.forever): 279 pass 280 except socket.timeout: 281 pass 282 except StopFiltering: 283 pass 284 return self.state 285 286 def update_state(self, body, message): 287 self.state.count += 1 288 if self.limit and self.state.count >= self.limit: 289 raise StopFiltering() 290 291 def ack_message(self, body, message): 292 message.ack() 293 294 def create_consumer(self): 295 return self.app.amqp.TaskConsumer( 296 self.conn, 297 queues=self.consume_from, 298 accept=self.accept, 299 ) 300 301 def prepare_consumer(self, consumer): 302 filter = self.filter 303 update_state = self.update_state 304 ack_message = self.ack_message 305 if self.tasks: 306 filter = filter_callback(filter, self.tasks) 307 update_state = filter_callback(update_state, self.tasks) 308 ack_message = filter_callback(ack_message, self.tasks) 309 consumer.register_callback(filter) 310 consumer.register_callback(update_state) 311 if self.ack_messages: 312 consumer.register_callback(self.ack_message) 313 if self.callback is not None: 314 callback = partial(self.callback, self.state) 315 if self.tasks: 316 callback = filter_callback(callback, self.tasks) 317 consumer.register_callback(callback) 318 self.declare_queues(consumer) 319 return consumer 320 321 def declare_queues(self, consumer): 322 # declare all queues on the new broker. 323 for queue in consumer.queues: 324 if self.queues and queue.name not in self.queues: 325 continue 326 if self.on_declare_queue is not None: 327 self.on_declare_queue(queue) 328 try: 329 _, mcount, _ = queue( 330 consumer.channel).queue_declare(passive=True) 331 if mcount: 332 self.state.total_apx += mcount 333 except self.conn.channel_errors: 334 pass 335 336 337def start_filter(app, conn, filter, limit=None, timeout=1.0, 338 ack_messages=False, tasks=None, queues=None, 339 callback=None, forever=False, on_declare_queue=None, 340 consume_from=None, state=None, accept=None, **kwargs): 341 """Filter tasks.""" 342 return Filterer( 343 app, conn, filter, 344 limit=limit, 345 timeout=timeout, 346 ack_messages=ack_messages, 347 tasks=tasks, 348 queues=queues, 349 callback=callback, 350 forever=forever, 351 on_declare_queue=on_declare_queue, 352 consume_from=consume_from, 353 state=state, 354 accept=accept, 355 **kwargs).start() 356 357 358def move_task_by_id(task_id, dest, **kwargs): 359 """Find a task by id and move it to another queue. 360 361 Arguments: 362 task_id (str): Id of task to find and move. 363 dest: (str, kombu.Queue): Destination queue. 364 transform (Callable): Optional function to transform the return 365 value (destination) of the filter function. 366 **kwargs (Any): Also supports the same keyword 367 arguments as :func:`move`. 368 """ 369 return move_by_idmap({task_id: dest}, **kwargs) 370 371 372def move_by_idmap(map, **kwargs): 373 """Move tasks by matching from a ``task_id: queue`` mapping. 374 375 Where ``queue`` is a queue to move the task to. 376 377 Example: 378 >>> move_by_idmap({ 379 ... '5bee6e82-f4ac-468e-bd3d-13e8600250bc': Queue('name'), 380 ... 'ada8652d-aef3-466b-abd2-becdaf1b82b3': Queue('name'), 381 ... '3a2b140d-7db1-41ba-ac90-c36a0ef4ab1f': Queue('name')}, 382 ... queues=['hipri']) 383 """ 384 def task_id_in_map(body, message): 385 return map.get(message.properties['correlation_id']) 386 387 # adding the limit means that we don't have to consume any more 388 # when we've found everything. 389 return move(task_id_in_map, limit=len(map), **kwargs) 390 391 392def move_by_taskmap(map, **kwargs): 393 """Move tasks by matching from a ``task_name: queue`` mapping. 394 395 ``queue`` is the queue to move the task to. 396 397 Example: 398 >>> move_by_taskmap({ 399 ... 'tasks.add': Queue('name'), 400 ... 'tasks.mul': Queue('name'), 401 ... }) 402 """ 403 def task_name_in_map(body, message): 404 return map.get(body['task']) # <- name of task 405 406 return move(task_name_in_map, **kwargs) 407 408 409def filter_status(state, body, message, **kwargs): 410 print(MOVING_PROGRESS_FMT.format(state=state, body=body, **kwargs)) 411 412 413move_direct = partial(move, transform=worker_direct) 414move_direct_by_id = partial(move_task_by_id, transform=worker_direct) 415move_direct_by_idmap = partial(move_by_idmap, transform=worker_direct) 416move_direct_by_taskmap = partial(move_by_taskmap, transform=worker_direct) 417