1# -*- coding: utf-8 -*- 2"""In-memory representation of cluster state. 3 4This module implements a data-structure used to keep 5track of the state of a cluster of workers and the tasks 6it is working on (by consuming events). 7 8For every event consumed the state is updated, 9so the state represents the state of the cluster 10at the time of the last event. 11 12Snapshots (:mod:`celery.events.snapshot`) can be used to 13take "pictures" of this state at regular intervals 14to for example, store that in a database. 15""" 16from __future__ import absolute_import, unicode_literals 17 18import bisect 19import sys 20import threading 21from collections import defaultdict 22from datetime import datetime 23from decimal import Decimal 24from itertools import islice 25from operator import itemgetter 26from time import time 27from weakref import WeakSet, ref 28 29from kombu.clocks import timetuple 30from kombu.utils.objects import cached_property 31 32from celery import states 33from celery.five import items, python_2_unicode_compatible, values 34from celery.utils.functional import LRUCache, memoize, pass1 35from celery.utils.log import get_logger 36 37try: 38 from collections.abc import Callable 39except ImportError: 40 # TODO: Remove this when we drop Python 2.7 support 41 from collections import Callable 42 43 44__all__ = ('Worker', 'Task', 'State', 'heartbeat_expires') 45 46# pylint: disable=redefined-outer-name 47# We cache globals and attribute lookups, so disable this warning. 48# pylint: disable=too-many-function-args 49# For some reason pylint thinks ._event is a method, when it's a property. 50 51#: Set if running PyPy 52PYPY = hasattr(sys, 'pypy_version_info') 53 54#: The window (in percentage) is added to the workers heartbeat 55#: frequency. If the time between updates exceeds this window, 56#: then the worker is considered to be offline. 57HEARTBEAT_EXPIRE_WINDOW = 200 58 59#: Max drift between event timestamp and time of event received 60#: before we alert that clocks may be unsynchronized. 61HEARTBEAT_DRIFT_MAX = 16 62 63DRIFT_WARNING = """\ 64Substantial drift from %s may mean clocks are out of sync. Current drift is 65%s seconds. [orig: %s recv: %s] 66""" 67 68logger = get_logger(__name__) 69warn = logger.warning 70 71R_STATE = '<State: events={0.event_count} tasks={0.task_count}>' 72R_WORKER = '<Worker: {0.hostname} ({0.status_string} clock:{0.clock})' 73R_TASK = '<Task: {0.name}({0.uuid}) {0.state} clock:{0.clock}>' 74 75#: Mapping of task event names to task state. 76TASK_EVENT_TO_STATE = { 77 'sent': states.PENDING, 78 'received': states.RECEIVED, 79 'started': states.STARTED, 80 'failed': states.FAILURE, 81 'retried': states.RETRY, 82 'succeeded': states.SUCCESS, 83 'revoked': states.REVOKED, 84 'rejected': states.REJECTED, 85} 86 87 88class CallableDefaultdict(defaultdict): 89 """:class:`~collections.defaultdict` with configurable __call__. 90 91 We use this for backwards compatibility in State.tasks_by_type 92 etc, which used to be a method but is now an index instead. 93 94 So you can do:: 95 96 >>> add_tasks = state.tasks_by_type['proj.tasks.add'] 97 98 while still supporting the method call:: 99 100 >>> add_tasks = list(state.tasks_by_type( 101 ... 'proj.tasks.add', reverse=True)) 102 """ 103 104 def __init__(self, fun, *args, **kwargs): 105 self.fun = fun 106 super(CallableDefaultdict, self).__init__(*args, **kwargs) 107 108 def __call__(self, *args, **kwargs): 109 return self.fun(*args, **kwargs) 110 111 112Callable.register(CallableDefaultdict) # noqa: E305 113 114 115@memoize(maxsize=1000, keyfun=lambda a, _: a[0]) 116def _warn_drift(hostname, drift, local_received, timestamp): 117 # we use memoize here so the warning is only logged once per hostname 118 warn(DRIFT_WARNING, hostname, drift, 119 datetime.fromtimestamp(local_received), 120 datetime.fromtimestamp(timestamp)) 121 122 123def heartbeat_expires(timestamp, freq=60, 124 expire_window=HEARTBEAT_EXPIRE_WINDOW, 125 Decimal=Decimal, float=float, isinstance=isinstance): 126 """Return time when heartbeat expires.""" 127 # some json implementations returns decimal.Decimal objects, 128 # which aren't compatible with float. 129 freq = float(freq) if isinstance(freq, Decimal) else freq 130 if isinstance(timestamp, Decimal): 131 timestamp = float(timestamp) 132 return timestamp + (freq * (expire_window / 1e2)) 133 134 135def _depickle_task(cls, fields): 136 return cls(**fields) 137 138 139def with_unique_field(attr): 140 141 def _decorate_cls(cls): 142 143 def __eq__(this, other): 144 if isinstance(other, this.__class__): 145 return getattr(this, attr) == getattr(other, attr) 146 return NotImplemented 147 cls.__eq__ = __eq__ 148 149 def __ne__(this, other): 150 res = this.__eq__(other) 151 return True if res is NotImplemented else not res 152 cls.__ne__ = __ne__ 153 154 def __hash__(this): 155 return hash(getattr(this, attr)) 156 cls.__hash__ = __hash__ 157 158 return cls 159 return _decorate_cls 160 161 162@with_unique_field('hostname') 163@python_2_unicode_compatible 164class Worker(object): 165 """Worker State.""" 166 167 heartbeat_max = 4 168 expire_window = HEARTBEAT_EXPIRE_WINDOW 169 170 _fields = ('hostname', 'pid', 'freq', 'heartbeats', 'clock', 171 'active', 'processed', 'loadavg', 'sw_ident', 172 'sw_ver', 'sw_sys') 173 if not PYPY: # pragma: no cover 174 __slots__ = _fields + ('event', '__dict__', '__weakref__') 175 176 def __init__(self, hostname=None, pid=None, freq=60, 177 heartbeats=None, clock=0, active=None, processed=None, 178 loadavg=None, sw_ident=None, sw_ver=None, sw_sys=None): 179 self.hostname = hostname 180 self.pid = pid 181 self.freq = freq 182 self.heartbeats = [] if heartbeats is None else heartbeats 183 self.clock = clock or 0 184 self.active = active 185 self.processed = processed 186 self.loadavg = loadavg 187 self.sw_ident = sw_ident 188 self.sw_ver = sw_ver 189 self.sw_sys = sw_sys 190 self.event = self._create_event_handler() 191 192 def __reduce__(self): 193 return self.__class__, (self.hostname, self.pid, self.freq, 194 self.heartbeats, self.clock, self.active, 195 self.processed, self.loadavg, self.sw_ident, 196 self.sw_ver, self.sw_sys) 197 198 def _create_event_handler(self): 199 _set = object.__setattr__ 200 hbmax = self.heartbeat_max 201 heartbeats = self.heartbeats 202 hb_pop = self.heartbeats.pop 203 hb_append = self.heartbeats.append 204 205 def event(type_, timestamp=None, 206 local_received=None, fields=None, 207 max_drift=HEARTBEAT_DRIFT_MAX, items=items, abs=abs, int=int, 208 insort=bisect.insort, len=len): 209 fields = fields or {} 210 for k, v in items(fields): 211 _set(self, k, v) 212 if type_ == 'offline': 213 heartbeats[:] = [] 214 else: 215 if not local_received or not timestamp: 216 return 217 drift = abs(int(local_received) - int(timestamp)) 218 if drift > max_drift: 219 _warn_drift(self.hostname, drift, 220 local_received, timestamp) 221 if local_received: # pragma: no cover 222 hearts = len(heartbeats) 223 if hearts > hbmax - 1: 224 hb_pop(0) 225 if hearts and local_received > heartbeats[-1]: 226 hb_append(local_received) 227 else: 228 insort(heartbeats, local_received) 229 return event 230 231 def update(self, f, **kw): 232 for k, v in items(dict(f, **kw) if kw else f): 233 setattr(self, k, v) 234 235 def __repr__(self): 236 return R_WORKER.format(self) 237 238 @property 239 def status_string(self): 240 return 'ONLINE' if self.alive else 'OFFLINE' 241 242 @property 243 def heartbeat_expires(self): 244 return heartbeat_expires(self.heartbeats[-1], 245 self.freq, self.expire_window) 246 247 @property 248 def alive(self, nowfun=time): 249 return bool(self.heartbeats and nowfun() < self.heartbeat_expires) 250 251 @property 252 def id(self): 253 return '{0.hostname}.{0.pid}'.format(self) 254 255 256@with_unique_field('uuid') 257@python_2_unicode_compatible 258class Task(object): 259 """Task State.""" 260 261 name = received = sent = started = succeeded = failed = retried = \ 262 revoked = rejected = args = kwargs = eta = expires = retries = \ 263 worker = result = exception = timestamp = runtime = traceback = \ 264 exchange = routing_key = root_id = parent_id = client = None 265 state = states.PENDING 266 clock = 0 267 268 _fields = ( 269 'uuid', 'name', 'state', 'received', 'sent', 'started', 'rejected', 270 'succeeded', 'failed', 'retried', 'revoked', 'args', 'kwargs', 271 'eta', 'expires', 'retries', 'worker', 'result', 'exception', 272 'timestamp', 'runtime', 'traceback', 'exchange', 'routing_key', 273 'clock', 'client', 'root', 'root_id', 'parent', 'parent_id', 274 'children', 275 ) 276 if not PYPY: # pragma: no cover 277 __slots__ = ('__dict__', '__weakref__') 278 279 #: How to merge out of order events. 280 #: Disorder is detected by logical ordering (e.g., :event:`task-received` 281 #: must've happened before a :event:`task-failed` event). 282 #: 283 #: A merge rule consists of a state and a list of fields to keep from 284 #: that state. ``(RECEIVED, ('name', 'args')``, means the name and args 285 #: fields are always taken from the RECEIVED state, and any values for 286 #: these fields received before or after is simply ignored. 287 merge_rules = { 288 states.RECEIVED: ( 289 'name', 'args', 'kwargs', 'parent_id', 290 'root_id', 'retries', 'eta', 'expires', 291 ), 292 } 293 294 #: meth:`info` displays these fields by default. 295 _info_fields = ( 296 'args', 'kwargs', 'retries', 'result', 'eta', 'runtime', 297 'expires', 'exception', 'exchange', 'routing_key', 298 'root_id', 'parent_id', 299 ) 300 301 def __init__(self, uuid=None, cluster_state=None, children=None, **kwargs): 302 self.uuid = uuid 303 self.cluster_state = cluster_state 304 if self.cluster_state is not None: 305 self.children = WeakSet( 306 self.cluster_state.tasks.get(task_id) 307 for task_id in children or () 308 if task_id in self.cluster_state.tasks 309 ) 310 else: 311 self.children = WeakSet() 312 self._serializer_handlers = { 313 'children': self._serializable_children, 314 'root': self._serializable_root, 315 'parent': self._serializable_parent, 316 } 317 if kwargs: 318 self.__dict__.update(kwargs) 319 320 def event(self, type_, timestamp=None, local_received=None, fields=None, 321 precedence=states.precedence, items=items, 322 setattr=setattr, task_event_to_state=TASK_EVENT_TO_STATE.get, 323 RETRY=states.RETRY): 324 fields = fields or {} 325 326 # using .get is faster than catching KeyError in this case. 327 state = task_event_to_state(type_) 328 if state is not None: 329 # sets, for example, self.succeeded to the timestamp. 330 setattr(self, type_, timestamp) 331 else: 332 state = type_.upper() # custom state 333 334 # note that precedence here is reversed 335 # see implementation in celery.states.state.__lt__ 336 if state != RETRY and self.state != RETRY and \ 337 precedence(state) > precedence(self.state): 338 # this state logically happens-before the current state, so merge. 339 keep = self.merge_rules.get(state) 340 if keep is not None: 341 fields = { 342 k: v for k, v in items(fields) if k in keep 343 } 344 else: 345 fields.update(state=state, timestamp=timestamp) 346 347 # update current state with info from this event. 348 self.__dict__.update(fields) 349 350 def info(self, fields=None, extra=None): 351 """Information about this task suitable for on-screen display.""" 352 extra = [] if not extra else extra 353 fields = self._info_fields if fields is None else fields 354 355 def _keys(): 356 for key in list(fields) + list(extra): 357 value = getattr(self, key, None) 358 if value is not None: 359 yield key, value 360 361 return dict(_keys()) 362 363 def __repr__(self): 364 return R_TASK.format(self) 365 366 def as_dict(self): 367 get = object.__getattribute__ 368 handler = self._serializer_handlers.get 369 return { 370 k: handler(k, pass1)(get(self, k)) for k in self._fields 371 } 372 373 def _serializable_children(self, value): 374 return [task.id for task in self.children] 375 376 def _serializable_root(self, value): 377 return self.root_id 378 379 def _serializable_parent(self, value): 380 return self.parent_id 381 382 def __reduce__(self): 383 return _depickle_task, (self.__class__, self.as_dict()) 384 385 @property 386 def id(self): 387 return self.uuid 388 389 @property 390 def origin(self): 391 return self.client if self.worker is None else self.worker.id 392 393 @property 394 def ready(self): 395 return self.state in states.READY_STATES 396 397 @cached_property 398 def parent(self): 399 # issue github.com/mher/flower/issues/648 400 try: 401 return self.parent_id and self.cluster_state.tasks.data[self.parent_id] 402 except KeyError: 403 return None 404 405 @cached_property 406 def root(self): 407 # issue github.com/mher/flower/issues/648 408 try: 409 return self.root_id and self.cluster_state.tasks.data[self.root_id] 410 except KeyError: 411 return None 412 413 414class State(object): 415 """Records clusters state.""" 416 417 Worker = Worker 418 Task = Task 419 event_count = 0 420 task_count = 0 421 heap_multiplier = 4 422 423 def __init__(self, callback=None, 424 workers=None, tasks=None, taskheap=None, 425 max_workers_in_memory=5000, max_tasks_in_memory=10000, 426 on_node_join=None, on_node_leave=None, 427 tasks_by_type=None, tasks_by_worker=None): 428 self.event_callback = callback 429 self.workers = (LRUCache(max_workers_in_memory) 430 if workers is None else workers) 431 self.tasks = (LRUCache(max_tasks_in_memory) 432 if tasks is None else tasks) 433 self._taskheap = [] if taskheap is None else taskheap 434 self.max_workers_in_memory = max_workers_in_memory 435 self.max_tasks_in_memory = max_tasks_in_memory 436 self.on_node_join = on_node_join 437 self.on_node_leave = on_node_leave 438 self._mutex = threading.Lock() 439 self.handlers = {} 440 self._seen_types = set() 441 self._tasks_to_resolve = {} 442 self.rebuild_taskheap() 443 444 # type: Mapping[TaskName, WeakSet[Task]] 445 self.tasks_by_type = CallableDefaultdict( 446 self._tasks_by_type, WeakSet) 447 self.tasks_by_type.update( 448 _deserialize_Task_WeakSet_Mapping(tasks_by_type, self.tasks)) 449 450 # type: Mapping[Hostname, WeakSet[Task]] 451 self.tasks_by_worker = CallableDefaultdict( 452 self._tasks_by_worker, WeakSet) 453 self.tasks_by_worker.update( 454 _deserialize_Task_WeakSet_Mapping(tasks_by_worker, self.tasks)) 455 456 @cached_property 457 def _event(self): 458 return self._create_dispatcher() 459 460 def freeze_while(self, fun, *args, **kwargs): 461 clear_after = kwargs.pop('clear_after', False) 462 with self._mutex: 463 try: 464 return fun(*args, **kwargs) 465 finally: 466 if clear_after: 467 self._clear() 468 469 def clear_tasks(self, ready=True): 470 with self._mutex: 471 return self._clear_tasks(ready) 472 473 def _clear_tasks(self, ready=True): 474 if ready: 475 in_progress = { 476 uuid: task for uuid, task in self.itertasks() 477 if task.state not in states.READY_STATES 478 } 479 self.tasks.clear() 480 self.tasks.update(in_progress) 481 else: 482 self.tasks.clear() 483 self._taskheap[:] = [] 484 485 def _clear(self, ready=True): 486 self.workers.clear() 487 self._clear_tasks(ready) 488 self.event_count = 0 489 self.task_count = 0 490 491 def clear(self, ready=True): 492 with self._mutex: 493 return self._clear(ready) 494 495 def get_or_create_worker(self, hostname, **kwargs): 496 """Get or create worker by hostname. 497 498 Returns: 499 Tuple: of ``(worker, was_created)`` pairs. 500 """ 501 try: 502 worker = self.workers[hostname] 503 if kwargs: 504 worker.update(kwargs) 505 return worker, False 506 except KeyError: 507 worker = self.workers[hostname] = self.Worker( 508 hostname, **kwargs) 509 return worker, True 510 511 def get_or_create_task(self, uuid): 512 """Get or create task by uuid.""" 513 try: 514 return self.tasks[uuid], False 515 except KeyError: 516 task = self.tasks[uuid] = self.Task(uuid, cluster_state=self) 517 return task, True 518 519 def event(self, event): 520 with self._mutex: 521 return self._event(event) 522 523 def task_event(self, type_, fields): 524 """Deprecated, use :meth:`event`.""" 525 return self._event(dict(fields, type='-'.join(['task', type_])))[0] 526 527 def worker_event(self, type_, fields): 528 """Deprecated, use :meth:`event`.""" 529 return self._event(dict(fields, type='-'.join(['worker', type_])))[0] 530 531 def _create_dispatcher(self): 532 # noqa: C901 533 # pylint: disable=too-many-statements 534 # This code is highly optimized, but not for reusability. 535 get_handler = self.handlers.__getitem__ 536 event_callback = self.event_callback 537 wfields = itemgetter('hostname', 'timestamp', 'local_received') 538 tfields = itemgetter('uuid', 'hostname', 'timestamp', 539 'local_received', 'clock') 540 taskheap = self._taskheap 541 th_append = taskheap.append 542 th_pop = taskheap.pop 543 # Removing events from task heap is an O(n) operation, 544 # so easier to just account for the common number of events 545 # for each task (PENDING->RECEIVED->STARTED->final) 546 #: an O(n) operation 547 max_events_in_heap = self.max_tasks_in_memory * self.heap_multiplier 548 add_type = self._seen_types.add 549 on_node_join, on_node_leave = self.on_node_join, self.on_node_leave 550 tasks, Task = self.tasks, self.Task 551 workers, Worker = self.workers, self.Worker 552 # avoid updating LRU entry at getitem 553 get_worker, get_task = workers.data.__getitem__, tasks.data.__getitem__ 554 555 get_task_by_type_set = self.tasks_by_type.__getitem__ 556 get_task_by_worker_set = self.tasks_by_worker.__getitem__ 557 558 def _event(event, 559 timetuple=timetuple, KeyError=KeyError, 560 insort=bisect.insort, created=True): 561 self.event_count += 1 562 if event_callback: 563 event_callback(self, event) 564 group, _, subject = event['type'].partition('-') 565 try: 566 handler = get_handler(group) 567 except KeyError: 568 pass 569 else: 570 return handler(subject, event), subject 571 572 if group == 'worker': 573 try: 574 hostname, timestamp, local_received = wfields(event) 575 except KeyError: 576 pass 577 else: 578 is_offline = subject == 'offline' 579 try: 580 worker, created = get_worker(hostname), False 581 except KeyError: 582 if is_offline: 583 worker, created = Worker(hostname), False 584 else: 585 worker = workers[hostname] = Worker(hostname) 586 worker.event(subject, timestamp, local_received, event) 587 if on_node_join and (created or subject == 'online'): 588 on_node_join(worker) 589 if on_node_leave and is_offline: 590 on_node_leave(worker) 591 workers.pop(hostname, None) 592 return (worker, created), subject 593 elif group == 'task': 594 (uuid, hostname, timestamp, 595 local_received, clock) = tfields(event) 596 # task-sent event is sent by client, not worker 597 is_client_event = subject == 'sent' 598 try: 599 task, task_created = get_task(uuid), False 600 except KeyError: 601 task = tasks[uuid] = Task(uuid, cluster_state=self) 602 task_created = True 603 if is_client_event: 604 task.client = hostname 605 else: 606 try: 607 worker = get_worker(hostname) 608 except KeyError: 609 worker = workers[hostname] = Worker(hostname) 610 task.worker = worker 611 if worker is not None and local_received: 612 worker.event(None, local_received, timestamp) 613 614 origin = hostname if is_client_event else worker.id 615 616 # remove oldest event if exceeding the limit. 617 heaps = len(taskheap) 618 if heaps + 1 > max_events_in_heap: 619 th_pop(0) 620 621 # most events will be dated later than the previous. 622 timetup = timetuple(clock, timestamp, origin, ref(task)) 623 if heaps and timetup > taskheap[-1]: 624 th_append(timetup) 625 else: 626 insort(taskheap, timetup) 627 628 if subject == 'received': 629 self.task_count += 1 630 task.event(subject, timestamp, local_received, event) 631 task_name = task.name 632 if task_name is not None: 633 add_type(task_name) 634 if task_created: # add to tasks_by_type index 635 get_task_by_type_set(task_name).add(task) 636 get_task_by_worker_set(hostname).add(task) 637 if task.parent_id: 638 try: 639 parent_task = self.tasks[task.parent_id] 640 except KeyError: 641 self._add_pending_task_child(task) 642 else: 643 parent_task.children.add(task) 644 try: 645 _children = self._tasks_to_resolve.pop(uuid) 646 except KeyError: 647 pass 648 else: 649 task.children.update(_children) 650 651 return (task, task_created), subject 652 return _event 653 654 def _add_pending_task_child(self, task): 655 try: 656 ch = self._tasks_to_resolve[task.parent_id] 657 except KeyError: 658 ch = self._tasks_to_resolve[task.parent_id] = WeakSet() 659 ch.add(task) 660 661 def rebuild_taskheap(self, timetuple=timetuple): 662 heap = self._taskheap[:] = [ 663 timetuple(t.clock, t.timestamp, t.origin, ref(t)) 664 for t in values(self.tasks) 665 ] 666 heap.sort() 667 668 def itertasks(self, limit=None): 669 for index, row in enumerate(items(self.tasks)): 670 yield row 671 if limit and index + 1 >= limit: 672 break 673 674 def tasks_by_time(self, limit=None, reverse=True): 675 """Generator yielding tasks ordered by time. 676 677 Yields: 678 Tuples of ``(uuid, Task)``. 679 """ 680 _heap = self._taskheap 681 if reverse: 682 _heap = reversed(_heap) 683 684 seen = set() 685 for evtup in islice(_heap, 0, limit): 686 task = evtup[3]() 687 if task is not None: 688 uuid = task.uuid 689 if uuid not in seen: 690 yield uuid, task 691 seen.add(uuid) 692 tasks_by_timestamp = tasks_by_time 693 694 def _tasks_by_type(self, name, limit=None, reverse=True): 695 """Get all tasks by type. 696 697 This is slower than accessing :attr:`tasks_by_type`, 698 but will be ordered by time. 699 700 Returns: 701 Generator: giving ``(uuid, Task)`` pairs. 702 """ 703 return islice( 704 ((uuid, task) for uuid, task in self.tasks_by_time(reverse=reverse) 705 if task.name == name), 706 0, limit, 707 ) 708 709 def _tasks_by_worker(self, hostname, limit=None, reverse=True): 710 """Get all tasks by worker. 711 712 Slower than accessing :attr:`tasks_by_worker`, but ordered by time. 713 """ 714 return islice( 715 ((uuid, task) for uuid, task in self.tasks_by_time(reverse=reverse) 716 if task.worker.hostname == hostname), 717 0, limit, 718 ) 719 720 def task_types(self): 721 """Return a list of all seen task types.""" 722 return sorted(self._seen_types) 723 724 def alive_workers(self): 725 """Return a list of (seemingly) alive workers.""" 726 return (w for w in values(self.workers) if w.alive) 727 728 def __repr__(self): 729 return R_STATE.format(self) 730 731 def __reduce__(self): 732 return self.__class__, ( 733 self.event_callback, self.workers, self.tasks, None, 734 self.max_workers_in_memory, self.max_tasks_in_memory, 735 self.on_node_join, self.on_node_leave, 736 _serialize_Task_WeakSet_Mapping(self.tasks_by_type), 737 _serialize_Task_WeakSet_Mapping(self.tasks_by_worker), 738 ) 739 740 741def _serialize_Task_WeakSet_Mapping(mapping): 742 return {name: [t.id for t in tasks] for name, tasks in items(mapping)} 743 744 745def _deserialize_Task_WeakSet_Mapping(mapping, tasks): 746 return {name: WeakSet(tasks[i] for i in ids if i in tasks) 747 for name, ids in items(mapping or {})} 748