1# -*- coding: utf-8 -*-
2"""In-memory representation of cluster state.
3
4This module implements a data-structure used to keep
5track of the state of a cluster of workers and the tasks
6it is working on (by consuming events).
7
8For every event consumed the state is updated,
9so the state represents the state of the cluster
10at the time of the last event.
11
12Snapshots (:mod:`celery.events.snapshot`) can be used to
13take "pictures" of this state at regular intervals
14to for example, store that in a database.
15"""
16from __future__ import absolute_import, unicode_literals
17
18import bisect
19import sys
20import threading
21from collections import defaultdict
22from datetime import datetime
23from decimal import Decimal
24from itertools import islice
25from operator import itemgetter
26from time import time
27from weakref import WeakSet, ref
28
29from kombu.clocks import timetuple
30from kombu.utils.objects import cached_property
31
32from celery import states
33from celery.five import items, python_2_unicode_compatible, values
34from celery.utils.functional import LRUCache, memoize, pass1
35from celery.utils.log import get_logger
36
37try:
38    from collections.abc import Callable
39except ImportError:
40    # TODO: Remove this when we drop Python 2.7 support
41    from collections import Callable
42
43
44__all__ = ('Worker', 'Task', 'State', 'heartbeat_expires')
45
46# pylint: disable=redefined-outer-name
47# We cache globals and attribute lookups, so disable this warning.
48# pylint: disable=too-many-function-args
49# For some reason pylint thinks ._event is a method, when it's a property.
50
51#: Set if running PyPy
52PYPY = hasattr(sys, 'pypy_version_info')
53
54#: The window (in percentage) is added to the workers heartbeat
55#: frequency.  If the time between updates exceeds this window,
56#: then the worker is considered to be offline.
57HEARTBEAT_EXPIRE_WINDOW = 200
58
59#: Max drift between event timestamp and time of event received
60#: before we alert that clocks may be unsynchronized.
61HEARTBEAT_DRIFT_MAX = 16
62
63DRIFT_WARNING = """\
64Substantial drift from %s may mean clocks are out of sync.  Current drift is
65%s seconds.  [orig: %s recv: %s]
66"""
67
68logger = get_logger(__name__)
69warn = logger.warning
70
71R_STATE = '<State: events={0.event_count} tasks={0.task_count}>'
72R_WORKER = '<Worker: {0.hostname} ({0.status_string} clock:{0.clock})'
73R_TASK = '<Task: {0.name}({0.uuid}) {0.state} clock:{0.clock}>'
74
75#: Mapping of task event names to task state.
76TASK_EVENT_TO_STATE = {
77    'sent': states.PENDING,
78    'received': states.RECEIVED,
79    'started': states.STARTED,
80    'failed': states.FAILURE,
81    'retried': states.RETRY,
82    'succeeded': states.SUCCESS,
83    'revoked': states.REVOKED,
84    'rejected': states.REJECTED,
85}
86
87
88class CallableDefaultdict(defaultdict):
89    """:class:`~collections.defaultdict` with configurable __call__.
90
91    We use this for backwards compatibility in State.tasks_by_type
92    etc, which used to be a method but is now an index instead.
93
94    So you can do::
95
96        >>> add_tasks = state.tasks_by_type['proj.tasks.add']
97
98    while still supporting the method call::
99
100        >>> add_tasks = list(state.tasks_by_type(
101        ...     'proj.tasks.add', reverse=True))
102    """
103
104    def __init__(self, fun, *args, **kwargs):
105        self.fun = fun
106        super(CallableDefaultdict, self).__init__(*args, **kwargs)
107
108    def __call__(self, *args, **kwargs):
109        return self.fun(*args, **kwargs)
110
111
112Callable.register(CallableDefaultdict)  # noqa: E305
113
114
115@memoize(maxsize=1000, keyfun=lambda a, _: a[0])
116def _warn_drift(hostname, drift, local_received, timestamp):
117    # we use memoize here so the warning is only logged once per hostname
118    warn(DRIFT_WARNING, hostname, drift,
119         datetime.fromtimestamp(local_received),
120         datetime.fromtimestamp(timestamp))
121
122
123def heartbeat_expires(timestamp, freq=60,
124                      expire_window=HEARTBEAT_EXPIRE_WINDOW,
125                      Decimal=Decimal, float=float, isinstance=isinstance):
126    """Return time when heartbeat expires."""
127    # some json implementations returns decimal.Decimal objects,
128    # which aren't compatible with float.
129    freq = float(freq) if isinstance(freq, Decimal) else freq
130    if isinstance(timestamp, Decimal):
131        timestamp = float(timestamp)
132    return timestamp + (freq * (expire_window / 1e2))
133
134
135def _depickle_task(cls, fields):
136    return cls(**fields)
137
138
139def with_unique_field(attr):
140
141    def _decorate_cls(cls):
142
143        def __eq__(this, other):
144            if isinstance(other, this.__class__):
145                return getattr(this, attr) == getattr(other, attr)
146            return NotImplemented
147        cls.__eq__ = __eq__
148
149        def __ne__(this, other):
150            res = this.__eq__(other)
151            return True if res is NotImplemented else not res
152        cls.__ne__ = __ne__
153
154        def __hash__(this):
155            return hash(getattr(this, attr))
156        cls.__hash__ = __hash__
157
158        return cls
159    return _decorate_cls
160
161
162@with_unique_field('hostname')
163@python_2_unicode_compatible
164class Worker(object):
165    """Worker State."""
166
167    heartbeat_max = 4
168    expire_window = HEARTBEAT_EXPIRE_WINDOW
169
170    _fields = ('hostname', 'pid', 'freq', 'heartbeats', 'clock',
171               'active', 'processed', 'loadavg', 'sw_ident',
172               'sw_ver', 'sw_sys')
173    if not PYPY:  # pragma: no cover
174        __slots__ = _fields + ('event', '__dict__', '__weakref__')
175
176    def __init__(self, hostname=None, pid=None, freq=60,
177                 heartbeats=None, clock=0, active=None, processed=None,
178                 loadavg=None, sw_ident=None, sw_ver=None, sw_sys=None):
179        self.hostname = hostname
180        self.pid = pid
181        self.freq = freq
182        self.heartbeats = [] if heartbeats is None else heartbeats
183        self.clock = clock or 0
184        self.active = active
185        self.processed = processed
186        self.loadavg = loadavg
187        self.sw_ident = sw_ident
188        self.sw_ver = sw_ver
189        self.sw_sys = sw_sys
190        self.event = self._create_event_handler()
191
192    def __reduce__(self):
193        return self.__class__, (self.hostname, self.pid, self.freq,
194                                self.heartbeats, self.clock, self.active,
195                                self.processed, self.loadavg, self.sw_ident,
196                                self.sw_ver, self.sw_sys)
197
198    def _create_event_handler(self):
199        _set = object.__setattr__
200        hbmax = self.heartbeat_max
201        heartbeats = self.heartbeats
202        hb_pop = self.heartbeats.pop
203        hb_append = self.heartbeats.append
204
205        def event(type_, timestamp=None,
206                  local_received=None, fields=None,
207                  max_drift=HEARTBEAT_DRIFT_MAX, items=items, abs=abs, int=int,
208                  insort=bisect.insort, len=len):
209            fields = fields or {}
210            for k, v in items(fields):
211                _set(self, k, v)
212            if type_ == 'offline':
213                heartbeats[:] = []
214            else:
215                if not local_received or not timestamp:
216                    return
217                drift = abs(int(local_received) - int(timestamp))
218                if drift > max_drift:
219                    _warn_drift(self.hostname, drift,
220                                local_received, timestamp)
221                if local_received:  # pragma: no cover
222                    hearts = len(heartbeats)
223                    if hearts > hbmax - 1:
224                        hb_pop(0)
225                    if hearts and local_received > heartbeats[-1]:
226                        hb_append(local_received)
227                    else:
228                        insort(heartbeats, local_received)
229        return event
230
231    def update(self, f, **kw):
232        for k, v in items(dict(f, **kw) if kw else f):
233            setattr(self, k, v)
234
235    def __repr__(self):
236        return R_WORKER.format(self)
237
238    @property
239    def status_string(self):
240        return 'ONLINE' if self.alive else 'OFFLINE'
241
242    @property
243    def heartbeat_expires(self):
244        return heartbeat_expires(self.heartbeats[-1],
245                                 self.freq, self.expire_window)
246
247    @property
248    def alive(self, nowfun=time):
249        return bool(self.heartbeats and nowfun() < self.heartbeat_expires)
250
251    @property
252    def id(self):
253        return '{0.hostname}.{0.pid}'.format(self)
254
255
256@with_unique_field('uuid')
257@python_2_unicode_compatible
258class Task(object):
259    """Task State."""
260
261    name = received = sent = started = succeeded = failed = retried = \
262        revoked = rejected = args = kwargs = eta = expires = retries = \
263        worker = result = exception = timestamp = runtime = traceback = \
264        exchange = routing_key = root_id = parent_id = client = None
265    state = states.PENDING
266    clock = 0
267
268    _fields = (
269        'uuid', 'name', 'state', 'received', 'sent', 'started', 'rejected',
270        'succeeded', 'failed', 'retried', 'revoked', 'args', 'kwargs',
271        'eta', 'expires', 'retries', 'worker', 'result', 'exception',
272        'timestamp', 'runtime', 'traceback', 'exchange', 'routing_key',
273        'clock', 'client', 'root', 'root_id', 'parent', 'parent_id',
274        'children',
275    )
276    if not PYPY:  # pragma: no cover
277        __slots__ = ('__dict__', '__weakref__')
278
279    #: How to merge out of order events.
280    #: Disorder is detected by logical ordering (e.g., :event:`task-received`
281    #: must've happened before a :event:`task-failed` event).
282    #:
283    #: A merge rule consists of a state and a list of fields to keep from
284    #: that state. ``(RECEIVED, ('name', 'args')``, means the name and args
285    #: fields are always taken from the RECEIVED state, and any values for
286    #: these fields received before or after is simply ignored.
287    merge_rules = {
288        states.RECEIVED: (
289            'name', 'args', 'kwargs', 'parent_id',
290            'root_id', 'retries', 'eta', 'expires',
291        ),
292    }
293
294    #: meth:`info` displays these fields by default.
295    _info_fields = (
296        'args', 'kwargs', 'retries', 'result', 'eta', 'runtime',
297        'expires', 'exception', 'exchange', 'routing_key',
298        'root_id', 'parent_id',
299    )
300
301    def __init__(self, uuid=None, cluster_state=None, children=None, **kwargs):
302        self.uuid = uuid
303        self.cluster_state = cluster_state
304        if self.cluster_state is not None:
305            self.children = WeakSet(
306                self.cluster_state.tasks.get(task_id)
307                for task_id in children or ()
308                if task_id in self.cluster_state.tasks
309            )
310        else:
311            self.children = WeakSet()
312        self._serializer_handlers = {
313            'children': self._serializable_children,
314            'root': self._serializable_root,
315            'parent': self._serializable_parent,
316        }
317        if kwargs:
318            self.__dict__.update(kwargs)
319
320    def event(self, type_, timestamp=None, local_received=None, fields=None,
321              precedence=states.precedence, items=items,
322              setattr=setattr, task_event_to_state=TASK_EVENT_TO_STATE.get,
323              RETRY=states.RETRY):
324        fields = fields or {}
325
326        # using .get is faster than catching KeyError in this case.
327        state = task_event_to_state(type_)
328        if state is not None:
329            # sets, for example, self.succeeded to the timestamp.
330            setattr(self, type_, timestamp)
331        else:
332            state = type_.upper()  # custom state
333
334        # note that precedence here is reversed
335        # see implementation in celery.states.state.__lt__
336        if state != RETRY and self.state != RETRY and \
337                precedence(state) > precedence(self.state):
338            # this state logically happens-before the current state, so merge.
339            keep = self.merge_rules.get(state)
340            if keep is not None:
341                fields = {
342                    k: v for k, v in items(fields) if k in keep
343                }
344        else:
345            fields.update(state=state, timestamp=timestamp)
346
347        # update current state with info from this event.
348        self.__dict__.update(fields)
349
350    def info(self, fields=None, extra=None):
351        """Information about this task suitable for on-screen display."""
352        extra = [] if not extra else extra
353        fields = self._info_fields if fields is None else fields
354
355        def _keys():
356            for key in list(fields) + list(extra):
357                value = getattr(self, key, None)
358                if value is not None:
359                    yield key, value
360
361        return dict(_keys())
362
363    def __repr__(self):
364        return R_TASK.format(self)
365
366    def as_dict(self):
367        get = object.__getattribute__
368        handler = self._serializer_handlers.get
369        return {
370            k: handler(k, pass1)(get(self, k)) for k in self._fields
371        }
372
373    def _serializable_children(self, value):
374        return [task.id for task in self.children]
375
376    def _serializable_root(self, value):
377        return self.root_id
378
379    def _serializable_parent(self, value):
380        return self.parent_id
381
382    def __reduce__(self):
383        return _depickle_task, (self.__class__, self.as_dict())
384
385    @property
386    def id(self):
387        return self.uuid
388
389    @property
390    def origin(self):
391        return self.client if self.worker is None else self.worker.id
392
393    @property
394    def ready(self):
395        return self.state in states.READY_STATES
396
397    @cached_property
398    def parent(self):
399        # issue github.com/mher/flower/issues/648
400        try:
401            return self.parent_id and self.cluster_state.tasks.data[self.parent_id]
402        except KeyError:
403            return None
404
405    @cached_property
406    def root(self):
407        # issue github.com/mher/flower/issues/648
408        try:
409            return self.root_id and self.cluster_state.tasks.data[self.root_id]
410        except KeyError:
411            return None
412
413
414class State(object):
415    """Records clusters state."""
416
417    Worker = Worker
418    Task = Task
419    event_count = 0
420    task_count = 0
421    heap_multiplier = 4
422
423    def __init__(self, callback=None,
424                 workers=None, tasks=None, taskheap=None,
425                 max_workers_in_memory=5000, max_tasks_in_memory=10000,
426                 on_node_join=None, on_node_leave=None,
427                 tasks_by_type=None, tasks_by_worker=None):
428        self.event_callback = callback
429        self.workers = (LRUCache(max_workers_in_memory)
430                        if workers is None else workers)
431        self.tasks = (LRUCache(max_tasks_in_memory)
432                      if tasks is None else tasks)
433        self._taskheap = [] if taskheap is None else taskheap
434        self.max_workers_in_memory = max_workers_in_memory
435        self.max_tasks_in_memory = max_tasks_in_memory
436        self.on_node_join = on_node_join
437        self.on_node_leave = on_node_leave
438        self._mutex = threading.Lock()
439        self.handlers = {}
440        self._seen_types = set()
441        self._tasks_to_resolve = {}
442        self.rebuild_taskheap()
443
444        # type: Mapping[TaskName, WeakSet[Task]]
445        self.tasks_by_type = CallableDefaultdict(
446            self._tasks_by_type, WeakSet)
447        self.tasks_by_type.update(
448            _deserialize_Task_WeakSet_Mapping(tasks_by_type, self.tasks))
449
450        # type: Mapping[Hostname, WeakSet[Task]]
451        self.tasks_by_worker = CallableDefaultdict(
452            self._tasks_by_worker, WeakSet)
453        self.tasks_by_worker.update(
454            _deserialize_Task_WeakSet_Mapping(tasks_by_worker, self.tasks))
455
456    @cached_property
457    def _event(self):
458        return self._create_dispatcher()
459
460    def freeze_while(self, fun, *args, **kwargs):
461        clear_after = kwargs.pop('clear_after', False)
462        with self._mutex:
463            try:
464                return fun(*args, **kwargs)
465            finally:
466                if clear_after:
467                    self._clear()
468
469    def clear_tasks(self, ready=True):
470        with self._mutex:
471            return self._clear_tasks(ready)
472
473    def _clear_tasks(self, ready=True):
474        if ready:
475            in_progress = {
476                uuid: task for uuid, task in self.itertasks()
477                if task.state not in states.READY_STATES
478            }
479            self.tasks.clear()
480            self.tasks.update(in_progress)
481        else:
482            self.tasks.clear()
483        self._taskheap[:] = []
484
485    def _clear(self, ready=True):
486        self.workers.clear()
487        self._clear_tasks(ready)
488        self.event_count = 0
489        self.task_count = 0
490
491    def clear(self, ready=True):
492        with self._mutex:
493            return self._clear(ready)
494
495    def get_or_create_worker(self, hostname, **kwargs):
496        """Get or create worker by hostname.
497
498        Returns:
499            Tuple: of ``(worker, was_created)`` pairs.
500        """
501        try:
502            worker = self.workers[hostname]
503            if kwargs:
504                worker.update(kwargs)
505            return worker, False
506        except KeyError:
507            worker = self.workers[hostname] = self.Worker(
508                hostname, **kwargs)
509            return worker, True
510
511    def get_or_create_task(self, uuid):
512        """Get or create task by uuid."""
513        try:
514            return self.tasks[uuid], False
515        except KeyError:
516            task = self.tasks[uuid] = self.Task(uuid, cluster_state=self)
517            return task, True
518
519    def event(self, event):
520        with self._mutex:
521            return self._event(event)
522
523    def task_event(self, type_, fields):
524        """Deprecated, use :meth:`event`."""
525        return self._event(dict(fields, type='-'.join(['task', type_])))[0]
526
527    def worker_event(self, type_, fields):
528        """Deprecated, use :meth:`event`."""
529        return self._event(dict(fields, type='-'.join(['worker', type_])))[0]
530
531    def _create_dispatcher(self):
532        # noqa: C901
533        # pylint: disable=too-many-statements
534        # This code is highly optimized, but not for reusability.
535        get_handler = self.handlers.__getitem__
536        event_callback = self.event_callback
537        wfields = itemgetter('hostname', 'timestamp', 'local_received')
538        tfields = itemgetter('uuid', 'hostname', 'timestamp',
539                             'local_received', 'clock')
540        taskheap = self._taskheap
541        th_append = taskheap.append
542        th_pop = taskheap.pop
543        # Removing events from task heap is an O(n) operation,
544        # so easier to just account for the common number of events
545        # for each task (PENDING->RECEIVED->STARTED->final)
546        #: an O(n) operation
547        max_events_in_heap = self.max_tasks_in_memory * self.heap_multiplier
548        add_type = self._seen_types.add
549        on_node_join, on_node_leave = self.on_node_join, self.on_node_leave
550        tasks, Task = self.tasks, self.Task
551        workers, Worker = self.workers, self.Worker
552        # avoid updating LRU entry at getitem
553        get_worker, get_task = workers.data.__getitem__, tasks.data.__getitem__
554
555        get_task_by_type_set = self.tasks_by_type.__getitem__
556        get_task_by_worker_set = self.tasks_by_worker.__getitem__
557
558        def _event(event,
559                   timetuple=timetuple, KeyError=KeyError,
560                   insort=bisect.insort, created=True):
561            self.event_count += 1
562            if event_callback:
563                event_callback(self, event)
564            group, _, subject = event['type'].partition('-')
565            try:
566                handler = get_handler(group)
567            except KeyError:
568                pass
569            else:
570                return handler(subject, event), subject
571
572            if group == 'worker':
573                try:
574                    hostname, timestamp, local_received = wfields(event)
575                except KeyError:
576                    pass
577                else:
578                    is_offline = subject == 'offline'
579                    try:
580                        worker, created = get_worker(hostname), False
581                    except KeyError:
582                        if is_offline:
583                            worker, created = Worker(hostname), False
584                        else:
585                            worker = workers[hostname] = Worker(hostname)
586                    worker.event(subject, timestamp, local_received, event)
587                    if on_node_join and (created or subject == 'online'):
588                        on_node_join(worker)
589                    if on_node_leave and is_offline:
590                        on_node_leave(worker)
591                        workers.pop(hostname, None)
592                    return (worker, created), subject
593            elif group == 'task':
594                (uuid, hostname, timestamp,
595                 local_received, clock) = tfields(event)
596                # task-sent event is sent by client, not worker
597                is_client_event = subject == 'sent'
598                try:
599                    task, task_created = get_task(uuid), False
600                except KeyError:
601                    task = tasks[uuid] = Task(uuid, cluster_state=self)
602                    task_created = True
603                if is_client_event:
604                    task.client = hostname
605                else:
606                    try:
607                        worker = get_worker(hostname)
608                    except KeyError:
609                        worker = workers[hostname] = Worker(hostname)
610                    task.worker = worker
611                    if worker is not None and local_received:
612                        worker.event(None, local_received, timestamp)
613
614                origin = hostname if is_client_event else worker.id
615
616                # remove oldest event if exceeding the limit.
617                heaps = len(taskheap)
618                if heaps + 1 > max_events_in_heap:
619                    th_pop(0)
620
621                # most events will be dated later than the previous.
622                timetup = timetuple(clock, timestamp, origin, ref(task))
623                if heaps and timetup > taskheap[-1]:
624                    th_append(timetup)
625                else:
626                    insort(taskheap, timetup)
627
628                if subject == 'received':
629                    self.task_count += 1
630                task.event(subject, timestamp, local_received, event)
631                task_name = task.name
632                if task_name is not None:
633                    add_type(task_name)
634                    if task_created:  # add to tasks_by_type index
635                        get_task_by_type_set(task_name).add(task)
636                        get_task_by_worker_set(hostname).add(task)
637                if task.parent_id:
638                    try:
639                        parent_task = self.tasks[task.parent_id]
640                    except KeyError:
641                        self._add_pending_task_child(task)
642                    else:
643                        parent_task.children.add(task)
644                try:
645                    _children = self._tasks_to_resolve.pop(uuid)
646                except KeyError:
647                    pass
648                else:
649                    task.children.update(_children)
650
651                return (task, task_created), subject
652        return _event
653
654    def _add_pending_task_child(self, task):
655        try:
656            ch = self._tasks_to_resolve[task.parent_id]
657        except KeyError:
658            ch = self._tasks_to_resolve[task.parent_id] = WeakSet()
659        ch.add(task)
660
661    def rebuild_taskheap(self, timetuple=timetuple):
662        heap = self._taskheap[:] = [
663            timetuple(t.clock, t.timestamp, t.origin, ref(t))
664            for t in values(self.tasks)
665        ]
666        heap.sort()
667
668    def itertasks(self, limit=None):
669        for index, row in enumerate(items(self.tasks)):
670            yield row
671            if limit and index + 1 >= limit:
672                break
673
674    def tasks_by_time(self, limit=None, reverse=True):
675        """Generator yielding tasks ordered by time.
676
677        Yields:
678            Tuples of ``(uuid, Task)``.
679        """
680        _heap = self._taskheap
681        if reverse:
682            _heap = reversed(_heap)
683
684        seen = set()
685        for evtup in islice(_heap, 0, limit):
686            task = evtup[3]()
687            if task is not None:
688                uuid = task.uuid
689                if uuid not in seen:
690                    yield uuid, task
691                    seen.add(uuid)
692    tasks_by_timestamp = tasks_by_time
693
694    def _tasks_by_type(self, name, limit=None, reverse=True):
695        """Get all tasks by type.
696
697        This is slower than accessing :attr:`tasks_by_type`,
698        but will be ordered by time.
699
700        Returns:
701            Generator: giving ``(uuid, Task)`` pairs.
702        """
703        return islice(
704            ((uuid, task) for uuid, task in self.tasks_by_time(reverse=reverse)
705             if task.name == name),
706            0, limit,
707        )
708
709    def _tasks_by_worker(self, hostname, limit=None, reverse=True):
710        """Get all tasks by worker.
711
712        Slower than accessing :attr:`tasks_by_worker`, but ordered by time.
713        """
714        return islice(
715            ((uuid, task) for uuid, task in self.tasks_by_time(reverse=reverse)
716             if task.worker.hostname == hostname),
717            0, limit,
718        )
719
720    def task_types(self):
721        """Return a list of all seen task types."""
722        return sorted(self._seen_types)
723
724    def alive_workers(self):
725        """Return a list of (seemingly) alive workers."""
726        return (w for w in values(self.workers) if w.alive)
727
728    def __repr__(self):
729        return R_STATE.format(self)
730
731    def __reduce__(self):
732        return self.__class__, (
733            self.event_callback, self.workers, self.tasks, None,
734            self.max_workers_in_memory, self.max_tasks_in_memory,
735            self.on_node_join, self.on_node_leave,
736            _serialize_Task_WeakSet_Mapping(self.tasks_by_type),
737            _serialize_Task_WeakSet_Mapping(self.tasks_by_worker),
738        )
739
740
741def _serialize_Task_WeakSet_Mapping(mapping):
742    return {name: [t.id for t in tasks] for name, tasks in items(mapping)}
743
744
745def _deserialize_Task_WeakSet_Mapping(mapping, tasks):
746    return {name: WeakSet(tasks[i] for i in ids if i in tasks)
747            for name, ids in items(mapping or {})}
748