1import asyncio
2import heapq
3import inspect
4import itertools
5import json
6import logging
7import math
8import operator
9import os
10import random
11import sys
12import uuid
13import warnings
14import weakref
15from collections import defaultdict, deque
16from collections.abc import (
17    Callable,
18    Collection,
19    Hashable,
20    Iterable,
21    Iterator,
22    Mapping,
23    Set,
24)
25from contextlib import suppress
26from datetime import timedelta
27from functools import partial
28from numbers import Number
29from typing import Any, ClassVar, Container
30from typing import cast as pep484_cast
31
32import psutil
33from sortedcontainers import SortedDict, SortedSet
34from tlz import (
35    compose,
36    first,
37    groupby,
38    merge,
39    merge_sorted,
40    merge_with,
41    pluck,
42    second,
43    valmap,
44)
45from tornado.ioloop import IOLoop, PeriodicCallback
46
47import dask
48from dask.highlevelgraph import HighLevelGraph
49from dask.utils import format_bytes, format_time, parse_bytes, parse_timedelta, tmpfile
50from dask.widgets import get_template
51
52from distributed.utils import recursive_to_dict
53
54from . import preloading, profile
55from . import versions as version_module
56from .active_memory_manager import ActiveMemoryManagerExtension
57from .batched import BatchedSend
58from .comm import (
59    Comm,
60    get_address_host,
61    normalize_address,
62    resolve_address,
63    unparse_host_port,
64)
65from .comm.addressing import addresses_from_user_args
66from .core import CommClosedError, Status, clean_exception, rpc, send_recv
67from .diagnostics.memory_sampler import MemorySamplerExtension
68from .diagnostics.plugin import SchedulerPlugin, _get_plugin_name
69from .event import EventExtension
70from .http import get_handlers
71from .lock import LockExtension
72from .metrics import time
73from .multi_lock import MultiLockExtension
74from .node import ServerNode
75from .proctitle import setproctitle
76from .protocol.pickle import loads
77from .publish import PublishExtension
78from .pubsub import PubSubSchedulerExtension
79from .queues import QueueExtension
80from .recreate_tasks import ReplayTaskScheduler
81from .security import Security
82from .semaphore import SemaphoreExtension
83from .stealing import WorkStealing
84from .utils import (
85    All,
86    TimeoutError,
87    empty_context,
88    get_fileno_limit,
89    key_split,
90    key_split_group,
91    log_errors,
92    no_default,
93    validate_key,
94)
95from .utils_comm import gather_from_workers, retry_operation, scatter_to_workers
96from .utils_perf import disable_gc_diagnosis, enable_gc_diagnosis
97from .variable import VariableExtension
98
99try:
100    from cython import compiled
101except ImportError:
102    compiled = False
103
104if compiled:
105    from cython import (
106        Py_hash_t,
107        Py_ssize_t,
108        bint,
109        cast,
110        ccall,
111        cclass,
112        cfunc,
113        declare,
114        double,
115        exceptval,
116        final,
117        inline,
118        nogil,
119    )
120else:
121    from ctypes import c_double as double
122    from ctypes import c_ssize_t as Py_hash_t
123    from ctypes import c_ssize_t as Py_ssize_t
124
125    bint = bool
126
127    def cast(T, v, *a, **k):
128        return v
129
130    def ccall(func):
131        return func
132
133    def cclass(cls):
134        return cls
135
136    def cfunc(func):
137        return func
138
139    def declare(*a, **k):
140        if len(a) == 2:
141            return a[1]
142        else:
143            pass
144
145    def exceptval(*a, **k):
146        def wrapper(func):
147            return func
148
149        return wrapper
150
151    def final(cls):
152        return cls
153
154    def inline(func):
155        return func
156
157    def nogil(func):
158        return func
159
160
161if sys.version_info < (3, 8):
162    try:
163        import pickle5 as pickle
164    except ImportError:
165        import pickle
166else:
167    import pickle
168
169
170logger = logging.getLogger(__name__)
171
172
173LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
174DEFAULT_DATA_SIZE = declare(
175    Py_ssize_t, parse_bytes(dask.config.get("distributed.scheduler.default-data-size"))
176)
177
178DEFAULT_EXTENSIONS = [
179    LockExtension,
180    MultiLockExtension,
181    PublishExtension,
182    ReplayTaskScheduler,
183    QueueExtension,
184    VariableExtension,
185    PubSubSchedulerExtension,
186    SemaphoreExtension,
187    EventExtension,
188    ActiveMemoryManagerExtension,
189    MemorySamplerExtension,
190]
191
192ALL_TASK_STATES = declare(
193    set, {"released", "waiting", "no-worker", "processing", "erred", "memory"}
194)
195globals()["ALL_TASK_STATES"] = ALL_TASK_STATES
196COMPILED = declare(bint, compiled)
197globals()["COMPILED"] = COMPILED
198
199
200@final
201@cclass
202class ClientState:
203    """
204    A simple object holding information about a client.
205
206    .. attribute:: client_key: str
207
208       A unique identifier for this client.  This is generally an opaque
209       string generated by the client itself.
210
211    .. attribute:: wants_what: {TaskState}
212
213       A set of tasks this client wants kept in memory, so that it can
214       download its result when desired.  This is the reverse mapping of
215       :class:`TaskState.who_wants`.
216
217       Tasks are typically removed from this set when the corresponding
218       object in the client's space (for example a ``Future`` or a Dask
219       collection) gets garbage-collected.
220
221    """
222
223    _client_key: str
224    _hash: Py_hash_t
225    _wants_what: set
226    _last_seen: double
227    _versions: dict
228
229    __slots__ = ("_client_key", "_hash", "_wants_what", "_last_seen", "_versions")
230
231    def __init__(self, client: str, versions: dict = None):
232        self._client_key = client
233        self._hash = hash(client)
234        self._wants_what = set()
235        self._last_seen = time()
236        self._versions = versions or {}
237
238    def __hash__(self):
239        return self._hash
240
241    def __eq__(self, other):
242        typ_self: type = type(self)
243        typ_other: type = type(other)
244        if typ_self == typ_other:
245            other_cs: ClientState = other
246            return self._client_key == other_cs._client_key
247        else:
248            return False
249
250    def __repr__(self):
251        return "<Client '%s'>" % self._client_key
252
253    def __str__(self):
254        return self._client_key
255
256    @property
257    def client_key(self):
258        return self._client_key
259
260    @property
261    def wants_what(self):
262        return self._wants_what
263
264    @property
265    def last_seen(self):
266        return self._last_seen
267
268    @property
269    def versions(self):
270        return self._versions
271
272
273@final
274@cclass
275class MemoryState:
276    """Memory readings on a worker or on the whole cluster.
277
278    managed
279        Sum of the output of sizeof() for all dask keys held by the worker, both in
280        memory and spilled to disk
281    managed_in_memory
282        Sum of the output of sizeof() for the dask keys held in RAM
283    managed_spilled
284        Sum of the output of sizeof() for the dask keys spilled to the hard drive.
285        Note that this is the size in memory; serialized size may be different.
286    process
287        Total RSS memory measured by the OS on the worker process.
288        This is always exactly equal to managed_in_memory + unmanaged.
289    unmanaged
290        process - managed_in_memory. This is the sum of
291
292        - Python interpreter and modules
293        - global variables
294        - memory temporarily allocated by the dask tasks that are currently running
295        - memory fragmentation
296        - memory leaks
297        - memory not yet garbage collected
298        - memory not yet free()'d by the Python memory manager to the OS
299
300    unmanaged_old
301        Minimum of the 'unmanaged' measures over the last
302        ``distributed.memory.recent-to-old-time`` seconds
303    unmanaged_recent
304        unmanaged - unmanaged_old; in other words process memory that has been recently
305        allocated but is not accounted for by dask; hopefully it's mostly a temporary
306        spike.
307    optimistic
308        managed_in_memory + unmanaged_old; in other words the memory held long-term by
309        the process under the hopeful assumption that all unmanaged_recent memory is a
310        temporary spike
311    """
312
313    __slots__ = ("_process", "_managed_in_memory", "_managed_spilled", "_unmanaged_old")
314
315    _process: Py_ssize_t
316    _managed_in_memory: Py_ssize_t
317    _managed_spilled: Py_ssize_t
318    _unmanaged_old: Py_ssize_t
319
320    def __init__(
321        self,
322        *,
323        process: Py_ssize_t,
324        unmanaged_old: Py_ssize_t,
325        managed: Py_ssize_t,
326        managed_spilled: Py_ssize_t,
327    ):
328        # Some data arrives with the heartbeat, some other arrives in realtime as the
329        # tasks progress. Also, sizeof() is not guaranteed to return correct results.
330        # This can cause glitches where a partial measure is larger than the whole, so
331        # we need to force all numbers to add up exactly by definition.
332        self._process = process
333        self._managed_spilled = min(managed_spilled, managed)
334        # Subtractions between unsigned ints guaranteed by construction to be >= 0
335        self._managed_in_memory = min(managed - self._managed_spilled, process)
336        self._unmanaged_old = min(unmanaged_old, process - self._managed_in_memory)
337
338    @property
339    def process(self) -> Py_ssize_t:
340        return self._process
341
342    @property
343    def managed_in_memory(self) -> Py_ssize_t:
344        return self._managed_in_memory
345
346    @property
347    def managed_spilled(self) -> Py_ssize_t:
348        return self._managed_spilled
349
350    @property
351    def unmanaged_old(self) -> Py_ssize_t:
352        return self._unmanaged_old
353
354    @classmethod
355    def sum(cls, *infos: "MemoryState") -> "MemoryState":
356        out = MemoryState(process=0, unmanaged_old=0, managed=0, managed_spilled=0)
357        ms: MemoryState
358        for ms in infos:
359            out._process += ms._process
360            out._managed_spilled += ms._managed_spilled
361            out._managed_in_memory += ms._managed_in_memory
362            out._unmanaged_old += ms._unmanaged_old
363        return out
364
365    @property
366    def managed(self) -> Py_ssize_t:
367        return self._managed_in_memory + self._managed_spilled
368
369    @property
370    def unmanaged(self) -> Py_ssize_t:
371        # This is never negative thanks to __init__
372        return self._process - self._managed_in_memory
373
374    @property
375    def unmanaged_recent(self) -> Py_ssize_t:
376        # This is never negative thanks to __init__
377        return self._process - self._managed_in_memory - self._unmanaged_old
378
379    @property
380    def optimistic(self) -> Py_ssize_t:
381        return self._managed_in_memory + self._unmanaged_old
382
383    def __repr__(self) -> str:
384        return (
385            f"Process memory (RSS)  : {format_bytes(self._process)}\n"
386            f"  - managed by Dask   : {format_bytes(self._managed_in_memory)}\n"
387            f"  - unmanaged (old)   : {format_bytes(self._unmanaged_old)}\n"
388            f"  - unmanaged (recent): {format_bytes(self.unmanaged_recent)}\n"
389            f"Spilled to disk       : {format_bytes(self._managed_spilled)}\n"
390        )
391
392
393@final
394@cclass
395class WorkerState:
396    """
397    A simple object holding information about a worker.
398
399    .. attribute:: address: str
400
401       This worker's unique key.  This can be its connected address
402       (such as ``'tcp://127.0.0.1:8891'``) or an alias (such as ``'alice'``).
403
404    .. attribute:: processing: {TaskState: cost}
405
406       A dictionary of tasks that have been submitted to this worker.
407       Each task state is associated with the expected cost in seconds
408       of running that task, summing both the task's expected computation
409       time and the expected communication time of its result.
410
411       If a task is already executing on the worker and the excecution time is
412       twice the learned average TaskGroup duration, this will be set to twice
413       the current executing time. If the task is unknown, the default task
414       duration is used instead of the TaskGroup average.
415
416       Multiple tasks may be submitted to a worker in advance and the worker
417       will run them eventually, depending on its execution resources
418       (but see :doc:`work-stealing`).
419
420       All the tasks here are in the "processing" state.
421
422       This attribute is kept in sync with :attr:`TaskState.processing_on`.
423
424    .. attribute:: executing: {TaskState: duration}
425
426       A dictionary of tasks that are currently being run on this worker.
427       Each task state is asssociated with the duration in seconds which
428       the task has been running.
429
430    .. attribute:: has_what: {TaskState}
431
432       An insertion-sorted set-like of tasks which currently reside on this worker.
433       All the tasks here are in the "memory" state.
434
435       This is the reverse mapping of :class:`TaskState.who_has`.
436
437    .. attribute:: nbytes: int
438
439       The total memory size, in bytes, used by the tasks this worker
440       holds in memory (i.e. the tasks in this worker's :attr:`has_what`).
441
442    .. attribute:: nthreads: int
443
444       The number of CPU threads made available on this worker.
445
446    .. attribute:: resources: {str: Number}
447
448       The available resources on this worker like ``{'gpu': 2}``.
449       These are abstract quantities that constrain certain tasks from
450       running at the same time on this worker.
451
452    .. attribute:: used_resources: {str: Number}
453
454       The sum of each resource used by all tasks allocated to this worker.
455       The numbers in this dictionary can only be less or equal than
456       those in this worker's :attr:`resources`.
457
458    .. attribute:: occupancy: double
459
460       The total expected runtime, in seconds, of all tasks currently
461       processing on this worker.  This is the sum of all the costs in
462       this worker's :attr:`processing` dictionary.
463
464    .. attribute:: status: Status
465
466       Read-only worker status, synced one way from the remote Worker object
467
468    .. attribute:: nanny: str
469
470       Address of the associated Nanny, if present
471
472    .. attribute:: last_seen: Py_ssize_t
473
474       The last time we received a heartbeat from this worker, in local
475       scheduler time.
476
477    .. attribute:: actors: {TaskState}
478
479       A set of all TaskStates on this worker that are actors.  This only
480       includes those actors whose state actually lives on this worker, not
481       actors to which this worker has a reference.
482
483    """
484
485    # XXX need a state field to signal active/removed?
486
487    _actors: set
488    _address: str
489    _bandwidth: double
490    _executing: dict
491    _extra: dict
492    # _has_what is a dict with all values set to None as rebalance() relies on the
493    # property of Python >=3.7 dicts to be insertion-sorted.
494    _has_what: dict
495    _hash: Py_hash_t
496    _last_seen: double
497    _local_directory: str
498    _memory_limit: Py_ssize_t
499    _memory_other_history: "deque[tuple[float, Py_ssize_t]]"
500    _memory_unmanaged_old: Py_ssize_t
501    _metrics: dict
502    _name: object
503    _nanny: str
504    _nbytes: Py_ssize_t
505    _nthreads: Py_ssize_t
506    _occupancy: double
507    _pid: Py_ssize_t
508    _processing: dict
509    _resources: dict
510    _services: dict
511    _status: Status
512    _time_delay: double
513    _used_resources: dict
514    _versions: dict
515
516    __slots__ = (
517        "_actors",
518        "_address",
519        "_bandwidth",
520        "_extra",
521        "_executing",
522        "_has_what",
523        "_hash",
524        "_last_seen",
525        "_local_directory",
526        "_memory_limit",
527        "_memory_other_history",
528        "_memory_unmanaged_old",
529        "_metrics",
530        "_name",
531        "_nanny",
532        "_nbytes",
533        "_nthreads",
534        "_occupancy",
535        "_pid",
536        "_processing",
537        "_resources",
538        "_services",
539        "_status",
540        "_time_delay",
541        "_used_resources",
542        "_versions",
543    )
544
545    def __init__(
546        self,
547        *,
548        address: str,
549        status: Status,
550        pid: Py_ssize_t,
551        name: object,
552        nthreads: Py_ssize_t = 0,
553        memory_limit: Py_ssize_t,
554        local_directory: str,
555        nanny: str,
556        services: "dict | None" = None,
557        versions: "dict | None" = None,
558        extra: "dict | None" = None,
559    ):
560        self._address = address
561        self._pid = pid
562        self._name = name
563        self._nthreads = nthreads
564        self._memory_limit = memory_limit
565        self._local_directory = local_directory
566        self._services = services or {}
567        self._versions = versions or {}
568        self._nanny = nanny
569        self._status = status
570
571        self._hash = hash(address)
572        self._nbytes = 0
573        self._occupancy = 0
574        self._memory_unmanaged_old = 0
575        self._memory_other_history = deque()
576        self._metrics = {}
577        self._last_seen = 0
578        self._time_delay = 0
579        self._bandwidth = float(
580            parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
581        )
582
583        self._actors = set()
584        self._has_what = {}
585        self._processing = {}
586        self._executing = {}
587        self._resources = {}
588        self._used_resources = {}
589
590        self._extra = extra or {}
591
592    def __hash__(self):
593        return self._hash
594
595    def __eq__(self, other):
596        typ_self: type = type(self)
597        typ_other: type = type(other)
598        if typ_self == typ_other:
599            other_ws: WorkerState = other
600            return self._address == other_ws._address
601        else:
602            return False
603
604    @property
605    def actors(self):
606        return self._actors
607
608    @property
609    def address(self) -> str:
610        return self._address
611
612    @property
613    def bandwidth(self):
614        return self._bandwidth
615
616    @property
617    def executing(self):
618        return self._executing
619
620    @property
621    def extra(self):
622        return self._extra
623
624    @property
625    def has_what(self) -> "Set[TaskState]":
626        return self._has_what.keys()
627
628    @property
629    def host(self):
630        return get_address_host(self._address)
631
632    @property
633    def last_seen(self):
634        return self._last_seen
635
636    @property
637    def local_directory(self):
638        return self._local_directory
639
640    @property
641    def memory_limit(self):
642        return self._memory_limit
643
644    @property
645    def metrics(self):
646        return self._metrics
647
648    @property
649    def memory(self) -> MemoryState:
650        return MemoryState(
651            # metrics["memory"] is None if the worker sent a heartbeat before its
652            # SystemMonitor ever had a chance to run
653            process=self._metrics["memory"] or 0,
654            managed=self._nbytes,
655            managed_spilled=self._metrics["spilled_nbytes"],
656            unmanaged_old=self._memory_unmanaged_old,
657        )
658
659    @property
660    def name(self):
661        return self._name
662
663    @property
664    def nanny(self):
665        return self._nanny
666
667    @property
668    def nbytes(self):
669        return self._nbytes
670
671    @nbytes.setter
672    def nbytes(self, v: Py_ssize_t):
673        self._nbytes = v
674
675    @property
676    def nthreads(self):
677        return self._nthreads
678
679    @property
680    def occupancy(self):
681        return self._occupancy
682
683    @occupancy.setter
684    def occupancy(self, v: double):
685        self._occupancy = v
686
687    @property
688    def pid(self):
689        return self._pid
690
691    @property
692    def processing(self):
693        return self._processing
694
695    @property
696    def resources(self):
697        return self._resources
698
699    @property
700    def services(self):
701        return self._services
702
703    @property
704    def status(self):
705        return self._status
706
707    @status.setter
708    def status(self, new_status):
709        if not isinstance(new_status, Status):
710            raise TypeError(f"Expected Status; got {new_status!r}")
711        self._status = new_status
712
713    @property
714    def time_delay(self):
715        return self._time_delay
716
717    @property
718    def used_resources(self):
719        return self._used_resources
720
721    @property
722    def versions(self):
723        return self._versions
724
725    @ccall
726    def clean(self):
727        """Return a version of this object that is appropriate for serialization"""
728        ws: WorkerState = WorkerState(
729            address=self._address,
730            status=self._status,
731            pid=self._pid,
732            name=self._name,
733            nthreads=self._nthreads,
734            memory_limit=self._memory_limit,
735            local_directory=self._local_directory,
736            services=self._services,
737            nanny=self._nanny,
738            extra=self._extra,
739        )
740        ts: TaskState
741        ws._processing = {ts._key: cost for ts, cost in self._processing.items()}
742        ws._executing = {ts._key: duration for ts, duration in self._executing.items()}
743        return ws
744
745    def __repr__(self):
746        return "<WorkerState %r, name: %s, status: %s, memory: %d, processing: %d>" % (
747            self._address,
748            self._name,
749            self._status.name,
750            len(self._has_what),
751            len(self._processing),
752        )
753
754    def _repr_html_(self):
755        return get_template("worker_state.html.j2").render(
756            address=self.address,
757            name=self.name,
758            status=self.status.name,
759            has_what=self._has_what,
760            processing=self.processing,
761        )
762
763    @ccall
764    @exceptval(check=False)
765    def identity(self) -> dict:
766        return {
767            "type": "Worker",
768            "id": self._name,
769            "host": self.host,
770            "resources": self._resources,
771            "local_directory": self._local_directory,
772            "name": self._name,
773            "nthreads": self._nthreads,
774            "memory_limit": self._memory_limit,
775            "last_seen": self._last_seen,
776            "services": self._services,
777            "metrics": self._metrics,
778            "nanny": self._nanny,
779            **self._extra,
780        }
781
782    @property
783    def ncores(self):
784        warnings.warn("WorkerState.ncores has moved to WorkerState.nthreads")
785        return self._nthreads
786
787
788@final
789@cclass
790class Computation:
791    """
792    Collection tracking a single compute or persist call
793
794    See also
795    --------
796    TaskPrefix
797    TaskGroup
798    TaskState
799    """
800
801    _start: double
802    _groups: set
803    _code: object
804    _id: object
805
806    def __init__(self):
807        self._start = time()
808        self._groups = set()
809        self._code = SortedSet()
810        self._id = uuid.uuid4()
811
812    @property
813    def code(self):
814        return self._code
815
816    @property
817    def start(self):
818        return self._start
819
820    @property
821    def stop(self):
822        if self.groups:
823            return max(tg.stop for tg in self.groups)
824        else:
825            return -1
826
827    @property
828    def states(self):
829        tg: TaskGroup
830        return merge_with(sum, [tg._states for tg in self._groups])
831
832    @property
833    def groups(self):
834        return self._groups
835
836    def __repr__(self):
837        return (
838            f"<Computation {self._id}: "
839            + "Tasks: "
840            + ", ".join(
841                "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v
842            )
843            + ">"
844        )
845
846    def _repr_html_(self):
847        return get_template("computation.html.j2").render(
848            id=self._id,
849            start=self.start,
850            stop=self.stop,
851            groups=self.groups,
852            states=self.states,
853            code=self.code,
854        )
855
856
857@final
858@cclass
859class TaskPrefix:
860    """Collection tracking all tasks within a group
861
862    Keys often have a structure like ``("x-123", 0)``
863    A group takes the first section, like ``"x"``
864
865    .. attribute:: name: str
866
867       The name of a group of tasks.
868       For a task like ``("x-123", 0)`` this is the text ``"x"``
869
870    .. attribute:: states: Dict[str, int]
871
872       The number of tasks in each state,
873       like ``{"memory": 10, "processing": 3, "released": 4, ...}``
874
875    .. attribute:: duration_average: float
876
877       An exponentially weighted moving average duration of all tasks with this prefix
878
879    .. attribute:: suspicious: int
880
881       Numbers of times a task was marked as suspicious with this prefix
882
883
884    See Also
885    --------
886    TaskGroup
887    """
888
889    _name: str
890    _all_durations: "defaultdict[str, float]"
891    _duration_average: double
892    _suspicious: Py_ssize_t
893    _groups: list
894
895    def __init__(self, name: str):
896        self._name = name
897        self._groups = []
898
899        # store timings for each prefix-action
900        self._all_durations = defaultdict(float)
901
902        task_durations = dask.config.get("distributed.scheduler.default-task-durations")
903        if self._name in task_durations:
904            self._duration_average = parse_timedelta(task_durations[self._name])
905        else:
906            self._duration_average = -1
907        self._suspicious = 0
908
909    @property
910    def name(self) -> str:
911        return self._name
912
913    @property
914    def all_durations(self) -> "defaultdict[str, float]":
915        return self._all_durations
916
917    @ccall
918    @exceptval(check=False)
919    def add_duration(self, action: str, start: double, stop: double):
920        duration = stop - start
921        self._all_durations[action] += duration
922        if action == "compute":
923            old = self._duration_average
924            if old < 0:
925                self._duration_average = duration
926            else:
927                self._duration_average = 0.5 * duration + 0.5 * old
928
929    @property
930    def duration_average(self) -> double:
931        return self._duration_average
932
933    @property
934    def suspicious(self) -> Py_ssize_t:
935        return self._suspicious
936
937    @property
938    def groups(self):
939        return self._groups
940
941    @property
942    def states(self):
943        tg: TaskGroup
944        return merge_with(sum, [tg._states for tg in self._groups])
945
946    @property
947    def active(self) -> "list[TaskGroup]":
948        tg: TaskGroup
949        return [
950            tg
951            for tg in self._groups
952            if any([v != 0 for k, v in tg._states.items() if k != "forgotten"])
953        ]
954
955    @property
956    def active_states(self):
957        tg: TaskGroup
958        return merge_with(sum, [tg._states for tg in self.active])
959
960    def __repr__(self):
961        return (
962            "<"
963            + self._name
964            + ": "
965            + ", ".join(
966                "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v
967            )
968            + ">"
969        )
970
971    @property
972    def nbytes_total(self):
973        tg: TaskGroup
974        return sum([tg._nbytes_total for tg in self._groups])
975
976    def __len__(self):
977        return sum(map(len, self._groups))
978
979    @property
980    def duration(self):
981        tg: TaskGroup
982        return sum([tg._duration for tg in self._groups])
983
984    @property
985    def types(self):
986        tg: TaskGroup
987        return set().union(*[tg._types for tg in self._groups])
988
989
990@final
991@cclass
992class TaskGroup:
993    """Collection tracking all tasks within a group
994
995    Keys often have a structure like ``("x-123", 0)``
996    A group takes the first section, like ``"x-123"``
997
998    .. attribute:: name: str
999
1000       The name of a group of tasks.
1001       For a task like ``("x-123", 0)`` this is the text ``"x-123"``
1002
1003    .. attribute:: states: Dict[str, int]
1004
1005       The number of tasks in each state,
1006       like ``{"memory": 10, "processing": 3, "released": 4, ...}``
1007
1008    .. attribute:: dependencies: Set[TaskGroup]
1009
1010       The other TaskGroups on which this one depends
1011
1012    .. attribute:: nbytes_total: int
1013
1014       The total number of bytes that this task group has produced
1015
1016    .. attribute:: duration: float
1017
1018       The total amount of time spent on all tasks in this TaskGroup
1019
1020    .. attribute:: types: Set[str]
1021
1022       The result types of this TaskGroup
1023
1024    .. attribute:: last_worker: WorkerState
1025
1026       The worker most recently assigned a task from this group, or None when the group
1027       is not identified to be root-like by `SchedulerState.decide_worker`.
1028
1029    .. attribute:: last_worker_tasks_left: int
1030
1031       If `last_worker` is not None, the number of times that worker should be assigned
1032       subsequent tasks until a new worker is chosen.
1033
1034    See also
1035    --------
1036    TaskPrefix
1037    """
1038
1039    _name: str
1040    _prefix: TaskPrefix  # TaskPrefix | None
1041    _states: dict
1042    _dependencies: set
1043    _nbytes_total: Py_ssize_t
1044    _duration: double
1045    _types: set
1046    _start: double
1047    _stop: double
1048    _all_durations: "defaultdict[str, float]"
1049    _last_worker: WorkerState  # WorkerState | None
1050    _last_worker_tasks_left: Py_ssize_t
1051
1052    def __init__(self, name: str):
1053        self._name = name
1054        self._prefix = None  # type: ignore
1055        self._states = {state: 0 for state in ALL_TASK_STATES}
1056        self._states["forgotten"] = 0
1057        self._dependencies = set()
1058        self._nbytes_total = 0
1059        self._duration = 0
1060        self._types = set()
1061        self._start = 0.0
1062        self._stop = 0.0
1063        self._all_durations = defaultdict(float)
1064        self._last_worker = None  # type: ignore
1065        self._last_worker_tasks_left = 0
1066
1067    @property
1068    def name(self) -> str:
1069        return self._name
1070
1071    @property
1072    def prefix(self) -> "TaskPrefix | None":
1073        return self._prefix
1074
1075    @property
1076    def states(self) -> dict:
1077        return self._states
1078
1079    @property
1080    def dependencies(self) -> set:
1081        return self._dependencies
1082
1083    @property
1084    def nbytes_total(self):
1085        return self._nbytes_total
1086
1087    @property
1088    def duration(self) -> double:
1089        return self._duration
1090
1091    @ccall
1092    @exceptval(check=False)
1093    def add_duration(self, action: str, start: double, stop: double):
1094        duration = stop - start
1095        self._all_durations[action] += duration
1096        if action == "compute":
1097            if self._stop < stop:
1098                self._stop = stop
1099            self._start = self._start or start
1100        self._duration += duration
1101        self._prefix.add_duration(action, start, stop)
1102
1103    @property
1104    def types(self) -> set:
1105        return self._types
1106
1107    @property
1108    def all_durations(self) -> "defaultdict[str, float]":
1109        return self._all_durations
1110
1111    @property
1112    def start(self) -> double:
1113        return self._start
1114
1115    @property
1116    def stop(self) -> double:
1117        return self._stop
1118
1119    @property
1120    def last_worker(self) -> "WorkerState | None":
1121        return self._last_worker
1122
1123    @property
1124    def last_worker_tasks_left(self) -> int:
1125        return self._last_worker_tasks_left
1126
1127    @ccall
1128    def add(self, other: "TaskState"):
1129        self._states[other._state] += 1
1130        other._group = self
1131
1132    def __repr__(self):
1133        return (
1134            "<"
1135            + (self._name or "no-group")
1136            + ": "
1137            + ", ".join(
1138                "%s: %d" % (k, v) for (k, v) in sorted(self._states.items()) if v
1139            )
1140            + ">"
1141        )
1142
1143    def __len__(self):
1144        return sum(self._states.values())
1145
1146
1147@final
1148@cclass
1149class TaskState:
1150    """
1151    A simple object holding information about a task.
1152
1153    .. attribute:: key: str
1154
1155       The key is the unique identifier of a task, generally formed
1156       from the name of the function, followed by a hash of the function
1157       and arguments, like ``'inc-ab31c010444977004d656610d2d421ec'``.
1158
1159    .. attribute:: prefix: TaskPrefix
1160
1161       The broad class of tasks to which this task belongs like "inc" or
1162       "read_csv"
1163
1164    .. attribute:: run_spec: object
1165
1166       A specification of how to run the task.  The type and meaning of this
1167       value is opaque to the scheduler, as it is only interpreted by the
1168       worker to which the task is sent for executing.
1169
1170       As a special case, this attribute may also be ``None``, in which case
1171       the task is "pure data" (such as, for example, a piece of data loaded
1172       in the scheduler using :meth:`Client.scatter`).  A "pure data" task
1173       cannot be computed again if its value is lost.
1174
1175    .. attribute:: priority: tuple
1176
1177       The priority provides each task with a relative ranking which is used
1178       to break ties when many tasks are being considered for execution.
1179
1180       This ranking is generally a 2-item tuple.  The first (and dominant)
1181       item corresponds to when it was submitted.  Generally, earlier tasks
1182       take precedence.  The second item is determined by the client, and is
1183       a way to prioritize tasks within a large graph that may be important,
1184       such as if they are on the critical path, or good to run in order to
1185       release many dependencies.  This is explained further in
1186       :doc:`Scheduling Policy <scheduling-policies>`.
1187
1188    .. attribute:: state: str
1189
1190       This task's current state.  Valid states include ``released``,
1191       ``waiting``, ``no-worker``, ``processing``, ``memory``, ``erred``
1192       and ``forgotten``.  If it is ``forgotten``, the task isn't stored
1193       in the ``tasks`` dictionary anymore and will probably disappear
1194       soon from memory.
1195
1196    .. attribute:: dependencies: {TaskState}
1197
1198       The set of tasks this task depends on for proper execution.  Only
1199       tasks still alive are listed in this set.  If, for whatever reason,
1200       this task also depends on a forgotten task, the
1201       :attr:`has_lost_dependencies` flag is set.
1202
1203       A task can only be executed once all its dependencies have already
1204       been successfully executed and have their result stored on at least
1205       one worker.  This is tracked by progressively draining the
1206       :attr:`waiting_on` set.
1207
1208    .. attribute:: dependents: {TaskState}
1209
1210       The set of tasks which depend on this task.  Only tasks still alive
1211       are listed in this set.
1212
1213       This is the reverse mapping of :attr:`dependencies`.
1214
1215    .. attribute:: has_lost_dependencies: bool
1216
1217       Whether any of the dependencies of this task has been forgotten.
1218       For memory consumption reasons, forgotten tasks are not kept in
1219       memory even though they may have dependent tasks.  When a task is
1220       forgotten, therefore, each of its dependents has their
1221       :attr:`has_lost_dependencies` attribute set to ``True``.
1222
1223       If :attr:`has_lost_dependencies` is true, this task cannot go
1224       into the "processing" state anymore.
1225
1226    .. attribute:: waiting_on: {TaskState}
1227
1228       The set of tasks this task is waiting on *before* it can be executed.
1229       This is always a subset of :attr:`dependencies`.  Each time one of the
1230       dependencies has finished processing, it is removed from the
1231       :attr:`waiting_on` set.
1232
1233       Once :attr:`waiting_on` becomes empty, this task can move from the
1234       "waiting" state to the "processing" state (unless one of the
1235       dependencies errored out, in which case this task is instead
1236       marked "erred").
1237
1238    .. attribute:: waiters: {TaskState}
1239
1240       The set of tasks which need this task to remain alive.  This is always
1241       a subset of :attr:`dependents`.  Each time one of the dependents
1242       has finished processing, it is removed from the :attr:`waiters`
1243       set.
1244
1245       Once both :attr:`waiters` and :attr:`who_wants` become empty, this
1246       task can be released (if it has a non-empty :attr:`run_spec`) or
1247       forgotten (otherwise) by the scheduler, and by any workers
1248       in :attr:`who_has`.
1249
1250       .. note:: Counter-intuitively, :attr:`waiting_on` and
1251          :attr:`waiters` are not reverse mappings of each other.
1252
1253    .. attribute:: who_wants: {ClientState}
1254
1255       The set of clients who want this task's result to remain alive.
1256       This is the reverse mapping of :attr:`ClientState.wants_what`.
1257
1258       When a client submits a graph to the scheduler it also specifies
1259       which output tasks it desires, such that their results are not released
1260       from memory.
1261
1262       Once a task has finished executing (i.e. moves into the "memory"
1263       or "erred" state), the clients in :attr:`who_wants` are notified.
1264
1265       Once both :attr:`waiters` and :attr:`who_wants` become empty, this
1266       task can be released (if it has a non-empty :attr:`run_spec`) or
1267       forgotten (otherwise) by the scheduler, and by any workers
1268       in :attr:`who_has`.
1269
1270    .. attribute:: who_has: {WorkerState}
1271
1272       The set of workers who have this task's result in memory.
1273       It is non-empty iff the task is in the "memory" state.  There can be
1274       more than one worker in this set if, for example, :meth:`Client.scatter`
1275       or :meth:`Client.replicate` was used.
1276
1277       This is the reverse mapping of :attr:`WorkerState.has_what`.
1278
1279    .. attribute:: processing_on: WorkerState (or None)
1280
1281       If this task is in the "processing" state, which worker is currently
1282       processing it.  Otherwise this is ``None``.
1283
1284       This attribute is kept in sync with :attr:`WorkerState.processing`.
1285
1286    .. attribute:: retries: int
1287
1288       The number of times this task can automatically be retried in case
1289       of failure.  If a task fails executing (the worker returns with
1290       an error), its :attr:`retries` attribute is checked.  If it is
1291       equal to 0, the task is marked "erred".  If it is greater than 0,
1292       the :attr:`retries` attribute is decremented and execution is
1293       attempted again.
1294
1295    .. attribute:: nbytes: int (or None)
1296
1297       The number of bytes, as determined by ``sizeof``, of the result
1298       of a finished task.  This number is used for diagnostics and to
1299       help prioritize work.
1300
1301    .. attribute:: type: str
1302
1303       The type of the object as a string.  Only present for tasks that have
1304       been computed.
1305
1306    .. attribute:: exception: object
1307
1308       If this task failed executing, the exception object is stored here.
1309       Otherwise this is ``None``.
1310
1311    .. attribute:: traceback: object
1312
1313       If this task failed executing, the traceback object is stored here.
1314       Otherwise this is ``None``.
1315
1316    .. attribute:: exception_blame: TaskState (or None)
1317
1318       If this task or one of its dependencies failed executing, the
1319       failed task is stored here (possibly itself).  Otherwise this
1320       is ``None``.
1321
1322    .. attribute:: erred_on: set(str)
1323
1324        Worker addresses on which errors appeared causing this task to be in an error state.
1325
1326    .. attribute:: suspicious: int
1327
1328       The number of times this task has been involved in a worker death.
1329
1330       Some tasks may cause workers to die (such as calling ``os._exit(0)``).
1331       When a worker dies, all of the tasks on that worker are reassigned
1332       to others.  This combination of behaviors can cause a bad task to
1333       catastrophically destroy all workers on the cluster, one after
1334       another.  Whenever a worker dies, we mark each task currently
1335       processing on that worker (as recorded by
1336       :attr:`WorkerState.processing`) as suspicious.
1337
1338       If a task is involved in three deaths (or some other fixed constant)
1339       then we mark the task as ``erred``.
1340
1341    .. attribute:: host_restrictions: {hostnames}
1342
1343       A set of hostnames where this task can be run (or ``None`` if empty).
1344       Usually this is empty unless the task has been specifically restricted
1345       to only run on certain hosts.  A hostname may correspond to one or
1346       several connected workers.
1347
1348    .. attribute:: worker_restrictions: {worker addresses}
1349
1350       A set of complete worker addresses where this can be run (or ``None``
1351       if empty).  Usually this is empty unless the task has been specifically
1352       restricted to only run on certain workers.
1353
1354       Note this is tracking worker addresses, not worker states, since
1355       the specific workers may not be connected at this time.
1356
1357    .. attribute:: resource_restrictions: {resource: quantity}
1358
1359       Resources required by this task, such as ``{'gpu': 1}`` or
1360       ``{'memory': 1e9}`` (or ``None`` if empty).  These are user-defined
1361       names and are matched against the contents of each
1362       :attr:`WorkerState.resources` dictionary.
1363
1364    .. attribute:: loose_restrictions: bool
1365
1366       If ``False``, each of :attr:`host_restrictions`,
1367       :attr:`worker_restrictions` and :attr:`resource_restrictions` is
1368       a hard constraint: if no worker is available satisfying those
1369       restrictions, the task cannot go into the "processing" state and
1370       will instead go into the "no-worker" state.
1371
1372       If ``True``, the above restrictions are mere preferences: if no worker
1373       is available satisfying those restrictions, the task can still go
1374       into the "processing" state and be sent for execution to another
1375       connected worker.
1376
1377    .. attribute:: metadata: dict
1378
1379       Metadata related to task.
1380
1381    .. attribute:: actor: bool
1382
1383       Whether or not this task is an Actor.
1384
1385    .. attribute:: group: TaskGroup
1386
1387        The group of tasks to which this one belongs.
1388
1389    .. attribute:: annotations: dict
1390
1391        Task annotations
1392    """
1393
1394    _key: str
1395    _hash: Py_hash_t
1396    _prefix: TaskPrefix
1397    _run_spec: object
1398    _priority: tuple  # tuple | None
1399    _state: str  # str | None
1400    _dependencies: set  # set[TaskState]
1401    _dependents: set  # set[TaskState]
1402    _has_lost_dependencies: bint
1403    _waiting_on: set  # set[TaskState]
1404    _waiters: set  # set[TaskState]
1405    _who_wants: set  # set[ClientState]
1406    _who_has: set  # set[WorkerState]
1407    _processing_on: WorkerState  # WorkerState | None
1408    _retries: Py_ssize_t
1409    _nbytes: Py_ssize_t
1410    _type: str  # str | None
1411    _exception: object
1412    _exception_text: str
1413    _traceback: object
1414    _traceback_text: str
1415    _exception_blame: "TaskState"  # TaskState | None"
1416    _erred_on: set
1417    _suspicious: Py_ssize_t
1418    _host_restrictions: set  # set[str] | None
1419    _worker_restrictions: set  # set[str] | None
1420    _resource_restrictions: dict  # dict | None
1421    _loose_restrictions: bint
1422    _metadata: dict
1423    _annotations: dict
1424    _actor: bint
1425    _group: TaskGroup  # TaskGroup | None
1426    _group_key: str
1427
1428    __slots__ = (
1429        # === General description ===
1430        "_actor",
1431        # Key name
1432        "_key",
1433        # Hash of the key name
1434        "_hash",
1435        # Key prefix (see key_split())
1436        "_prefix",
1437        # How to run the task (None if pure data)
1438        "_run_spec",
1439        # Alive dependents and dependencies
1440        "_dependencies",
1441        "_dependents",
1442        # Compute priority
1443        "_priority",
1444        # Restrictions
1445        "_host_restrictions",
1446        "_worker_restrictions",  # not WorkerStates but addresses
1447        "_resource_restrictions",
1448        "_loose_restrictions",
1449        # === Task state ===
1450        "_state",
1451        # Whether some dependencies were forgotten
1452        "_has_lost_dependencies",
1453        # If in 'waiting' state, which tasks need to complete
1454        # before we can run
1455        "_waiting_on",
1456        # If in 'waiting' or 'processing' state, which tasks needs us
1457        # to complete before they can run
1458        "_waiters",
1459        # In in 'processing' state, which worker we are processing on
1460        "_processing_on",
1461        # If in 'memory' state, Which workers have us
1462        "_who_has",
1463        # Which clients want us
1464        "_who_wants",
1465        "_exception",
1466        "_exception_text",
1467        "_traceback",
1468        "_traceback_text",
1469        "_erred_on",
1470        "_exception_blame",
1471        "_suspicious",
1472        "_retries",
1473        "_nbytes",
1474        "_type",
1475        "_group_key",
1476        "_group",
1477        "_metadata",
1478        "_annotations",
1479    )
1480
1481    def __init__(self, key: str, run_spec: object):
1482        self._key = key
1483        self._hash = hash(key)
1484        self._run_spec = run_spec
1485        self._state = None  # type: ignore
1486        self._exception = None
1487        self._exception_blame = None  # type: ignore
1488        self._traceback = None
1489        self._exception_text = ""
1490        self._traceback_text = ""
1491        self._suspicious = 0
1492        self._retries = 0
1493        self._nbytes = -1
1494        self._priority = None  # type: ignore
1495        self._who_wants = set()
1496        self._dependencies = set()
1497        self._dependents = set()
1498        self._waiting_on = set()
1499        self._waiters = set()
1500        self._who_has = set()
1501        self._processing_on = None  # type: ignore
1502        self._has_lost_dependencies = False
1503        self._host_restrictions = None  # type: ignore
1504        self._worker_restrictions = None  # type: ignore
1505        self._resource_restrictions = None  # type: ignore
1506        self._loose_restrictions = False
1507        self._actor = False
1508        self._type = None  # type: ignore
1509        self._group_key = key_split_group(key)
1510        self._group = None  # type: ignore
1511        self._metadata = {}
1512        self._annotations = {}
1513        self._erred_on = set()
1514
1515    def __hash__(self):
1516        return self._hash
1517
1518    def __eq__(self, other):
1519        typ_self: type = type(self)
1520        typ_other: type = type(other)
1521        if typ_self == typ_other:
1522            other_ts: TaskState = other
1523            return self._key == other_ts._key
1524        else:
1525            return False
1526
1527    @property
1528    def key(self):
1529        return self._key
1530
1531    @property
1532    def prefix(self):
1533        return self._prefix
1534
1535    @property
1536    def run_spec(self):
1537        return self._run_spec
1538
1539    @property
1540    def priority(self) -> "tuple | None":
1541        return self._priority
1542
1543    @property
1544    def state(self) -> "str | None":
1545        return self._state
1546
1547    @state.setter
1548    def state(self, value: str):
1549        self._group._states[self._state] -= 1
1550        self._group._states[value] += 1
1551        self._state = value
1552
1553    @property
1554    def dependencies(self) -> "set[TaskState]":
1555        return self._dependencies
1556
1557    @property
1558    def dependents(self) -> "set[TaskState]":
1559        return self._dependents
1560
1561    @property
1562    def has_lost_dependencies(self):
1563        return self._has_lost_dependencies
1564
1565    @property
1566    def waiting_on(self) -> "set[TaskState]":
1567        return self._waiting_on
1568
1569    @property
1570    def waiters(self) -> "set[TaskState]":
1571        return self._waiters
1572
1573    @property
1574    def who_wants(self) -> "set[ClientState]":
1575        return self._who_wants
1576
1577    @property
1578    def who_has(self) -> "set[WorkerState]":
1579        return self._who_has
1580
1581    @property
1582    def processing_on(self) -> "WorkerState | None":
1583        return self._processing_on
1584
1585    @processing_on.setter
1586    def processing_on(self, v: WorkerState) -> None:
1587        self._processing_on = v
1588
1589    @property
1590    def retries(self):
1591        return self._retries
1592
1593    @property
1594    def nbytes(self):
1595        return self._nbytes
1596
1597    @nbytes.setter
1598    def nbytes(self, v: Py_ssize_t):
1599        self._nbytes = v
1600
1601    @property
1602    def type(self) -> "str | None":
1603        return self._type
1604
1605    @property
1606    def exception(self):
1607        return self._exception
1608
1609    @property
1610    def exception_text(self):
1611        return self._exception_text
1612
1613    @property
1614    def traceback(self):
1615        return self._traceback
1616
1617    @property
1618    def traceback_text(self):
1619        return self._traceback_text
1620
1621    @property
1622    def exception_blame(self) -> "TaskState | None":
1623        return self._exception_blame
1624
1625    @property
1626    def suspicious(self):
1627        return self._suspicious
1628
1629    @property
1630    def host_restrictions(self) -> "set[str] | None":
1631        return self._host_restrictions
1632
1633    @property
1634    def worker_restrictions(self) -> "set[str] | None":
1635        return self._worker_restrictions
1636
1637    @property
1638    def resource_restrictions(self) -> "dict | None":
1639        return self._resource_restrictions
1640
1641    @property
1642    def loose_restrictions(self):
1643        return self._loose_restrictions
1644
1645    @property
1646    def metadata(self):
1647        return self._metadata
1648
1649    @property
1650    def annotations(self):
1651        return self._annotations
1652
1653    @property
1654    def actor(self):
1655        return self._actor
1656
1657    @property
1658    def group(self) -> "TaskGroup | None":
1659        return self._group
1660
1661    @property
1662    def group_key(self) -> str:
1663        return self._group_key
1664
1665    @property
1666    def prefix_key(self):
1667        return self._prefix._name
1668
1669    @property
1670    def erred_on(self):
1671        return self._erred_on
1672
1673    @ccall
1674    def add_dependency(self, other: "TaskState"):
1675        """Add another task as a dependency of this task"""
1676        self._dependencies.add(other)
1677        self._group._dependencies.add(other._group)
1678        other._dependents.add(self)
1679
1680    @ccall
1681    @inline
1682    @nogil
1683    def get_nbytes(self) -> Py_ssize_t:
1684        return self._nbytes if self._nbytes >= 0 else DEFAULT_DATA_SIZE
1685
1686    @ccall
1687    def set_nbytes(self, nbytes: Py_ssize_t):
1688        diff: Py_ssize_t = nbytes
1689        old_nbytes: Py_ssize_t = self._nbytes
1690        if old_nbytes >= 0:
1691            diff -= old_nbytes
1692        self._group._nbytes_total += diff
1693        ws: WorkerState
1694        for ws in self._who_has:
1695            ws._nbytes += diff
1696        self._nbytes = nbytes
1697
1698    def __repr__(self):
1699        return f"<TaskState {self._key!r} {self._state}>"
1700
1701    def _repr_html_(self):
1702        return get_template("task_state.html.j2").render(
1703            state=self._state,
1704            nbytes=self._nbytes,
1705            key=self._key,
1706        )
1707
1708    @ccall
1709    def validate(self):
1710        try:
1711            for cs in self._who_wants:
1712                assert isinstance(cs, ClientState), (repr(cs), self._who_wants)
1713            for ws in self._who_has:
1714                assert isinstance(ws, WorkerState), (repr(ws), self._who_has)
1715            for ts in self._dependencies:
1716                assert isinstance(ts, TaskState), (repr(ts), self._dependencies)
1717            for ts in self._dependents:
1718                assert isinstance(ts, TaskState), (repr(ts), self._dependents)
1719            validate_task_state(self)
1720        except Exception as e:
1721            logger.exception(e)
1722            if LOG_PDB:
1723                import pdb
1724
1725                pdb.set_trace()
1726
1727    def get_nbytes_deps(self):
1728        nbytes: Py_ssize_t = 0
1729        ts: TaskState
1730        for ts in self._dependencies:
1731            nbytes += ts.get_nbytes()
1732        return nbytes
1733
1734    @ccall
1735    def _to_dict(self, *, exclude: Container[str] = None):
1736        """
1737        A very verbose dictionary representation for debugging purposes.
1738        Not type stable and not inteded for roundtrips.
1739
1740        Parameters
1741        ----------
1742        exclude:
1743            A list of attributes which must not be present in the output.
1744
1745        See also
1746        --------
1747        Client.dump_cluster_state
1748        """
1749
1750        if not exclude:
1751            exclude = set()
1752        members = inspect.getmembers(self)
1753        return recursive_to_dict(
1754            {k: v for k, v in members if k not in exclude and not callable(v)},
1755            exclude=exclude,
1756        )
1757
1758
1759class _StateLegacyMapping(Mapping):
1760    """
1761    A mapping interface mimicking the former Scheduler state dictionaries.
1762    """
1763
1764    def __init__(self, states, accessor):
1765        self._states = states
1766        self._accessor = accessor
1767
1768    def __iter__(self):
1769        return iter(self._states)
1770
1771    def __len__(self):
1772        return len(self._states)
1773
1774    def __getitem__(self, key):
1775        return self._accessor(self._states[key])
1776
1777    def __repr__(self):
1778        return f"{self.__class__}({dict(self)})"
1779
1780
1781class _OptionalStateLegacyMapping(_StateLegacyMapping):
1782    """
1783    Similar to _StateLegacyMapping, but a false-y value is interpreted
1784    as a missing key.
1785    """
1786
1787    # For tasks etc.
1788
1789    def __iter__(self):
1790        accessor = self._accessor
1791        for k, v in self._states.items():
1792            if accessor(v):
1793                yield k
1794
1795    def __len__(self):
1796        accessor = self._accessor
1797        return sum(bool(accessor(v)) for v in self._states.values())
1798
1799    def __getitem__(self, key):
1800        v = self._accessor(self._states[key])
1801        if v:
1802            return v
1803        else:
1804            raise KeyError
1805
1806
1807class _StateLegacySet(Set):
1808    """
1809    Similar to _StateLegacyMapping, but exposes a set containing
1810    all values with a true value.
1811    """
1812
1813    # For loose_restrictions
1814
1815    def __init__(self, states, accessor):
1816        self._states = states
1817        self._accessor = accessor
1818
1819    def __iter__(self):
1820        return (k for k, v in self._states.items() if self._accessor(v))
1821
1822    def __len__(self):
1823        return sum(map(bool, map(self._accessor, self._states.values())))
1824
1825    def __contains__(self, k):
1826        st = self._states.get(k)
1827        return st is not None and bool(self._accessor(st))
1828
1829    def __repr__(self):
1830        return f"{self.__class__}({set(self)})"
1831
1832
1833def _legacy_task_key_set(tasks):
1834    """
1835    Transform a set of task states into a set of task keys.
1836    """
1837    ts: TaskState
1838    return {ts._key for ts in tasks}
1839
1840
1841def _legacy_client_key_set(clients):
1842    """
1843    Transform a set of client states into a set of client keys.
1844    """
1845    cs: ClientState
1846    return {cs._client_key for cs in clients}
1847
1848
1849def _legacy_worker_key_set(workers):
1850    """
1851    Transform a set of worker states into a set of worker keys.
1852    """
1853    ws: WorkerState
1854    return {ws._address for ws in workers}
1855
1856
1857def _legacy_task_key_dict(task_dict: dict):
1858    """
1859    Transform a dict of {task state: value} into a dict of {task key: value}.
1860    """
1861    ts: TaskState
1862    return {ts._key: value for ts, value in task_dict.items()}
1863
1864
1865def _task_key_or_none(task: TaskState):
1866    return task._key if task is not None else None
1867
1868
1869@cclass
1870class SchedulerState:
1871    """Underlying task state of dynamic scheduler
1872
1873    Tracks the current state of workers, data, and computations.
1874
1875    Handles transitions between different task states. Notifies the
1876    Scheduler of changes by messaging passing through Queues, which the
1877    Scheduler listens to responds accordingly.
1878
1879    All events are handled quickly, in linear time with respect to their
1880    input (which is often of constant size) and generally within a
1881    millisecond. Additionally when Cythonized, this can be faster still.
1882    To accomplish this the scheduler tracks a lot of state.  Every
1883    operation maintains the consistency of this state.
1884
1885    Users typically do not interact with ``Transitions`` directly. Instead
1886    users interact with the ``Client``, which in turn engages the
1887    ``Scheduler`` affecting different transitions here under-the-hood. In
1888    the background ``Worker``s also engage with the ``Scheduler``
1889    affecting these state transitions as well.
1890
1891    **State**
1892
1893    The ``Transitions`` object contains the following state variables.
1894    Each variable is listed along with what it stores and a brief
1895    description.
1896
1897    * **tasks:** ``{task key: TaskState}``
1898        Tasks currently known to the scheduler
1899    * **unrunnable:** ``{TaskState}``
1900        Tasks in the "no-worker" state
1901
1902    * **workers:** ``{worker key: WorkerState}``
1903        Workers currently connected to the scheduler
1904    * **idle:** ``{WorkerState}``:
1905        Set of workers that are not fully utilized
1906    * **saturated:** ``{WorkerState}``:
1907        Set of workers that are not over-utilized
1908    * **running:** ``{WorkerState}``:
1909        Set of workers that are currently in running state
1910
1911    * **clients:** ``{client key: ClientState}``
1912        Clients currently connected to the scheduler
1913
1914    * **task_duration:** ``{key-prefix: time}``
1915        Time we expect certain functions to take, e.g. ``{'sum': 0.25}``
1916    """
1917
1918    _aliases: dict
1919    _bandwidth: double
1920    _clients: dict  # dict[str, ClientState]
1921    _computations: object
1922    _extensions: dict
1923    _host_info: dict
1924    _idle: "SortedDict[str, WorkerState]"
1925    _idle_dv: dict  # dict[str, WorkerState]
1926    _n_tasks: Py_ssize_t
1927    _resources: dict
1928    _saturated: set  # set[WorkerState]
1929    _running: set  # set[WorkerState]
1930    _tasks: dict
1931    _task_groups: dict
1932    _task_prefixes: dict
1933    _task_metadata: dict
1934    _replicated_tasks: set
1935    _total_nthreads: Py_ssize_t
1936    _total_occupancy: double
1937    _transitions_table: dict
1938    _unknown_durations: dict
1939    _unrunnable: set
1940    _validate: bint
1941    _workers: "SortedDict[str, WorkerState]"
1942    _workers_dv: dict  # dict[str, WorkerState]
1943    _transition_counter: Py_ssize_t
1944    _plugins: dict  # dict[str, SchedulerPlugin]
1945
1946    # Variables from dask.config, cached by __init__ for performance
1947    UNKNOWN_TASK_DURATION: double
1948    MEMORY_RECENT_TO_OLD_TIME: double
1949    MEMORY_REBALANCE_MEASURE: str
1950    MEMORY_REBALANCE_SENDER_MIN: double
1951    MEMORY_REBALANCE_RECIPIENT_MAX: double
1952    MEMORY_REBALANCE_HALF_GAP: double
1953
1954    def __init__(
1955        self,
1956        aliases: dict,
1957        clients: "dict[str, ClientState]",
1958        workers: "SortedDict[str, WorkerState]",
1959        host_info: dict,
1960        resources: dict,
1961        tasks: dict,
1962        unrunnable: set,
1963        validate: bint,
1964        plugins: "Iterable[SchedulerPlugin]" = (),
1965        **kwargs,  # Passed verbatim to Server.__init__()
1966    ):
1967        self._aliases = aliases
1968        self._bandwidth = parse_bytes(
1969            dask.config.get("distributed.scheduler.bandwidth")
1970        )
1971        self._clients = clients
1972        self._clients["fire-and-forget"] = ClientState("fire-and-forget")
1973        self._extensions = {}
1974        self._host_info = host_info
1975        self._idle = SortedDict()
1976        # Note: cython.cast, not typing.cast!
1977        self._idle_dv = cast(dict, self._idle)
1978        self._n_tasks = 0
1979        self._resources = resources
1980        self._saturated = set()
1981        self._tasks = tasks
1982        self._replicated_tasks = {
1983            ts for ts in self._tasks.values() if len(ts._who_has) > 1
1984        }
1985        self._computations = deque(
1986            maxlen=dask.config.get("distributed.diagnostics.computations.max-history")
1987        )
1988        self._task_groups = {}
1989        self._task_prefixes = {}
1990        self._task_metadata = {}
1991        self._total_nthreads = 0
1992        self._total_occupancy = 0
1993        self._transitions_table = {
1994            ("released", "waiting"): self.transition_released_waiting,
1995            ("waiting", "released"): self.transition_waiting_released,
1996            ("waiting", "processing"): self.transition_waiting_processing,
1997            ("waiting", "memory"): self.transition_waiting_memory,
1998            ("processing", "released"): self.transition_processing_released,
1999            ("processing", "memory"): self.transition_processing_memory,
2000            ("processing", "erred"): self.transition_processing_erred,
2001            ("no-worker", "released"): self.transition_no_worker_released,
2002            ("no-worker", "waiting"): self.transition_no_worker_waiting,
2003            ("no-worker", "memory"): self.transition_no_worker_memory,
2004            ("released", "forgotten"): self.transition_released_forgotten,
2005            ("memory", "forgotten"): self.transition_memory_forgotten,
2006            ("erred", "released"): self.transition_erred_released,
2007            ("memory", "released"): self.transition_memory_released,
2008            ("released", "erred"): self.transition_released_erred,
2009        }
2010        self._unknown_durations = {}
2011        self._unrunnable = unrunnable
2012        self._validate = validate
2013        self._workers = workers
2014        # Note: cython.cast, not typing.cast!
2015        self._workers_dv = cast(dict, self._workers)
2016        self._running = {
2017            ws for ws in self._workers.values() if ws.status == Status.running
2018        }
2019        self._plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins}
2020
2021        # Variables from dask.config, cached by __init__ for performance
2022        self.UNKNOWN_TASK_DURATION = parse_timedelta(
2023            dask.config.get("distributed.scheduler.unknown-task-duration")
2024        )
2025        self.MEMORY_RECENT_TO_OLD_TIME = parse_timedelta(
2026            dask.config.get("distributed.worker.memory.recent-to-old-time")
2027        )
2028        self.MEMORY_REBALANCE_MEASURE = dask.config.get(
2029            "distributed.worker.memory.rebalance.measure"
2030        )
2031        self.MEMORY_REBALANCE_SENDER_MIN = dask.config.get(
2032            "distributed.worker.memory.rebalance.sender-min"
2033        )
2034        self.MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get(
2035            "distributed.worker.memory.rebalance.recipient-max"
2036        )
2037        self.MEMORY_REBALANCE_HALF_GAP = (
2038            dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap")
2039            / 2.0
2040        )
2041        self._transition_counter = 0
2042
2043        # Call Server.__init__()
2044        super().__init__(**kwargs)  # type: ignore
2045
2046    @property
2047    def aliases(self):
2048        return self._aliases
2049
2050    @property
2051    def bandwidth(self):
2052        return self._bandwidth
2053
2054    @property
2055    def clients(self):
2056        return self._clients
2057
2058    @property
2059    def computations(self):
2060        return self._computations
2061
2062    @property
2063    def extensions(self):
2064        return self._extensions
2065
2066    @property
2067    def host_info(self):
2068        return self._host_info
2069
2070    @property
2071    def idle(self):
2072        return self._idle
2073
2074    @property
2075    def n_tasks(self):
2076        return self._n_tasks
2077
2078    @property
2079    def resources(self):
2080        return self._resources
2081
2082    @property
2083    def saturated(self) -> "set[WorkerState]":
2084        return self._saturated
2085
2086    @property
2087    def running(self) -> "set[WorkerState]":
2088        return self._running
2089
2090    @property
2091    def tasks(self):
2092        return self._tasks
2093
2094    @property
2095    def task_groups(self):
2096        return self._task_groups
2097
2098    @property
2099    def task_prefixes(self):
2100        return self._task_prefixes
2101
2102    @property
2103    def task_metadata(self):
2104        return self._task_metadata
2105
2106    @property
2107    def replicated_tasks(self):
2108        return self._replicated_tasks
2109
2110    @property
2111    def total_nthreads(self):
2112        return self._total_nthreads
2113
2114    @property
2115    def total_occupancy(self):
2116        return self._total_occupancy
2117
2118    @total_occupancy.setter
2119    def total_occupancy(self, v: double):
2120        self._total_occupancy = v
2121
2122    @property
2123    def transition_counter(self):
2124        return self._transition_counter
2125
2126    @property
2127    def unknown_durations(self):
2128        return self._unknown_durations
2129
2130    @property
2131    def unrunnable(self):
2132        return self._unrunnable
2133
2134    @property
2135    def validate(self):
2136        return self._validate
2137
2138    @validate.setter
2139    def validate(self, v: bint):
2140        self._validate = v
2141
2142    @property
2143    def workers(self):
2144        return self._workers
2145
2146    @property
2147    def plugins(self) -> "dict[str, SchedulerPlugin]":
2148        return self._plugins
2149
2150    @property
2151    def memory(self) -> MemoryState:
2152        return MemoryState.sum(*(w.memory for w in self.workers.values()))
2153
2154    @property
2155    def __pdict__(self):
2156        return {
2157            "bandwidth": self._bandwidth,
2158            "resources": self._resources,
2159            "saturated": self._saturated,
2160            "unrunnable": self._unrunnable,
2161            "n_tasks": self._n_tasks,
2162            "unknown_durations": self._unknown_durations,
2163            "validate": self._validate,
2164            "tasks": self._tasks,
2165            "task_groups": self._task_groups,
2166            "task_prefixes": self._task_prefixes,
2167            "total_nthreads": self._total_nthreads,
2168            "total_occupancy": self._total_occupancy,
2169            "extensions": self._extensions,
2170            "clients": self._clients,
2171            "workers": self._workers,
2172            "idle": self._idle,
2173            "host_info": self._host_info,
2174        }
2175
2176    @ccall
2177    @exceptval(check=False)
2178    def new_task(
2179        self, key: str, spec: object, state: str, computation: Computation = None
2180    ) -> TaskState:
2181        """Create a new task, and associated states"""
2182        ts: TaskState = TaskState(key, spec)
2183        ts._state = state
2184
2185        tp: TaskPrefix
2186        prefix_key = key_split(key)
2187        tp = self._task_prefixes.get(prefix_key)  # type: ignore
2188        if tp is None:
2189            self._task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key)
2190        ts._prefix = tp
2191
2192        group_key = ts._group_key
2193        tg: TaskGroup = self._task_groups.get(group_key)  # type: ignore
2194        if tg is None:
2195            self._task_groups[group_key] = tg = TaskGroup(group_key)
2196            if computation:
2197                computation.groups.add(tg)
2198            tg._prefix = tp
2199            tp._groups.append(tg)
2200        tg.add(ts)
2201
2202        self._tasks[key] = ts
2203
2204        return ts
2205
2206    #####################
2207    # State Transitions #
2208    #####################
2209
2210    def _transition(self, key, finish: str, *args, **kwargs):
2211        """Transition a key from its current state to the finish state
2212
2213        Examples
2214        --------
2215        >>> self._transition('x', 'waiting')
2216        {'x': 'processing'}
2217
2218        Returns
2219        -------
2220        Dictionary of recommendations for future transitions
2221
2222        See Also
2223        --------
2224        Scheduler.transitions : transitive version of this function
2225        """
2226        parent: SchedulerState = cast(SchedulerState, self)
2227        ts: TaskState
2228        start: str
2229        start_finish: tuple
2230        finish2: str
2231        recommendations: dict
2232        worker_msgs: dict
2233        client_msgs: dict
2234        msgs: list
2235        new_msgs: list
2236        dependents: set
2237        dependencies: set
2238        try:
2239            recommendations = {}
2240            worker_msgs = {}
2241            client_msgs = {}
2242
2243            ts = parent._tasks.get(key)  # type: ignore
2244            if ts is None:
2245                return recommendations, client_msgs, worker_msgs
2246            start = ts._state
2247            if start == finish:
2248                return recommendations, client_msgs, worker_msgs
2249
2250            if self.plugins:
2251                dependents = set(ts._dependents)
2252                dependencies = set(ts._dependencies)
2253
2254            start_finish = (start, finish)
2255            func = self._transitions_table.get(start_finish)
2256            if func is not None:
2257                recommendations, client_msgs, worker_msgs = func(key, *args, **kwargs)
2258                self._transition_counter += 1
2259            elif "released" not in start_finish:
2260                assert not args and not kwargs, (args, kwargs, start_finish)
2261                a_recs: dict
2262                a_cmsgs: dict
2263                a_wmsgs: dict
2264                a: tuple = self._transition(key, "released")
2265                a_recs, a_cmsgs, a_wmsgs = a
2266
2267                v = a_recs.get(key, finish)
2268                func = self._transitions_table["released", v]
2269                b_recs: dict
2270                b_cmsgs: dict
2271                b_wmsgs: dict
2272                b: tuple = func(key)
2273                b_recs, b_cmsgs, b_wmsgs = b
2274
2275                recommendations.update(a_recs)
2276                for c, new_msgs in a_cmsgs.items():
2277                    msgs = client_msgs.get(c)  # type: ignore
2278                    if msgs is not None:
2279                        msgs.extend(new_msgs)
2280                    else:
2281                        client_msgs[c] = new_msgs
2282                for w, new_msgs in a_wmsgs.items():
2283                    msgs = worker_msgs.get(w)  # type: ignore
2284                    if msgs is not None:
2285                        msgs.extend(new_msgs)
2286                    else:
2287                        worker_msgs[w] = new_msgs
2288
2289                recommendations.update(b_recs)
2290                for c, new_msgs in b_cmsgs.items():
2291                    msgs = client_msgs.get(c)  # type: ignore
2292                    if msgs is not None:
2293                        msgs.extend(new_msgs)
2294                    else:
2295                        client_msgs[c] = new_msgs
2296                for w, new_msgs in b_wmsgs.items():
2297                    msgs = worker_msgs.get(w)  # type: ignore
2298                    if msgs is not None:
2299                        msgs.extend(new_msgs)
2300                    else:
2301                        worker_msgs[w] = new_msgs
2302
2303                start = "released"
2304            else:
2305                raise RuntimeError("Impossible transition from %r to %r" % start_finish)
2306
2307            finish2 = ts._state
2308            # FIXME downcast antipattern
2309            scheduler = pep484_cast(Scheduler, self)
2310            scheduler.transition_log.append(
2311                (key, start, finish2, recommendations, time())
2312            )
2313            if parent._validate:
2314                logger.debug(
2315                    "Transitioned %r %s->%s (actual: %s).  Consequence: %s",
2316                    key,
2317                    start,
2318                    finish2,
2319                    ts._state,
2320                    dict(recommendations),
2321                )
2322            if self.plugins:
2323                # Temporarily put back forgotten key for plugin to retrieve it
2324                if ts._state == "forgotten":
2325                    ts._dependents = dependents
2326                    ts._dependencies = dependencies
2327                    parent._tasks[ts._key] = ts
2328                for plugin in list(self.plugins.values()):
2329                    try:
2330                        plugin.transition(key, start, finish2, *args, **kwargs)
2331                    except Exception:
2332                        logger.info("Plugin failed with exception", exc_info=True)
2333                if ts._state == "forgotten":
2334                    del parent._tasks[ts._key]
2335
2336            tg: TaskGroup = ts._group
2337            if ts._state == "forgotten" and tg._name in parent._task_groups:
2338                # Remove TaskGroup if all tasks are in the forgotten state
2339                all_forgotten: bint = True
2340                for s in ALL_TASK_STATES:
2341                    if tg._states.get(s):
2342                        all_forgotten = False
2343                        break
2344                if all_forgotten:
2345                    ts._prefix._groups.remove(tg)
2346                    del parent._task_groups[tg._name]
2347
2348            return recommendations, client_msgs, worker_msgs
2349        except Exception:
2350            logger.exception("Error transitioning %r from %r to %r", key, start, finish)
2351            if LOG_PDB:
2352                import pdb
2353
2354                pdb.set_trace()
2355            raise
2356
2357    def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: dict):
2358        """Process transitions until none are left
2359
2360        This includes feedback from previous transitions and continues until we
2361        reach a steady state
2362        """
2363        keys: set = set()
2364        recommendations = recommendations.copy()
2365        msgs: list
2366        new_msgs: list
2367        new: tuple
2368        new_recs: dict
2369        new_cmsgs: dict
2370        new_wmsgs: dict
2371        while recommendations:
2372            key, finish = recommendations.popitem()
2373            keys.add(key)
2374
2375            new = self._transition(key, finish)
2376            new_recs, new_cmsgs, new_wmsgs = new
2377
2378            recommendations.update(new_recs)
2379            for c, new_msgs in new_cmsgs.items():
2380                msgs = client_msgs.get(c)  # type: ignore
2381                if msgs is not None:
2382                    msgs.extend(new_msgs)
2383                else:
2384                    client_msgs[c] = new_msgs
2385            for w, new_msgs in new_wmsgs.items():
2386                msgs = worker_msgs.get(w)  # type: ignore
2387                if msgs is not None:
2388                    msgs.extend(new_msgs)
2389                else:
2390                    worker_msgs[w] = new_msgs
2391
2392        if self._validate:
2393            # FIXME downcast antipattern
2394            scheduler = pep484_cast(Scheduler, self)
2395            for key in keys:
2396                scheduler.validate_key(key)
2397
2398    def transition_released_waiting(self, key):
2399        try:
2400            ts: TaskState = self._tasks[key]
2401            dts: TaskState
2402            recommendations: dict = {}
2403            client_msgs: dict = {}
2404            worker_msgs: dict = {}
2405
2406            if self._validate:
2407                assert ts._run_spec
2408                assert not ts._waiting_on
2409                assert not ts._who_has
2410                assert not ts._processing_on
2411                assert not any([dts._state == "forgotten" for dts in ts._dependencies])
2412
2413            if ts._has_lost_dependencies:
2414                recommendations[key] = "forgotten"
2415                return recommendations, client_msgs, worker_msgs
2416
2417            ts.state = "waiting"
2418
2419            dts: TaskState
2420            for dts in ts._dependencies:
2421                if dts._exception_blame:
2422                    ts._exception_blame = dts._exception_blame
2423                    recommendations[key] = "erred"
2424                    return recommendations, client_msgs, worker_msgs
2425
2426            for dts in ts._dependencies:
2427                dep = dts._key
2428                if not dts._who_has:
2429                    ts._waiting_on.add(dts)
2430                if dts._state == "released":
2431                    recommendations[dep] = "waiting"
2432                else:
2433                    dts._waiters.add(ts)
2434
2435            ts._waiters = {dts for dts in ts._dependents if dts._state == "waiting"}
2436
2437            if not ts._waiting_on:
2438                if self._workers_dv:
2439                    recommendations[key] = "processing"
2440                else:
2441                    self._unrunnable.add(ts)
2442                    ts.state = "no-worker"
2443
2444            return recommendations, client_msgs, worker_msgs
2445        except Exception as e:
2446            logger.exception(e)
2447            if LOG_PDB:
2448                import pdb
2449
2450                pdb.set_trace()
2451            raise
2452
2453    def transition_no_worker_waiting(self, key):
2454        try:
2455            ts: TaskState = self._tasks[key]
2456            dts: TaskState
2457            recommendations: dict = {}
2458            client_msgs: dict = {}
2459            worker_msgs: dict = {}
2460
2461            if self._validate:
2462                assert ts in self._unrunnable
2463                assert not ts._waiting_on
2464                assert not ts._who_has
2465                assert not ts._processing_on
2466
2467            self._unrunnable.remove(ts)
2468
2469            if ts._has_lost_dependencies:
2470                recommendations[key] = "forgotten"
2471                return recommendations, client_msgs, worker_msgs
2472
2473            for dts in ts._dependencies:
2474                dep = dts._key
2475                if not dts._who_has:
2476                    ts._waiting_on.add(dts)
2477                if dts._state == "released":
2478                    recommendations[dep] = "waiting"
2479                else:
2480                    dts._waiters.add(ts)
2481
2482            ts.state = "waiting"
2483
2484            if not ts._waiting_on:
2485                if self._workers_dv:
2486                    recommendations[key] = "processing"
2487                else:
2488                    self._unrunnable.add(ts)
2489                    ts.state = "no-worker"
2490
2491            return recommendations, client_msgs, worker_msgs
2492        except Exception as e:
2493            logger.exception(e)
2494            if LOG_PDB:
2495                import pdb
2496
2497                pdb.set_trace()
2498            raise
2499
2500    def transition_no_worker_memory(
2501        self, key, nbytes=None, type=None, typename: str = None, worker=None
2502    ):
2503        try:
2504            ws: WorkerState = self._workers_dv[worker]
2505            ts: TaskState = self._tasks[key]
2506            recommendations: dict = {}
2507            client_msgs: dict = {}
2508            worker_msgs: dict = {}
2509
2510            if self._validate:
2511                assert not ts._processing_on
2512                assert not ts._waiting_on
2513                assert ts._state == "no-worker"
2514
2515            self._unrunnable.remove(ts)
2516
2517            if nbytes is not None:
2518                ts.set_nbytes(nbytes)
2519
2520            self.check_idle_saturated(ws)
2521
2522            _add_to_memory(
2523                self, ts, ws, recommendations, client_msgs, type=type, typename=typename
2524            )
2525            ts.state = "memory"
2526
2527            return recommendations, client_msgs, worker_msgs
2528        except Exception as e:
2529            logger.exception(e)
2530            if LOG_PDB:
2531                import pdb
2532
2533                pdb.set_trace()
2534            raise
2535
2536    @ccall
2537    @exceptval(check=False)
2538    def decide_worker(self, ts: TaskState) -> WorkerState:  # -> WorkerState | None
2539        """
2540        Decide on a worker for task *ts*. Return a WorkerState.
2541
2542        If it's a root or root-like task, we place it with its relatives to
2543        reduce future data tansfer.
2544
2545        If it has dependencies or restrictions, we use
2546        `decide_worker_from_deps_and_restrictions`.
2547
2548        Otherwise, we pick the least occupied worker, or pick from all workers
2549        in a round-robin fashion.
2550        """
2551        if not self._workers_dv:
2552            return None  # type: ignore
2553
2554        ws: WorkerState
2555        tg: TaskGroup = ts._group
2556        valid_workers: set = self.valid_workers(ts)
2557
2558        if (
2559            valid_workers is not None
2560            and not valid_workers
2561            and not ts._loose_restrictions
2562        ):
2563            self._unrunnable.add(ts)
2564            ts.state = "no-worker"
2565            return None  # type: ignore
2566
2567        # Group is larger than cluster with few dependencies?
2568        # Minimize future data transfers.
2569        if (
2570            valid_workers is None
2571            and len(tg) > self._total_nthreads * 2
2572            and len(tg._dependencies) < 5
2573            and sum(map(len, tg._dependencies)) < 5
2574        ):
2575            ws = tg._last_worker
2576
2577            if not (
2578                ws and tg._last_worker_tasks_left and ws._address in self._workers_dv
2579            ):
2580                # Last-used worker is full or unknown; pick a new worker for the next few tasks
2581                ws = min(
2582                    (self._idle_dv or self._workers_dv).values(),
2583                    key=partial(self.worker_objective, ts),
2584                )
2585                tg._last_worker_tasks_left = math.floor(
2586                    (len(tg) / self._total_nthreads) * ws._nthreads
2587                )
2588
2589            # Record `last_worker`, or clear it on the final task
2590            tg._last_worker = (
2591                ws if tg.states["released"] + tg.states["waiting"] > 1 else None
2592            )
2593            tg._last_worker_tasks_left -= 1
2594            return ws
2595
2596        if ts._dependencies or valid_workers is not None:
2597            ws = decide_worker(
2598                ts,
2599                self._workers_dv.values(),
2600                valid_workers,
2601                partial(self.worker_objective, ts),
2602            )
2603        else:
2604            # Fastpath when there are no related tasks or restrictions
2605            worker_pool = self._idle or self._workers
2606            # Note: cython.cast, not typing.cast!
2607            worker_pool_dv = cast(dict, worker_pool)
2608            wp_vals = worker_pool.values()
2609            n_workers: Py_ssize_t = len(worker_pool_dv)
2610            if n_workers < 20:  # smart but linear in small case
2611                ws = min(wp_vals, key=operator.attrgetter("occupancy"))
2612                if ws._occupancy == 0:
2613                    # special case to use round-robin; linear search
2614                    # for next worker with zero occupancy (or just
2615                    # land back where we started).
2616                    wp_i: WorkerState
2617                    start: Py_ssize_t = self._n_tasks % n_workers
2618                    i: Py_ssize_t
2619                    for i in range(n_workers):
2620                        wp_i = wp_vals[(i + start) % n_workers]
2621                        if wp_i._occupancy == 0:
2622                            ws = wp_i
2623                            break
2624            else:  # dumb but fast in large case
2625                ws = wp_vals[self._n_tasks % n_workers]
2626
2627        if self._validate:
2628            assert ws is None or isinstance(ws, WorkerState), (
2629                type(ws),
2630                ws,
2631            )
2632            assert ws._address in self._workers_dv
2633
2634        return ws
2635
2636    @ccall
2637    def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double:
2638        """Estimate task duration using worker state and task state.
2639
2640        If a task takes longer than twice the current average duration we
2641        estimate the task duration to be 2x current-runtime, otherwise we set it
2642        to be the average duration.
2643
2644        See also ``_remove_from_processing``
2645        """
2646        exec_time: double = ws._executing.get(ts, 0)
2647        duration: double = self.get_task_duration(ts)
2648        total_duration: double
2649        if exec_time > 2 * duration:
2650            total_duration = 2 * exec_time
2651        else:
2652            comm: double = self.get_comm_cost(ts, ws)
2653            total_duration = duration + comm
2654        old = ws._processing.get(ts, 0)
2655        ws._processing[ts] = total_duration
2656        self._total_occupancy += total_duration - old
2657        ws._occupancy += total_duration - old
2658
2659        return total_duration
2660
2661    def transition_waiting_processing(self, key):
2662        try:
2663            ts: TaskState = self._tasks[key]
2664            dts: TaskState
2665            recommendations: dict = {}
2666            client_msgs: dict = {}
2667            worker_msgs: dict = {}
2668
2669            if self._validate:
2670                assert not ts._waiting_on
2671                assert not ts._who_has
2672                assert not ts._exception_blame
2673                assert not ts._processing_on
2674                assert not ts._has_lost_dependencies
2675                assert ts not in self._unrunnable
2676                assert all([dts._who_has for dts in ts._dependencies])
2677
2678            ws: WorkerState = self.decide_worker(ts)
2679            if ws is None:
2680                return recommendations, client_msgs, worker_msgs
2681            worker = ws._address
2682
2683            self.set_duration_estimate(ts, ws)
2684            ts._processing_on = ws
2685            ts.state = "processing"
2686            self.consume_resources(ts, ws)
2687            self.check_idle_saturated(ws)
2688            self._n_tasks += 1
2689
2690            if ts._actor:
2691                ws._actors.add(ts)
2692
2693            # logger.debug("Send job to worker: %s, %s", worker, key)
2694
2695            worker_msgs[worker] = [_task_to_msg(self, ts)]
2696
2697            return recommendations, client_msgs, worker_msgs
2698        except Exception as e:
2699            logger.exception(e)
2700            if LOG_PDB:
2701                import pdb
2702
2703                pdb.set_trace()
2704            raise
2705
2706    def transition_waiting_memory(
2707        self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs
2708    ):
2709        try:
2710            ws: WorkerState = self._workers_dv[worker]
2711            ts: TaskState = self._tasks[key]
2712            recommendations: dict = {}
2713            client_msgs: dict = {}
2714            worker_msgs: dict = {}
2715
2716            if self._validate:
2717                assert not ts._processing_on
2718                assert ts._waiting_on
2719                assert ts._state == "waiting"
2720
2721            ts._waiting_on.clear()
2722
2723            if nbytes is not None:
2724                ts.set_nbytes(nbytes)
2725
2726            self.check_idle_saturated(ws)
2727
2728            _add_to_memory(
2729                self, ts, ws, recommendations, client_msgs, type=type, typename=typename
2730            )
2731
2732            if self._validate:
2733                assert not ts._processing_on
2734                assert not ts._waiting_on
2735                assert ts._who_has
2736
2737            return recommendations, client_msgs, worker_msgs
2738        except Exception as e:
2739            logger.exception(e)
2740            if LOG_PDB:
2741                import pdb
2742
2743                pdb.set_trace()
2744            raise
2745
2746    def transition_processing_memory(
2747        self,
2748        key,
2749        nbytes=None,
2750        type=None,
2751        typename: str = None,
2752        worker=None,
2753        startstops=None,
2754        **kwargs,
2755    ):
2756        ws: WorkerState
2757        wws: WorkerState
2758        recommendations: dict = {}
2759        client_msgs: dict = {}
2760        worker_msgs: dict = {}
2761        try:
2762            ts: TaskState = self._tasks[key]
2763
2764            assert worker
2765            assert isinstance(worker, str)
2766
2767            if self._validate:
2768                assert ts._processing_on
2769                ws = ts._processing_on
2770                assert ts in ws._processing
2771                assert not ts._waiting_on
2772                assert not ts._who_has, (ts, ts._who_has)
2773                assert not ts._exception_blame
2774                assert ts._state == "processing"
2775
2776            ws = self._workers_dv.get(worker)  # type: ignore
2777            if ws is None:
2778                recommendations[key] = "released"
2779                return recommendations, client_msgs, worker_msgs
2780
2781            if ws != ts._processing_on:  # someone else has this task
2782                logger.info(
2783                    "Unexpected worker completed task. Expected: %s, Got: %s, Key: %s",
2784                    ts._processing_on,
2785                    ws,
2786                    key,
2787                )
2788                worker_msgs[ts._processing_on.address] = [
2789                    {
2790                        "op": "cancel-compute",
2791                        "key": key,
2792                        "reason": "Finished on different worker",
2793                    }
2794                ]
2795
2796            #############################
2797            # Update Timing Information #
2798            #############################
2799            if startstops:
2800                startstop: dict
2801                for startstop in startstops:
2802                    ts._group.add_duration(
2803                        stop=startstop["stop"],
2804                        start=startstop["start"],
2805                        action=startstop["action"],
2806                    )
2807
2808            s: set = self._unknown_durations.pop(ts._prefix._name, set())
2809            tts: TaskState
2810            steal = self.extensions.get("stealing")
2811            for tts in s:
2812                if tts._processing_on:
2813                    self.set_duration_estimate(tts, tts._processing_on)
2814                    if steal:
2815                        steal.put_key_in_stealable(tts)
2816
2817            ############################
2818            # Update State Information #
2819            ############################
2820            if nbytes is not None:
2821                ts.set_nbytes(nbytes)
2822
2823            _remove_from_processing(self, ts)
2824
2825            _add_to_memory(
2826                self, ts, ws, recommendations, client_msgs, type=type, typename=typename
2827            )
2828
2829            if self._validate:
2830                assert not ts._processing_on
2831                assert not ts._waiting_on
2832
2833            return recommendations, client_msgs, worker_msgs
2834        except Exception as e:
2835            logger.exception(e)
2836            if LOG_PDB:
2837                import pdb
2838
2839                pdb.set_trace()
2840            raise
2841
2842    def transition_memory_released(self, key, safe: bint = False):
2843        ws: WorkerState
2844        try:
2845            ts: TaskState = self._tasks[key]
2846            dts: TaskState
2847            recommendations: dict = {}
2848            client_msgs: dict = {}
2849            worker_msgs: dict = {}
2850
2851            if self._validate:
2852                assert not ts._waiting_on
2853                assert not ts._processing_on
2854                if safe:
2855                    assert not ts._waiters
2856
2857            if ts._actor:
2858                for ws in ts._who_has:
2859                    ws._actors.discard(ts)
2860                if ts._who_wants:
2861                    ts._exception_blame = ts
2862                    ts._exception = "Worker holding Actor was lost"
2863                    recommendations[ts._key] = "erred"
2864                    return (
2865                        recommendations,
2866                        client_msgs,
2867                        worker_msgs,
2868                    )  # don't try to recreate
2869
2870            for dts in ts._waiters:
2871                if dts._state in ("no-worker", "processing"):
2872                    recommendations[dts._key] = "waiting"
2873                elif dts._state == "waiting":
2874                    dts._waiting_on.add(ts)
2875
2876            # XXX factor this out?
2877            worker_msg = {
2878                "op": "free-keys",
2879                "keys": [key],
2880                "stimulus_id": f"memory-released-{time()}",
2881            }
2882            for ws in ts._who_has:
2883                worker_msgs[ws._address] = [worker_msg]
2884            self.remove_all_replicas(ts)
2885
2886            ts.state = "released"
2887
2888            report_msg = {"op": "lost-data", "key": key}
2889            cs: ClientState
2890            for cs in ts._who_wants:
2891                client_msgs[cs._client_key] = [report_msg]
2892
2893            if not ts._run_spec:  # pure data
2894                recommendations[key] = "forgotten"
2895            elif ts._has_lost_dependencies:
2896                recommendations[key] = "forgotten"
2897            elif ts._who_wants or ts._waiters:
2898                recommendations[key] = "waiting"
2899
2900            if self._validate:
2901                assert not ts._waiting_on
2902
2903            return recommendations, client_msgs, worker_msgs
2904        except Exception as e:
2905            logger.exception(e)
2906            if LOG_PDB:
2907                import pdb
2908
2909                pdb.set_trace()
2910            raise
2911
2912    def transition_released_erred(self, key):
2913        try:
2914            ts: TaskState = self._tasks[key]
2915            dts: TaskState
2916            failing_ts: TaskState
2917            recommendations: dict = {}
2918            client_msgs: dict = {}
2919            worker_msgs: dict = {}
2920
2921            if self._validate:
2922                with log_errors(pdb=LOG_PDB):
2923                    assert ts._exception_blame
2924                    assert not ts._who_has
2925                    assert not ts._waiting_on
2926                    assert not ts._waiters
2927
2928            failing_ts = ts._exception_blame
2929
2930            for dts in ts._dependents:
2931                dts._exception_blame = failing_ts
2932                if not dts._who_has:
2933                    recommendations[dts._key] = "erred"
2934
2935            report_msg = {
2936                "op": "task-erred",
2937                "key": key,
2938                "exception": failing_ts._exception,
2939                "traceback": failing_ts._traceback,
2940            }
2941            cs: ClientState
2942            for cs in ts._who_wants:
2943                client_msgs[cs._client_key] = [report_msg]
2944
2945            ts.state = "erred"
2946
2947            # TODO: waiting data?
2948            return recommendations, client_msgs, worker_msgs
2949        except Exception as e:
2950            logger.exception(e)
2951            if LOG_PDB:
2952                import pdb
2953
2954                pdb.set_trace()
2955            raise
2956
2957    def transition_erred_released(self, key):
2958        try:
2959            ts: TaskState = self._tasks[key]
2960            dts: TaskState
2961            recommendations: dict = {}
2962            client_msgs: dict = {}
2963            worker_msgs: dict = {}
2964
2965            if self._validate:
2966                with log_errors(pdb=LOG_PDB):
2967                    assert ts._exception_blame
2968                    assert not ts._who_has
2969                    assert not ts._waiting_on
2970                    assert not ts._waiters
2971
2972            ts._exception = None
2973            ts._exception_blame = None
2974            ts._traceback = None
2975
2976            for dts in ts._dependents:
2977                if dts._state == "erred":
2978                    recommendations[dts._key] = "waiting"
2979
2980            w_msg = {
2981                "op": "free-keys",
2982                "keys": [key],
2983                "stimulus_id": f"erred-released-{time()}",
2984            }
2985            for ws_addr in ts._erred_on:
2986                worker_msgs[ws_addr] = [w_msg]
2987            ts._erred_on.clear()
2988
2989            report_msg = {"op": "task-retried", "key": key}
2990            cs: ClientState
2991            for cs in ts._who_wants:
2992                client_msgs[cs._client_key] = [report_msg]
2993
2994            ts.state = "released"
2995
2996            return recommendations, client_msgs, worker_msgs
2997        except Exception as e:
2998            logger.exception(e)
2999            if LOG_PDB:
3000                import pdb
3001
3002                pdb.set_trace()
3003            raise
3004
3005    def transition_waiting_released(self, key):
3006        try:
3007            ts: TaskState = self._tasks[key]
3008            recommendations: dict = {}
3009            client_msgs: dict = {}
3010            worker_msgs: dict = {}
3011
3012            if self._validate:
3013                assert not ts._who_has
3014                assert not ts._processing_on
3015
3016            dts: TaskState
3017            for dts in ts._dependencies:
3018                if ts in dts._waiters:
3019                    dts._waiters.discard(ts)
3020                    if not dts._waiters and not dts._who_wants:
3021                        recommendations[dts._key] = "released"
3022            ts._waiting_on.clear()
3023
3024            ts.state = "released"
3025
3026            if ts._has_lost_dependencies:
3027                recommendations[key] = "forgotten"
3028            elif not ts._exception_blame and (ts._who_wants or ts._waiters):
3029                recommendations[key] = "waiting"
3030            else:
3031                ts._waiters.clear()
3032
3033            return recommendations, client_msgs, worker_msgs
3034        except Exception as e:
3035            logger.exception(e)
3036            if LOG_PDB:
3037                import pdb
3038
3039                pdb.set_trace()
3040            raise
3041
3042    def transition_processing_released(self, key):
3043        try:
3044            ts: TaskState = self._tasks[key]
3045            dts: TaskState
3046            recommendations: dict = {}
3047            client_msgs: dict = {}
3048            worker_msgs: dict = {}
3049
3050            if self._validate:
3051                assert ts._processing_on
3052                assert not ts._who_has
3053                assert not ts._waiting_on
3054                assert self._tasks[key].state == "processing"
3055
3056            w: str = _remove_from_processing(self, ts)
3057            if w:
3058                worker_msgs[w] = [
3059                    {
3060                        "op": "free-keys",
3061                        "keys": [key],
3062                        "stimulus_id": f"processing-released-{time()}",
3063                    }
3064                ]
3065
3066            ts.state = "released"
3067
3068            if ts._has_lost_dependencies:
3069                recommendations[key] = "forgotten"
3070            elif ts._waiters or ts._who_wants:
3071                recommendations[key] = "waiting"
3072
3073            if recommendations.get(key) != "waiting":
3074                for dts in ts._dependencies:
3075                    if dts._state != "released":
3076                        dts._waiters.discard(ts)
3077                        if not dts._waiters and not dts._who_wants:
3078                            recommendations[dts._key] = "released"
3079                ts._waiters.clear()
3080
3081            if self._validate:
3082                assert not ts._processing_on
3083
3084            return recommendations, client_msgs, worker_msgs
3085        except Exception as e:
3086            logger.exception(e)
3087            if LOG_PDB:
3088                import pdb
3089
3090                pdb.set_trace()
3091            raise
3092
3093    def transition_processing_erred(
3094        self,
3095        key: str,
3096        cause: str = None,
3097        exception=None,
3098        traceback=None,
3099        exception_text: str = None,
3100        traceback_text: str = None,
3101        worker: str = None,
3102        **kwargs,
3103    ):
3104        ws: WorkerState
3105        try:
3106            ts: TaskState = self._tasks[key]
3107            dts: TaskState
3108            failing_ts: TaskState
3109            recommendations: dict = {}
3110            client_msgs: dict = {}
3111            worker_msgs: dict = {}
3112
3113            if self._validate:
3114                assert cause or ts._exception_blame
3115                assert ts._processing_on
3116                assert not ts._who_has
3117                assert not ts._waiting_on
3118
3119            if ts._actor:
3120                ws = ts._processing_on
3121                ws._actors.remove(ts)
3122
3123            w = _remove_from_processing(self, ts)
3124
3125            ts._erred_on.add(w or worker)
3126            if exception is not None:
3127                ts._exception = exception
3128                ts._exception_text = exception_text  # type: ignore
3129            if traceback is not None:
3130                ts._traceback = traceback
3131                ts._traceback_text = traceback_text  # type: ignore
3132            if cause is not None:
3133                failing_ts = self._tasks[cause]
3134                ts._exception_blame = failing_ts
3135            else:
3136                failing_ts = ts._exception_blame  # type: ignore
3137
3138            for dts in ts._dependents:
3139                dts._exception_blame = failing_ts
3140                recommendations[dts._key] = "erred"
3141
3142            for dts in ts._dependencies:
3143                dts._waiters.discard(ts)
3144                if not dts._waiters and not dts._who_wants:
3145                    recommendations[dts._key] = "released"
3146
3147            ts._waiters.clear()  # do anything with this?
3148
3149            ts.state = "erred"
3150
3151            report_msg = {
3152                "op": "task-erred",
3153                "key": key,
3154                "exception": failing_ts._exception,
3155                "traceback": failing_ts._traceback,
3156            }
3157            cs: ClientState
3158            for cs in ts._who_wants:
3159                client_msgs[cs._client_key] = [report_msg]
3160
3161            cs = self._clients["fire-and-forget"]
3162            if ts in cs._wants_what:
3163                _client_releases_keys(
3164                    self,
3165                    cs=cs,
3166                    keys=[key],
3167                    recommendations=recommendations,
3168                )
3169
3170            if self._validate:
3171                assert not ts._processing_on
3172
3173            return recommendations, client_msgs, worker_msgs
3174        except Exception as e:
3175            logger.exception(e)
3176            if LOG_PDB:
3177                import pdb
3178
3179                pdb.set_trace()
3180            raise
3181
3182    def transition_no_worker_released(self, key):
3183        try:
3184            ts: TaskState = self._tasks[key]
3185            dts: TaskState
3186            recommendations: dict = {}
3187            client_msgs: dict = {}
3188            worker_msgs: dict = {}
3189
3190            if self._validate:
3191                assert self._tasks[key].state == "no-worker"
3192                assert not ts._who_has
3193                assert not ts._waiting_on
3194
3195            self._unrunnable.remove(ts)
3196            ts.state = "released"
3197
3198            for dts in ts._dependencies:
3199                dts._waiters.discard(ts)
3200
3201            ts._waiters.clear()
3202
3203            return recommendations, client_msgs, worker_msgs
3204        except Exception as e:
3205            logger.exception(e)
3206            if LOG_PDB:
3207                import pdb
3208
3209                pdb.set_trace()
3210            raise
3211
3212    @ccall
3213    def remove_key(self, key):
3214        ts: TaskState = self._tasks.pop(key)
3215        assert ts._state == "forgotten"
3216        self._unrunnable.discard(ts)
3217        cs: ClientState
3218        for cs in ts._who_wants:
3219            cs._wants_what.remove(ts)
3220        ts._who_wants.clear()
3221        ts._processing_on = None
3222        ts._exception_blame = ts._exception = ts._traceback = None
3223        self._task_metadata.pop(key, None)
3224
3225    def transition_memory_forgotten(self, key):
3226        ws: WorkerState
3227        try:
3228            ts: TaskState = self._tasks[key]
3229            recommendations: dict = {}
3230            client_msgs: dict = {}
3231            worker_msgs: dict = {}
3232
3233            if self._validate:
3234                assert ts._state == "memory"
3235                assert not ts._processing_on
3236                assert not ts._waiting_on
3237                if not ts._run_spec:
3238                    # It's ok to forget a pure data task
3239                    pass
3240                elif ts._has_lost_dependencies:
3241                    # It's ok to forget a task with forgotten dependencies
3242                    pass
3243                elif not ts._who_wants and not ts._waiters and not ts._dependents:
3244                    # It's ok to forget a task that nobody needs
3245                    pass
3246                else:
3247                    assert 0, (ts,)
3248
3249            if ts._actor:
3250                for ws in ts._who_has:
3251                    ws._actors.discard(ts)
3252
3253            _propagate_forgotten(self, ts, recommendations, worker_msgs)
3254
3255            client_msgs = _task_to_client_msgs(self, ts)
3256            self.remove_key(key)
3257
3258            return recommendations, client_msgs, worker_msgs
3259        except Exception as e:
3260            logger.exception(e)
3261            if LOG_PDB:
3262                import pdb
3263
3264                pdb.set_trace()
3265            raise
3266
3267    def transition_released_forgotten(self, key):
3268        try:
3269            ts: TaskState = self._tasks[key]
3270            recommendations: dict = {}
3271            client_msgs: dict = {}
3272            worker_msgs: dict = {}
3273
3274            if self._validate:
3275                assert ts._state in ("released", "erred")
3276                assert not ts._who_has
3277                assert not ts._processing_on
3278                assert not ts._waiting_on, (ts, ts._waiting_on)
3279                if not ts._run_spec:
3280                    # It's ok to forget a pure data task
3281                    pass
3282                elif ts._has_lost_dependencies:
3283                    # It's ok to forget a task with forgotten dependencies
3284                    pass
3285                elif not ts._who_wants and not ts._waiters and not ts._dependents:
3286                    # It's ok to forget a task that nobody needs
3287                    pass
3288                else:
3289                    assert 0, (ts,)
3290
3291            _propagate_forgotten(self, ts, recommendations, worker_msgs)
3292
3293            client_msgs = _task_to_client_msgs(self, ts)
3294            self.remove_key(key)
3295
3296            return recommendations, client_msgs, worker_msgs
3297        except Exception as e:
3298            logger.exception(e)
3299            if LOG_PDB:
3300                import pdb
3301
3302                pdb.set_trace()
3303            raise
3304
3305    ##############################
3306    # Assigning Tasks to Workers #
3307    ##############################
3308
3309    @ccall
3310    @exceptval(check=False)
3311    def check_idle_saturated(self, ws: WorkerState, occ: double = -1.0):
3312        """Update the status of the idle and saturated state
3313
3314        The scheduler keeps track of workers that are ..
3315
3316        -  Saturated: have enough work to stay busy
3317        -  Idle: do not have enough work to stay busy
3318
3319        They are considered saturated if they both have enough tasks to occupy
3320        all of their threads, and if the expected runtime of those tasks is
3321        large enough.
3322
3323        This is useful for load balancing and adaptivity.
3324        """
3325        if self._total_nthreads == 0 or ws.status == Status.closed:
3326            return
3327        if occ < 0:
3328            occ = ws._occupancy
3329
3330        nc: Py_ssize_t = ws._nthreads
3331        p: Py_ssize_t = len(ws._processing)
3332        avg: double = self._total_occupancy / self._total_nthreads
3333
3334        idle = self._idle
3335        saturated: set = self._saturated
3336        if p < nc or occ < nc * avg / 2:
3337            idle[ws._address] = ws
3338            saturated.discard(ws)
3339        else:
3340            idle.pop(ws._address, None)
3341
3342            if p > nc:
3343                pending: double = occ * (p - nc) / (p * nc)
3344                if 0.4 < pending > 1.9 * avg:
3345                    saturated.add(ws)
3346                    return
3347
3348            saturated.discard(ws)
3349
3350    @ccall
3351    def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> double:
3352        """
3353        Get the estimated communication cost (in s.) to compute the task
3354        on the given worker.
3355        """
3356        dts: TaskState
3357        deps: set = ts._dependencies.difference(ws._has_what)
3358        nbytes: Py_ssize_t = 0
3359        for dts in deps:
3360            nbytes += dts._nbytes
3361        return nbytes / self._bandwidth
3362
3363    @ccall
3364    def get_task_duration(self, ts: TaskState) -> double:
3365        """Get the estimated computation cost of the given task (not including
3366        any communication cost).
3367
3368        If no data has been observed, value of
3369        `distributed.scheduler.default-task-durations` are used. If none is set
3370        for this task, `distributed.scheduler.unknown-task-duration` is used
3371        instead.
3372        """
3373        duration: double = ts._prefix._duration_average
3374        if duration >= 0:
3375            return duration
3376
3377        s: set = self._unknown_durations.get(ts._prefix._name)  # type: ignore
3378        if s is None:
3379            self._unknown_durations[ts._prefix._name] = s = set()
3380        s.add(ts)
3381        return self.UNKNOWN_TASK_DURATION
3382
3383    @ccall
3384    @exceptval(check=False)
3385    def valid_workers(self, ts: TaskState) -> set:  # set[WorkerState] | None
3386        """Return set of currently valid workers for key
3387
3388        If all workers are valid then this returns ``None``.
3389        This checks tracks the following state:
3390
3391        *  worker_restrictions
3392        *  host_restrictions
3393        *  resource_restrictions
3394        """
3395        s: set = None  # type: ignore
3396
3397        if ts._worker_restrictions:
3398            s = {addr for addr in ts._worker_restrictions if addr in self._workers_dv}
3399
3400        if ts._host_restrictions:
3401            # Resolve the alias here rather than early, for the worker
3402            # may not be connected when host_restrictions is populated
3403            hr: list = [self.coerce_hostname(h) for h in ts._host_restrictions]
3404            # XXX need HostState?
3405            sl: list = []
3406            for h in hr:
3407                dh: dict = self._host_info.get(h)  # type: ignore
3408                if dh is not None:
3409                    sl.append(dh["addresses"])
3410
3411            ss: set = set.union(*sl) if sl else set()
3412            if s is None:
3413                s = ss
3414            else:
3415                s |= ss
3416
3417        if ts._resource_restrictions:
3418            dw: dict = {}
3419            for resource, required in ts._resource_restrictions.items():
3420                dr: dict = self._resources.get(resource)  # type: ignore
3421                if dr is None:
3422                    self._resources[resource] = dr = {}
3423
3424                sw: set = set()
3425                for addr, supplied in dr.items():
3426                    if supplied >= required:
3427                        sw.add(addr)
3428
3429                dw[resource] = sw
3430
3431            ww: set = set.intersection(*dw.values())
3432            if s is None:
3433                s = ww
3434            else:
3435                s &= ww
3436
3437        if s is None:
3438            if len(self._running) < len(self._workers_dv):
3439                return self._running.copy()
3440        else:
3441            s = {self._workers_dv[addr] for addr in s}
3442            if len(self._running) < len(self._workers_dv):
3443                s &= self._running
3444
3445        return s
3446
3447    @ccall
3448    def consume_resources(self, ts: TaskState, ws: WorkerState):
3449        if ts._resource_restrictions:
3450            for r, required in ts._resource_restrictions.items():
3451                ws._used_resources[r] += required
3452
3453    @ccall
3454    def release_resources(self, ts: TaskState, ws: WorkerState):
3455        if ts._resource_restrictions:
3456            for r, required in ts._resource_restrictions.items():
3457                ws._used_resources[r] -= required
3458
3459    @ccall
3460    def coerce_hostname(self, host):
3461        """
3462        Coerce the hostname of a worker.
3463        """
3464        alias = self._aliases.get(host)
3465        if alias is not None:
3466            ws: WorkerState = self._workers_dv[alias]
3467            return ws.host
3468        else:
3469            return host
3470
3471    @ccall
3472    @exceptval(check=False)
3473    def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple:
3474        """
3475        Objective function to determine which worker should get the task
3476
3477        Minimize expected start time.  If a tie then break with data storage.
3478        """
3479        dts: TaskState
3480        nbytes: Py_ssize_t
3481        comm_bytes: Py_ssize_t = 0
3482        for dts in ts._dependencies:
3483            if ws not in dts._who_has:
3484                nbytes = dts.get_nbytes()
3485                comm_bytes += nbytes
3486
3487        stack_time: double = ws._occupancy / ws._nthreads
3488        start_time: double = stack_time + comm_bytes / self._bandwidth
3489
3490        if ts._actor:
3491            return (len(ws._actors), start_time, ws._nbytes)
3492        else:
3493            return (start_time, ws._nbytes)
3494
3495    @ccall
3496    def add_replica(self, ts: TaskState, ws: WorkerState):
3497        """Note that a worker holds a replica of a task with state='memory'"""
3498        if self._validate:
3499            assert ws not in ts._who_has
3500            assert ts not in ws._has_what
3501
3502        ws._nbytes += ts.get_nbytes()
3503        ws._has_what[ts] = None
3504        ts._who_has.add(ws)
3505        if len(ts._who_has) == 2:
3506            self._replicated_tasks.add(ts)
3507
3508    @ccall
3509    def remove_replica(self, ts: TaskState, ws: WorkerState):
3510        """Note that a worker no longer holds a replica of a task"""
3511        ws._nbytes -= ts.get_nbytes()
3512        del ws._has_what[ts]
3513        ts._who_has.remove(ws)
3514        if len(ts._who_has) == 1:
3515            self._replicated_tasks.remove(ts)
3516
3517    @ccall
3518    def remove_all_replicas(self, ts: TaskState):
3519        """Remove all replicas of a task from all workers"""
3520        ws: WorkerState
3521        nbytes: Py_ssize_t = ts.get_nbytes()
3522        for ws in ts._who_has:
3523            ws._nbytes -= nbytes
3524            del ws._has_what[ts]
3525        if len(ts._who_has) > 1:
3526            self._replicated_tasks.remove(ts)
3527        ts._who_has.clear()
3528
3529
3530class Scheduler(SchedulerState, ServerNode):
3531    """Dynamic distributed task scheduler
3532
3533    The scheduler tracks the current state of workers, data, and computations.
3534    The scheduler listens for events and responds by controlling workers
3535    appropriately.  It continuously tries to use the workers to execute an ever
3536    growing dask graph.
3537
3538    All events are handled quickly, in linear time with respect to their input
3539    (which is often of constant size) and generally within a millisecond.  To
3540    accomplish this the scheduler tracks a lot of state.  Every operation
3541    maintains the consistency of this state.
3542
3543    The scheduler communicates with the outside world through Comm objects.
3544    It maintains a consistent and valid view of the world even when listening
3545    to several clients at once.
3546
3547    A Scheduler is typically started either with the ``dask-scheduler``
3548    executable::
3549
3550         $ dask-scheduler
3551         Scheduler started at 127.0.0.1:8786
3552
3553    Or within a LocalCluster a Client starts up without connection
3554    information::
3555
3556        >>> c = Client()  # doctest: +SKIP
3557        >>> c.cluster.scheduler  # doctest: +SKIP
3558        Scheduler(...)
3559
3560    Users typically do not interact with the scheduler directly but rather with
3561    the client object ``Client``.
3562
3563    **State**
3564
3565    The scheduler contains the following state variables.  Each variable is
3566    listed along with what it stores and a brief description.
3567
3568    * **tasks:** ``{task key: TaskState}``
3569        Tasks currently known to the scheduler
3570    * **unrunnable:** ``{TaskState}``
3571        Tasks in the "no-worker" state
3572
3573    * **workers:** ``{worker key: WorkerState}``
3574        Workers currently connected to the scheduler
3575    * **idle:** ``{WorkerState}``:
3576        Set of workers that are not fully utilized
3577    * **saturated:** ``{WorkerState}``:
3578        Set of workers that are not over-utilized
3579
3580    * **host_info:** ``{hostname: dict}``:
3581        Information about each worker host
3582
3583    * **clients:** ``{client key: ClientState}``
3584        Clients currently connected to the scheduler
3585
3586    * **services:** ``{str: port}``:
3587        Other services running on this scheduler, like Bokeh
3588    * **loop:** ``IOLoop``:
3589        The running Tornado IOLoop
3590    * **client_comms:** ``{client key: Comm}``
3591        For each client, a Comm object used to receive task requests and
3592        report task status updates.
3593    * **stream_comms:** ``{worker key: Comm}``
3594        For each worker, a Comm object from which we both accept stimuli and
3595        report results
3596    * **task_duration:** ``{key-prefix: time}``
3597        Time we expect certain functions to take, e.g. ``{'sum': 0.25}``
3598    """
3599
3600    default_port = 8786
3601    _instances: "ClassVar[weakref.WeakSet[Scheduler]]" = weakref.WeakSet()
3602
3603    def __init__(
3604        self,
3605        loop=None,
3606        delete_interval="500ms",
3607        synchronize_worker_interval="60s",
3608        services=None,
3609        service_kwargs=None,
3610        allowed_failures=None,
3611        extensions=None,
3612        validate=None,
3613        scheduler_file=None,
3614        security=None,
3615        worker_ttl=None,
3616        idle_timeout=None,
3617        interface=None,
3618        host=None,
3619        port=0,
3620        protocol=None,
3621        dashboard_address=None,
3622        dashboard=None,
3623        http_prefix="/",
3624        preload=None,
3625        preload_argv=(),
3626        plugins=(),
3627        **kwargs,
3628    ):
3629        self._setup_logging(logger)
3630
3631        # Attributes
3632        if allowed_failures is None:
3633            allowed_failures = dask.config.get("distributed.scheduler.allowed-failures")
3634        self.allowed_failures = allowed_failures
3635        if validate is None:
3636            validate = dask.config.get("distributed.scheduler.validate")
3637        self.proc = psutil.Process()
3638        self.delete_interval = parse_timedelta(delete_interval, default="ms")
3639        self.synchronize_worker_interval = parse_timedelta(
3640            synchronize_worker_interval, default="ms"
3641        )
3642        self.digests = None
3643        self.service_specs = services or {}
3644        self.service_kwargs = service_kwargs or {}
3645        self.services = {}
3646        self.scheduler_file = scheduler_file
3647        worker_ttl = worker_ttl or dask.config.get("distributed.scheduler.worker-ttl")
3648        self.worker_ttl = parse_timedelta(worker_ttl) if worker_ttl else None
3649        idle_timeout = idle_timeout or dask.config.get(
3650            "distributed.scheduler.idle-timeout"
3651        )
3652        if idle_timeout:
3653            self.idle_timeout = parse_timedelta(idle_timeout)
3654        else:
3655            self.idle_timeout = None
3656        self.idle_since = time()
3657        self.time_started = self.idle_since  # compatibility for dask-gateway
3658        self._lock = asyncio.Lock()
3659        self.bandwidth_workers = defaultdict(float)
3660        self.bandwidth_types = defaultdict(float)
3661
3662        if not preload:
3663            preload = dask.config.get("distributed.scheduler.preload")
3664        if not preload_argv:
3665            preload_argv = dask.config.get("distributed.scheduler.preload-argv")
3666        self.preloads = preloading.process_preloads(self, preload, preload_argv)
3667
3668        if isinstance(security, dict):
3669            security = Security(**security)
3670        self.security = security or Security()
3671        assert isinstance(self.security, Security)
3672        self.connection_args = self.security.get_connection_args("scheduler")
3673        self.connection_args["handshake_overrides"] = {  # common denominator
3674            "pickle-protocol": 4
3675        }
3676
3677        self._start_address = addresses_from_user_args(
3678            host=host,
3679            port=port,
3680            interface=interface,
3681            protocol=protocol,
3682            security=security,
3683            default_port=self.default_port,
3684        )
3685
3686        http_server_modules = dask.config.get("distributed.scheduler.http.routes")
3687        show_dashboard = dashboard or (dashboard is None and dashboard_address)
3688        # install vanilla route if show_dashboard but bokeh is not installed
3689        if show_dashboard:
3690            try:
3691                import distributed.dashboard.scheduler
3692            except ImportError:
3693                show_dashboard = False
3694                http_server_modules.append("distributed.http.scheduler.missing_bokeh")
3695        routes = get_handlers(
3696            server=self, modules=http_server_modules, prefix=http_prefix
3697        )
3698        self.start_http_server(routes, dashboard_address, default_port=8787)
3699        if show_dashboard:
3700            distributed.dashboard.scheduler.connect(
3701                self.http_application, self.http_server, self, prefix=http_prefix
3702            )
3703
3704        # Communication state
3705        self.loop = loop or IOLoop.current()
3706        self.client_comms = {}
3707        self.stream_comms = {}
3708        self._worker_coroutines = []
3709        self._ipython_kernel = None
3710
3711        # Task state
3712        tasks = {}
3713        for old_attr, new_attr, wrap in [
3714            ("priority", "priority", None),
3715            ("dependencies", "dependencies", _legacy_task_key_set),
3716            ("dependents", "dependents", _legacy_task_key_set),
3717            ("retries", "retries", None),
3718        ]:
3719            func = operator.attrgetter(new_attr)
3720            if wrap is not None:
3721                func = compose(wrap, func)
3722            setattr(self, old_attr, _StateLegacyMapping(tasks, func))
3723
3724        for old_attr, new_attr, wrap in [
3725            ("nbytes", "nbytes", None),
3726            ("who_wants", "who_wants", _legacy_client_key_set),
3727            ("who_has", "who_has", _legacy_worker_key_set),
3728            ("waiting", "waiting_on", _legacy_task_key_set),
3729            ("waiting_data", "waiters", _legacy_task_key_set),
3730            ("rprocessing", "processing_on", None),
3731            ("host_restrictions", "host_restrictions", None),
3732            ("worker_restrictions", "worker_restrictions", None),
3733            ("resource_restrictions", "resource_restrictions", None),
3734            ("suspicious_tasks", "suspicious", None),
3735            ("exceptions", "exception", None),
3736            ("tracebacks", "traceback", None),
3737            ("exceptions_blame", "exception_blame", _task_key_or_none),
3738        ]:
3739            func = operator.attrgetter(new_attr)
3740            if wrap is not None:
3741                func = compose(wrap, func)
3742            setattr(self, old_attr, _OptionalStateLegacyMapping(tasks, func))
3743
3744        for old_attr, new_attr, wrap in [
3745            ("loose_restrictions", "loose_restrictions", None)
3746        ]:
3747            func = operator.attrgetter(new_attr)
3748            if wrap is not None:
3749                func = compose(wrap, func)
3750            setattr(self, old_attr, _StateLegacySet(tasks, func))
3751
3752        self.generation = 0
3753        self._last_client = None
3754        self._last_time = 0
3755        unrunnable = set()
3756
3757        self.datasets = {}
3758
3759        # Prefix-keyed containers
3760
3761        # Client state
3762        clients = {}
3763        for old_attr, new_attr, wrap in [
3764            ("wants_what", "wants_what", _legacy_task_key_set)
3765        ]:
3766            func = operator.attrgetter(new_attr)
3767            if wrap is not None:
3768                func = compose(wrap, func)
3769            setattr(self, old_attr, _StateLegacyMapping(clients, func))
3770
3771        # Worker state
3772        workers = SortedDict()
3773        for old_attr, new_attr, wrap in [
3774            ("nthreads", "nthreads", None),
3775            ("worker_bytes", "nbytes", None),
3776            ("worker_resources", "resources", None),
3777            ("used_resources", "used_resources", None),
3778            ("occupancy", "occupancy", None),
3779            ("worker_info", "metrics", None),
3780            ("processing", "processing", _legacy_task_key_dict),
3781            ("has_what", "has_what", _legacy_task_key_set),
3782        ]:
3783            func = operator.attrgetter(new_attr)
3784            if wrap is not None:
3785                func = compose(wrap, func)
3786            setattr(self, old_attr, _StateLegacyMapping(workers, func))
3787
3788        host_info = {}
3789        resources = {}
3790        aliases = {}
3791
3792        self._task_state_collections = [unrunnable]
3793
3794        self._worker_collections = [
3795            workers,
3796            host_info,
3797            resources,
3798            aliases,
3799        ]
3800
3801        self.transition_log = deque(
3802            maxlen=dask.config.get("distributed.scheduler.transition-log-length")
3803        )
3804        self.log = deque(
3805            maxlen=dask.config.get("distributed.scheduler.transition-log-length")
3806        )
3807        self.events = defaultdict(
3808            partial(
3809                deque, maxlen=dask.config.get("distributed.scheduler.events-log-length")
3810            )
3811        )
3812        self.event_counts = defaultdict(int)
3813        self.event_subscriber = defaultdict(set)
3814        self.worker_plugins = {}
3815        self.nanny_plugins = {}
3816
3817        worker_handlers = {
3818            "task-finished": self.handle_task_finished,
3819            "task-erred": self.handle_task_erred,
3820            "release-worker-data": self.release_worker_data,
3821            "add-keys": self.add_keys,
3822            "missing-data": self.handle_missing_data,
3823            "long-running": self.handle_long_running,
3824            "reschedule": self.reschedule,
3825            "keep-alive": lambda *args, **kwargs: None,
3826            "log-event": self.log_worker_event,
3827            "worker-status-change": self.handle_worker_status_change,
3828        }
3829
3830        client_handlers = {
3831            "update-graph": self.update_graph,
3832            "update-graph-hlg": self.update_graph_hlg,
3833            "client-desires-keys": self.client_desires_keys,
3834            "update-data": self.update_data,
3835            "report-key": self.report_on_key,
3836            "client-releases-keys": self.client_releases_keys,
3837            "heartbeat-client": self.client_heartbeat,
3838            "close-client": self.remove_client,
3839            "restart": self.restart,
3840            "subscribe-topic": self.subscribe_topic,
3841            "unsubscribe-topic": self.unsubscribe_topic,
3842        }
3843
3844        self.handlers = {
3845            "register-client": self.add_client,
3846            "scatter": self.scatter,
3847            "register-worker": self.add_worker,
3848            "register_nanny": self.add_nanny,
3849            "unregister": self.remove_worker,
3850            "gather": self.gather,
3851            "cancel": self.stimulus_cancel,
3852            "retry": self.stimulus_retry,
3853            "feed": self.feed,
3854            "terminate": self.close,
3855            "broadcast": self.broadcast,
3856            "proxy": self.proxy,
3857            "ncores": self.get_ncores,
3858            "ncores_running": self.get_ncores_running,
3859            "has_what": self.get_has_what,
3860            "who_has": self.get_who_has,
3861            "processing": self.get_processing,
3862            "call_stack": self.get_call_stack,
3863            "profile": self.get_profile,
3864            "performance_report": self.performance_report,
3865            "get_logs": self.get_logs,
3866            "logs": self.get_logs,
3867            "worker_logs": self.get_worker_logs,
3868            "log_event": self.log_worker_event,
3869            "events": self.get_events,
3870            "nbytes": self.get_nbytes,
3871            "versions": self.versions,
3872            "add_keys": self.add_keys,
3873            "rebalance": self.rebalance,
3874            "replicate": self.replicate,
3875            "start_ipython": self.start_ipython,
3876            "run_function": self.run_function,
3877            "update_data": self.update_data,
3878            "set_resources": self.add_resources,
3879            "retire_workers": self.retire_workers,
3880            "get_metadata": self.get_metadata,
3881            "set_metadata": self.set_metadata,
3882            "set_restrictions": self.set_restrictions,
3883            "heartbeat_worker": self.heartbeat_worker,
3884            "get_task_status": self.get_task_status,
3885            "get_task_stream": self.get_task_stream,
3886            "register_scheduler_plugin": self.register_scheduler_plugin,
3887            "register_worker_plugin": self.register_worker_plugin,
3888            "unregister_worker_plugin": self.unregister_worker_plugin,
3889            "register_nanny_plugin": self.register_nanny_plugin,
3890            "unregister_nanny_plugin": self.unregister_nanny_plugin,
3891            "adaptive_target": self.adaptive_target,
3892            "workers_to_close": self.workers_to_close,
3893            "subscribe_worker_status": self.subscribe_worker_status,
3894            "start_task_metadata": self.start_task_metadata,
3895            "stop_task_metadata": self.stop_task_metadata,
3896        }
3897
3898        connection_limit = get_fileno_limit() / 2
3899
3900        super().__init__(
3901            # Arguments to SchedulerState
3902            aliases=aliases,
3903            clients=clients,
3904            workers=workers,
3905            host_info=host_info,
3906            resources=resources,
3907            tasks=tasks,
3908            unrunnable=unrunnable,
3909            validate=validate,
3910            plugins=plugins,
3911            # Arguments to ServerNode
3912            handlers=self.handlers,
3913            stream_handlers=merge(worker_handlers, client_handlers),
3914            io_loop=self.loop,
3915            connection_limit=connection_limit,
3916            deserialize=False,
3917            connection_args=self.connection_args,
3918            **kwargs,
3919        )
3920
3921        if self.worker_ttl:
3922            pc = PeriodicCallback(self.check_worker_ttl, self.worker_ttl)
3923            self.periodic_callbacks["worker-ttl"] = pc
3924
3925        if self.idle_timeout:
3926            pc = PeriodicCallback(self.check_idle, self.idle_timeout / 4)
3927            self.periodic_callbacks["idle-timeout"] = pc
3928
3929        if extensions is None:
3930            extensions = list(DEFAULT_EXTENSIONS)
3931            if dask.config.get("distributed.scheduler.work-stealing"):
3932                extensions.append(WorkStealing)
3933        for ext in extensions:
3934            ext(self)
3935
3936        setproctitle("dask-scheduler [not started]")
3937        Scheduler._instances.add(self)
3938        self.rpc.allow_offload = False
3939        self.status = Status.undefined
3940
3941    ##################
3942    # Administration #
3943    ##################
3944
3945    def __repr__(self):
3946        parent: SchedulerState = cast(SchedulerState, self)
3947        return '<Scheduler: "%s" workers: %d cores: %d, tasks: %d>' % (
3948            self.address,
3949            len(parent._workers_dv),
3950            parent._total_nthreads,
3951            len(parent._tasks),
3952        )
3953
3954    def _repr_html_(self):
3955        parent: SchedulerState = cast(SchedulerState, self)
3956        return get_template("scheduler.html.j2").render(
3957            address=self.address,
3958            workers=parent._workers_dv,
3959            threads=parent._total_nthreads,
3960            tasks=parent._tasks,
3961        )
3962
3963    def identity(self, comm=None):
3964        """Basic information about ourselves and our cluster"""
3965        parent: SchedulerState = cast(SchedulerState, self)
3966        d = {
3967            "type": type(self).__name__,
3968            "id": str(self.id),
3969            "address": self.address,
3970            "services": {key: v.port for (key, v) in self.services.items()},
3971            "started": self.time_started,
3972            "workers": {
3973                worker.address: worker.identity()
3974                for worker in parent._workers_dv.values()
3975            },
3976        }
3977        return d
3978
3979    def _to_dict(
3980        self, comm: Comm = None, *, exclude: Container[str] = None
3981    ) -> "dict[str, Any]":
3982        """
3983        A very verbose dictionary representation for debugging purposes.
3984        Not type stable and not inteded for roundtrips.
3985
3986        Parameters
3987        ----------
3988        comm:
3989        exclude:
3990            A list of attributes which must not be present in the output.
3991
3992        See also
3993        --------
3994        Server.identity
3995        Client.dump_cluster_state
3996        """
3997
3998        info = super()._to_dict(exclude=exclude)
3999        extra = {
4000            "transition_log": self.transition_log,
4001            "log": self.log,
4002            "tasks": self.tasks,
4003            "events": self.events,
4004        }
4005        info.update(extra)
4006        extensions = {}
4007        for name, ex in self.extensions.items():
4008            if hasattr(ex, "_to_dict"):
4009                extensions[name] = ex._to_dict()
4010        return recursive_to_dict(info, exclude=exclude)
4011
4012    def get_worker_service_addr(self, worker, service_name, protocol=False):
4013        """
4014        Get the (host, port) address of the named service on the *worker*.
4015        Returns None if the service doesn't exist.
4016
4017        Parameters
4018        ----------
4019        worker : address
4020        service_name : str
4021            Common services include 'bokeh' and 'nanny'
4022        protocol : boolean
4023            Whether or not to include a full address with protocol (True)
4024            or just a (host, port) pair
4025        """
4026        parent: SchedulerState = cast(SchedulerState, self)
4027        ws: WorkerState = parent._workers_dv[worker]
4028        port = ws._services.get(service_name)
4029        if port is None:
4030            return None
4031        elif protocol:
4032            return "%(protocol)s://%(host)s:%(port)d" % {
4033                "protocol": ws._address.split("://")[0],
4034                "host": ws.host,
4035                "port": port,
4036            }
4037        else:
4038            return ws.host, port
4039
4040    async def start(self):
4041        """Clear out old state and restart all running coroutines"""
4042        await super().start()
4043        assert self.status != Status.running
4044
4045        enable_gc_diagnosis()
4046
4047        self.clear_task_state()
4048
4049        with suppress(AttributeError):
4050            for c in self._worker_coroutines:
4051                c.cancel()
4052
4053        for addr in self._start_address:
4054            await self.listen(
4055                addr,
4056                allow_offload=False,
4057                handshake_overrides={"pickle-protocol": 4, "compression": None},
4058                **self.security.get_listen_args("scheduler"),
4059            )
4060            self.ip = get_address_host(self.listen_address)
4061            listen_ip = self.ip
4062
4063            if listen_ip == "0.0.0.0":
4064                listen_ip = ""
4065
4066        if self.address.startswith("inproc://"):
4067            listen_ip = "localhost"
4068
4069        # Services listen on all addresses
4070        self.start_services(listen_ip)
4071
4072        for listener in self.listeners:
4073            logger.info("  Scheduler at: %25s", listener.contact_address)
4074        for k, v in self.services.items():
4075            logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port))
4076
4077        self.loop.add_callback(self.reevaluate_occupancy)
4078
4079        if self.scheduler_file:
4080            with open(self.scheduler_file, "w") as f:
4081                json.dump(self.identity(), f, indent=2)
4082
4083            fn = self.scheduler_file  # remove file when we close the process
4084
4085            def del_scheduler_file():
4086                if os.path.exists(fn):
4087                    os.remove(fn)
4088
4089            weakref.finalize(self, del_scheduler_file)
4090
4091        for preload in self.preloads:
4092            await preload.start()
4093
4094        await asyncio.gather(
4095            *[plugin.start(self) for plugin in list(self.plugins.values())]
4096        )
4097
4098        self.start_periodic_callbacks()
4099
4100        setproctitle(f"dask-scheduler [{self.address}]")
4101        return self
4102
4103    async def close(self, comm=None, fast=False, close_workers=False):
4104        """Send cleanup signal to all coroutines then wait until finished
4105
4106        See Also
4107        --------
4108        Scheduler.cleanup
4109        """
4110        parent: SchedulerState = cast(SchedulerState, self)
4111        if self.status in (Status.closing, Status.closed):
4112            await self.finished()
4113            return
4114        self.status = Status.closing
4115
4116        logger.info("Scheduler closing...")
4117        setproctitle("dask-scheduler [closing]")
4118
4119        for preload in self.preloads:
4120            await preload.teardown()
4121
4122        if close_workers:
4123            await self.broadcast(msg={"op": "close_gracefully"}, nanny=True)
4124            for worker in parent._workers_dv:
4125                # Report would require the worker to unregister with the
4126                # currently closing scheduler. This is not necessary and might
4127                # delay shutdown of the worker unnecessarily
4128                self.worker_send(worker, {"op": "close", "report": False})
4129            for i in range(20):  # wait a second for send signals to clear
4130                if parent._workers_dv:
4131                    await asyncio.sleep(0.05)
4132                else:
4133                    break
4134
4135        await asyncio.gather(
4136            *[plugin.close() for plugin in list(self.plugins.values())]
4137        )
4138
4139        for pc in self.periodic_callbacks.values():
4140            pc.stop()
4141        self.periodic_callbacks.clear()
4142
4143        self.stop_services()
4144
4145        for ext in parent._extensions.values():
4146            with suppress(AttributeError):
4147                ext.teardown()
4148        logger.info("Scheduler closing all comms")
4149
4150        futures = []
4151        for w, comm in list(self.stream_comms.items()):
4152            if not comm.closed():
4153                comm.send({"op": "close", "report": False})
4154                comm.send({"op": "close-stream"})
4155            with suppress(AttributeError):
4156                futures.append(comm.close())
4157
4158        for future in futures:  # TODO: do all at once
4159            await future
4160
4161        for comm in self.client_comms.values():
4162            comm.abort()
4163
4164        await self.rpc.close()
4165
4166        self.status = Status.closed
4167        self.stop()
4168        await super().close()
4169
4170        setproctitle("dask-scheduler [closed]")
4171        disable_gc_diagnosis()
4172
4173    async def close_worker(self, comm=None, worker=None, safe=None):
4174        """Remove a worker from the cluster
4175
4176        This both removes the worker from our local state and also sends a
4177        signal to the worker to shut down.  This works regardless of whether or
4178        not the worker has a nanny process restarting it
4179        """
4180        logger.info("Closing worker %s", worker)
4181        with log_errors():
4182            self.log_event(worker, {"action": "close-worker"})
4183            # FIXME: This does not handle nannies
4184            self.worker_send(worker, {"op": "close", "report": False})
4185            await self.remove_worker(address=worker, safe=safe)
4186
4187    ###########
4188    # Stimuli #
4189    ###########
4190
4191    def heartbeat_worker(
4192        self,
4193        comm=None,
4194        *,
4195        address,
4196        resolve_address: bool = True,
4197        now: float = None,
4198        resources: dict = None,
4199        host_info: dict = None,
4200        metrics: dict,
4201        executing: dict = None,
4202    ):
4203        parent: SchedulerState = cast(SchedulerState, self)
4204        address = self.coerce_address(address, resolve_address)
4205        address = normalize_address(address)
4206        ws: WorkerState = parent._workers_dv.get(address)  # type: ignore
4207        if ws is None:
4208            return {"status": "missing"}
4209
4210        host = get_address_host(address)
4211        local_now = time()
4212        host_info = host_info or {}
4213
4214        dh: dict = parent._host_info.setdefault(host, {})
4215        dh["last-seen"] = local_now
4216
4217        frac = 1 / len(parent._workers_dv)
4218        parent._bandwidth = (
4219            parent._bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac
4220        )
4221        for other, (bw, count) in metrics["bandwidth"]["workers"].items():
4222            if (address, other) not in self.bandwidth_workers:
4223                self.bandwidth_workers[address, other] = bw / count
4224            else:
4225                alpha = (1 - frac) ** count
4226                self.bandwidth_workers[address, other] = self.bandwidth_workers[
4227                    address, other
4228                ] * alpha + bw * (1 - alpha)
4229        for typ, (bw, count) in metrics["bandwidth"]["types"].items():
4230            if typ not in self.bandwidth_types:
4231                self.bandwidth_types[typ] = bw / count
4232            else:
4233                alpha = (1 - frac) ** count
4234                self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * (
4235                    1 - alpha
4236                )
4237
4238        ws._last_seen = local_now
4239        if executing is not None:
4240            ws._executing = {
4241                parent._tasks[key]: duration
4242                for key, duration in executing.items()
4243                if key in parent._tasks
4244            }
4245
4246        ws._metrics = metrics
4247
4248        # Calculate RSS - dask keys, separating "old" and "new" usage
4249        # See MemoryState for details
4250        max_memory_unmanaged_old_hist_age = local_now - parent.MEMORY_RECENT_TO_OLD_TIME
4251        memory_unmanaged_old = ws._memory_unmanaged_old
4252        while ws._memory_other_history:
4253            timestamp, size = ws._memory_other_history[0]
4254            if timestamp >= max_memory_unmanaged_old_hist_age:
4255                break
4256            ws._memory_other_history.popleft()
4257            if size == memory_unmanaged_old:
4258                memory_unmanaged_old = 0  # recalculate min()
4259
4260        # metrics["memory"] is None if the worker sent a heartbeat before its
4261        # SystemMonitor ever had a chance to run.
4262        # ws._nbytes is updated at a different time and sizeof() may not be accurate,
4263        # so size may be (temporarily) negative; floor it to zero.
4264        size = max(0, (metrics["memory"] or 0) - ws._nbytes + metrics["spilled_nbytes"])
4265
4266        ws._memory_other_history.append((local_now, size))
4267        if not memory_unmanaged_old:
4268            # The worker has just been started or the previous minimum has been expunged
4269            # because too old.
4270            # Note: this algorithm is capped to 200 * MEMORY_RECENT_TO_OLD_TIME elements
4271            # cluster-wide by heartbeat_interval(), regardless of the number of workers
4272            ws._memory_unmanaged_old = min(map(second, ws._memory_other_history))
4273        elif size < memory_unmanaged_old:
4274            ws._memory_unmanaged_old = size
4275
4276        if host_info:
4277            dh = parent._host_info.setdefault(host, {})
4278            dh.update(host_info)
4279
4280        if now:
4281            ws._time_delay = local_now - now
4282
4283        if resources:
4284            self.add_resources(worker=address, resources=resources)
4285
4286        self.log_event(address, merge({"action": "heartbeat"}, metrics))
4287
4288        return {
4289            "status": "OK",
4290            "time": local_now,
4291            "heartbeat-interval": heartbeat_interval(len(parent._workers_dv)),
4292        }
4293
4294    async def add_worker(
4295        self,
4296        comm=None,
4297        *,
4298        address: str,
4299        status: str,
4300        keys=(),
4301        nthreads=None,
4302        name=None,
4303        resolve_address=True,
4304        nbytes=None,
4305        types=None,
4306        now=None,
4307        resources=None,
4308        host_info=None,
4309        memory_limit=None,
4310        metrics=None,
4311        pid=0,
4312        services=None,
4313        local_directory=None,
4314        versions=None,
4315        nanny=None,
4316        extra=None,
4317    ):
4318        """Add a new worker to the cluster"""
4319        parent: SchedulerState = cast(SchedulerState, self)
4320        with log_errors():
4321            address = self.coerce_address(address, resolve_address)
4322            address = normalize_address(address)
4323            host = get_address_host(address)
4324
4325            if address in parent._workers_dv:
4326                raise ValueError("Worker already exists %s" % address)
4327
4328            if name in parent._aliases:
4329                logger.warning(
4330                    "Worker tried to connect with a duplicate name: %s", name
4331                )
4332                msg = {
4333                    "status": "error",
4334                    "message": "name taken, %s" % name,
4335                    "time": time(),
4336                }
4337                if comm:
4338                    await comm.write(msg)
4339                return
4340
4341            ws: WorkerState
4342            parent._workers[address] = ws = WorkerState(
4343                address=address,
4344                status=Status.lookup[status],  # type: ignore
4345                pid=pid,
4346                nthreads=nthreads,
4347                memory_limit=memory_limit or 0,
4348                name=name,
4349                local_directory=local_directory,
4350                services=services,
4351                versions=versions,
4352                nanny=nanny,
4353                extra=extra,
4354            )
4355            if ws._status == Status.running:
4356                parent._running.add(ws)
4357
4358            dh: dict = parent._host_info.get(host)  # type: ignore
4359            if dh is None:
4360                parent._host_info[host] = dh = {}
4361
4362            dh_addresses: set = dh.get("addresses")  # type: ignore
4363            if dh_addresses is None:
4364                dh["addresses"] = dh_addresses = set()
4365                dh["nthreads"] = 0
4366
4367            dh_addresses.add(address)
4368            dh["nthreads"] += nthreads
4369
4370            parent._total_nthreads += nthreads
4371            parent._aliases[name] = address
4372
4373            self.heartbeat_worker(
4374                address=address,
4375                resolve_address=resolve_address,
4376                now=now,
4377                resources=resources,
4378                host_info=host_info,
4379                metrics=metrics,
4380            )
4381
4382            # Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot
4383            # exist before this.
4384            self.check_idle_saturated(ws)
4385
4386            # for key in keys:  # TODO
4387            #     self.mark_key_in_memory(key, [address])
4388
4389            self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop)
4390
4391            if ws._nthreads > len(ws._processing):
4392                parent._idle[ws._address] = ws
4393
4394            for plugin in list(self.plugins.values()):
4395                try:
4396                    result = plugin.add_worker(scheduler=self, worker=address)
4397                    if inspect.isawaitable(result):
4398                        await result
4399                except Exception as e:
4400                    logger.exception(e)
4401
4402            recommendations: dict = {}
4403            client_msgs: dict = {}
4404            worker_msgs: dict = {}
4405            if nbytes:
4406                assert isinstance(nbytes, dict)
4407                already_released_keys = []
4408                for key in nbytes:
4409                    ts: TaskState = parent._tasks.get(key)  # type: ignore
4410                    if ts is not None and ts.state != "released":
4411                        if ts.state == "memory":
4412                            self.add_keys(worker=address, keys=[key])
4413                        else:
4414                            t: tuple = parent._transition(
4415                                key,
4416                                "memory",
4417                                worker=address,
4418                                nbytes=nbytes[key],
4419                                typename=types[key],
4420                            )
4421                            recommendations, client_msgs, worker_msgs = t
4422                            parent._transitions(
4423                                recommendations, client_msgs, worker_msgs
4424                            )
4425                            recommendations = {}
4426                    else:
4427                        already_released_keys.append(key)
4428                if already_released_keys:
4429                    if address not in worker_msgs:
4430                        worker_msgs[address] = []
4431                    worker_msgs[address].append(
4432                        {
4433                            "op": "remove-replicas",
4434                            "keys": already_released_keys,
4435                            "stimulus_id": f"reconnect-already-released-{time()}",
4436                        }
4437                    )
4438
4439            if ws._status == Status.running:
4440                for ts in parent._unrunnable:
4441                    valid: set = self.valid_workers(ts)
4442                    if valid is None or ws in valid:
4443                        recommendations[ts._key] = "waiting"
4444
4445            if recommendations:
4446                parent._transitions(recommendations, client_msgs, worker_msgs)
4447
4448            self.send_all(client_msgs, worker_msgs)
4449
4450            self.log_event(address, {"action": "add-worker"})
4451            self.log_event("all", {"action": "add-worker", "worker": address})
4452            logger.info("Register worker %s", ws)
4453
4454            msg = {
4455                "status": "OK",
4456                "time": time(),
4457                "heartbeat-interval": heartbeat_interval(len(parent._workers_dv)),
4458                "worker-plugins": self.worker_plugins,
4459            }
4460
4461            cs: ClientState
4462            version_warning = version_module.error_message(
4463                version_module.get_versions(),
4464                merge(
4465                    {w: ws._versions for w, ws in parent._workers_dv.items()},
4466                    {
4467                        c: cs._versions
4468                        for c, cs in parent._clients.items()
4469                        if cs._versions
4470                    },
4471                ),
4472                versions,
4473                client_name="This Worker",
4474            )
4475            msg.update(version_warning)
4476
4477            if comm:
4478                await comm.write(msg)
4479
4480            await self.handle_worker(comm=comm, worker=address)
4481
4482    async def add_nanny(self, comm):
4483        msg = {
4484            "status": "OK",
4485            "nanny-plugins": self.nanny_plugins,
4486        }
4487        return msg
4488
4489    def update_graph_hlg(
4490        self,
4491        client=None,
4492        hlg=None,
4493        keys=None,
4494        dependencies=None,
4495        restrictions=None,
4496        priority=None,
4497        loose_restrictions=None,
4498        resources=None,
4499        submitting_task=None,
4500        retries=None,
4501        user_priority=0,
4502        actors=None,
4503        fifo_timeout=0,
4504        code=None,
4505    ):
4506        unpacked_graph = HighLevelGraph.__dask_distributed_unpack__(hlg)
4507        dsk = unpacked_graph["dsk"]
4508        dependencies = unpacked_graph["deps"]
4509        annotations = unpacked_graph["annotations"]
4510
4511        # Remove any self-dependencies (happens on test_publish_bag() and others)
4512        for k, v in dependencies.items():
4513            deps = set(v)
4514            if k in deps:
4515                deps.remove(k)
4516            dependencies[k] = deps
4517
4518        if priority is None:
4519            # Removing all non-local keys before calling order()
4520            dsk_keys = set(dsk)  # intersection() of sets is much faster than dict_keys
4521            stripped_deps = {
4522                k: v.intersection(dsk_keys)
4523                for k, v in dependencies.items()
4524                if k in dsk_keys
4525            }
4526            priority = dask.order.order(dsk, dependencies=stripped_deps)
4527
4528        return self.update_graph(
4529            client,
4530            dsk,
4531            keys,
4532            dependencies,
4533            restrictions,
4534            priority,
4535            loose_restrictions,
4536            resources,
4537            submitting_task,
4538            retries,
4539            user_priority,
4540            actors,
4541            fifo_timeout,
4542            annotations,
4543            code=code,
4544        )
4545
4546    def update_graph(
4547        self,
4548        client=None,
4549        tasks=None,
4550        keys=None,
4551        dependencies=None,
4552        restrictions=None,
4553        priority=None,
4554        loose_restrictions=None,
4555        resources=None,
4556        submitting_task=None,
4557        retries=None,
4558        user_priority=0,
4559        actors=None,
4560        fifo_timeout=0,
4561        annotations=None,
4562        code=None,
4563    ):
4564        """
4565        Add new computations to the internal dask graph
4566
4567        This happens whenever the Client calls submit, map, get, or compute.
4568        """
4569        parent: SchedulerState = cast(SchedulerState, self)
4570        start = time()
4571        fifo_timeout = parse_timedelta(fifo_timeout)
4572        keys = set(keys)
4573        if len(tasks) > 1:
4574            self.log_event(
4575                ["all", client], {"action": "update_graph", "count": len(tasks)}
4576            )
4577
4578        # Remove aliases
4579        for k in list(tasks):
4580            if tasks[k] is k:
4581                del tasks[k]
4582
4583        dependencies = dependencies or {}
4584
4585        if parent._total_occupancy > 1e-9 and parent._computations:
4586            # Still working on something. Assign new tasks to same computation
4587            computation = cast(Computation, parent._computations[-1])
4588        else:
4589            computation = Computation()
4590            parent._computations.append(computation)
4591
4592        if code and code not in computation._code:  # add new code blocks
4593            computation._code.add(code)
4594
4595        n = 0
4596        while len(tasks) != n:  # walk through new tasks, cancel any bad deps
4597            n = len(tasks)
4598            for k, deps in list(dependencies.items()):
4599                if any(
4600                    dep not in parent._tasks and dep not in tasks for dep in deps
4601                ):  # bad key
4602                    logger.info("User asked for computation on lost data, %s", k)
4603                    del tasks[k]
4604                    del dependencies[k]
4605                    if k in keys:
4606                        keys.remove(k)
4607                    self.report({"op": "cancelled-key", "key": k}, client=client)
4608                    self.client_releases_keys(keys=[k], client=client)
4609
4610        # Avoid computation that is already finished
4611        ts: TaskState
4612        already_in_memory = set()  # tasks that are already done
4613        for k, v in dependencies.items():
4614            if v and k in parent._tasks:
4615                ts = parent._tasks[k]
4616                if ts._state in ("memory", "erred"):
4617                    already_in_memory.add(k)
4618
4619        dts: TaskState
4620        if already_in_memory:
4621            dependents = dask.core.reverse_dict(dependencies)
4622            stack = list(already_in_memory)
4623            done = set(already_in_memory)
4624            while stack:  # remove unnecessary dependencies
4625                key = stack.pop()
4626                ts = parent._tasks[key]
4627                try:
4628                    deps = dependencies[key]
4629                except KeyError:
4630                    deps = self.dependencies[key]
4631                for dep in deps:
4632                    if dep in dependents:
4633                        child_deps = dependents[dep]
4634                    else:
4635                        child_deps = self.dependencies[dep]
4636                    if all(d in done for d in child_deps):
4637                        if dep in parent._tasks and dep not in done:
4638                            done.add(dep)
4639                            stack.append(dep)
4640
4641            for d in done:
4642                tasks.pop(d, None)
4643                dependencies.pop(d, None)
4644
4645        # Get or create task states
4646        stack = list(keys)
4647        touched_keys = set()
4648        touched_tasks = []
4649        while stack:
4650            k = stack.pop()
4651            if k in touched_keys:
4652                continue
4653            # XXX Have a method get_task_state(self, k) ?
4654            ts = parent._tasks.get(k)
4655            if ts is None:
4656                ts = parent.new_task(
4657                    k, tasks.get(k), "released", computation=computation
4658                )
4659            elif not ts._run_spec:
4660                ts._run_spec = tasks.get(k)
4661
4662            touched_keys.add(k)
4663            touched_tasks.append(ts)
4664            stack.extend(dependencies.get(k, ()))
4665
4666        self.client_desires_keys(keys=keys, client=client)
4667
4668        # Add dependencies
4669        for key, deps in dependencies.items():
4670            ts = parent._tasks.get(key)
4671            if ts is None or ts._dependencies:
4672                continue
4673            for dep in deps:
4674                dts = parent._tasks[dep]
4675                ts.add_dependency(dts)
4676
4677        # Compute priorities
4678        if isinstance(user_priority, Number):
4679            user_priority = {k: user_priority for k in tasks}
4680
4681        annotations = annotations or {}
4682        restrictions = restrictions or {}
4683        loose_restrictions = loose_restrictions or []
4684        resources = resources or {}
4685        retries = retries or {}
4686
4687        # Override existing taxonomy with per task annotations
4688        if annotations:
4689            if "priority" in annotations:
4690                user_priority.update(annotations["priority"])
4691
4692            if "workers" in annotations:
4693                restrictions.update(annotations["workers"])
4694
4695            if "allow_other_workers" in annotations:
4696                loose_restrictions.extend(
4697                    k for k, v in annotations["allow_other_workers"].items() if v
4698                )
4699
4700            if "retries" in annotations:
4701                retries.update(annotations["retries"])
4702
4703            if "resources" in annotations:
4704                resources.update(annotations["resources"])
4705
4706            for a, kv in annotations.items():
4707                for k, v in kv.items():
4708                    # Tasks might have been culled, in which case
4709                    # we have nothing to annotate.
4710                    ts = parent._tasks.get(k)
4711                    if ts is not None:
4712                        ts._annotations[a] = v
4713
4714        # Add actors
4715        if actors is True:
4716            actors = list(keys)
4717        for actor in actors or []:
4718            ts = parent._tasks[actor]
4719            ts._actor = True
4720
4721        priority = priority or dask.order.order(
4722            tasks
4723        )  # TODO: define order wrt old graph
4724
4725        if submitting_task:  # sub-tasks get better priority than parent tasks
4726            ts = parent._tasks.get(submitting_task)
4727            if ts is not None:
4728                generation = ts._priority[0] - 0.01
4729            else:  # super-task already cleaned up
4730                generation = self.generation
4731        elif self._last_time + fifo_timeout < start:
4732            self.generation += 1  # older graph generations take precedence
4733            generation = self.generation
4734            self._last_time = start
4735        else:
4736            generation = self.generation
4737
4738        for key in set(priority) & touched_keys:
4739            ts = parent._tasks[key]
4740            if ts._priority is None:
4741                ts._priority = (-(user_priority.get(key, 0)), generation, priority[key])
4742
4743        # Ensure all runnables have a priority
4744        runnables = [ts for ts in touched_tasks if ts._run_spec]
4745        for ts in runnables:
4746            if ts._priority is None and ts._run_spec:
4747                ts._priority = (self.generation, 0)
4748
4749        if restrictions:
4750            # *restrictions* is a dict keying task ids to lists of
4751            # restriction specifications (either worker names or addresses)
4752            for k, v in restrictions.items():
4753                if v is None:
4754                    continue
4755                ts = parent._tasks.get(k)
4756                if ts is None:
4757                    continue
4758                ts._host_restrictions = set()
4759                ts._worker_restrictions = set()
4760                # Make sure `v` is a collection and not a single worker name / address
4761                if not isinstance(v, (list, tuple, set)):
4762                    v = [v]
4763                for w in v:
4764                    try:
4765                        w = self.coerce_address(w)
4766                    except ValueError:
4767                        # Not a valid address, but perhaps it's a hostname
4768                        ts._host_restrictions.add(w)
4769                    else:
4770                        ts._worker_restrictions.add(w)
4771
4772            if loose_restrictions:
4773                for k in loose_restrictions:
4774                    ts = parent._tasks[k]
4775                    ts._loose_restrictions = True
4776
4777        if resources:
4778            for k, v in resources.items():
4779                if v is None:
4780                    continue
4781                assert isinstance(v, dict)
4782                ts = parent._tasks.get(k)
4783                if ts is None:
4784                    continue
4785                ts._resource_restrictions = v
4786
4787        if retries:
4788            for k, v in retries.items():
4789                assert isinstance(v, int)
4790                ts = parent._tasks.get(k)
4791                if ts is None:
4792                    continue
4793                ts._retries = v
4794
4795        # Compute recommendations
4796        recommendations: dict = {}
4797
4798        for ts in sorted(runnables, key=operator.attrgetter("priority"), reverse=True):
4799            if ts._state == "released" and ts._run_spec:
4800                recommendations[ts._key] = "waiting"
4801
4802        for ts in touched_tasks:
4803            for dts in ts._dependencies:
4804                if dts._exception_blame:
4805                    ts._exception_blame = dts._exception_blame
4806                    recommendations[ts._key] = "erred"
4807                    break
4808
4809        for plugin in list(self.plugins.values()):
4810            try:
4811                plugin.update_graph(
4812                    self,
4813                    client=client,
4814                    tasks=tasks,
4815                    keys=keys,
4816                    restrictions=restrictions or {},
4817                    dependencies=dependencies,
4818                    priority=priority,
4819                    loose_restrictions=loose_restrictions,
4820                    resources=resources,
4821                    annotations=annotations,
4822                )
4823            except Exception as e:
4824                logger.exception(e)
4825
4826        self.transitions(recommendations)
4827
4828        for ts in touched_tasks:
4829            if ts._state in ("memory", "erred"):
4830                self.report_on_key(ts=ts, client=client)
4831
4832        end = time()
4833        if self.digests is not None:
4834            self.digests["update-graph-duration"].add(end - start)
4835
4836        # TODO: balance workers
4837
4838    def stimulus_task_finished(self, key=None, worker=None, **kwargs):
4839        """Mark that a task has finished execution on a particular worker"""
4840        parent: SchedulerState = cast(SchedulerState, self)
4841        logger.debug("Stimulus task finished %s, %s", key, worker)
4842
4843        recommendations: dict = {}
4844        client_msgs: dict = {}
4845        worker_msgs: dict = {}
4846
4847        ws: WorkerState = parent._workers_dv[worker]
4848        ts: TaskState = parent._tasks.get(key)
4849        if ts is None or ts._state == "released":
4850            logger.debug(
4851                "Received already computed task, worker: %s, state: %s"
4852                ", key: %s, who_has: %s",
4853                worker,
4854                ts._state if ts else "forgotten",
4855                key,
4856                ts._who_has if ts else {},
4857            )
4858            worker_msgs[worker] = [
4859                {
4860                    "op": "free-keys",
4861                    "keys": [key],
4862                    "stimulus_id": f"already-released-or-forgotten-{time()}",
4863                }
4864            ]
4865        elif ts._state == "memory":
4866            self.add_keys(worker=worker, keys=[key])
4867        else:
4868            ts._metadata.update(kwargs["metadata"])
4869            r: tuple = parent._transition(key, "memory", worker=worker, **kwargs)
4870            recommendations, client_msgs, worker_msgs = r
4871
4872            if ts._state == "memory":
4873                assert ws in ts._who_has
4874        return recommendations, client_msgs, worker_msgs
4875
4876    def stimulus_task_erred(
4877        self, key=None, worker=None, exception=None, traceback=None, **kwargs
4878    ):
4879        """Mark that a task has erred on a particular worker"""
4880        parent: SchedulerState = cast(SchedulerState, self)
4881        logger.debug("Stimulus task erred %s, %s", key, worker)
4882
4883        ts: TaskState = parent._tasks.get(key)
4884        if ts is None or ts._state != "processing":
4885            return {}, {}, {}
4886
4887        if ts._retries > 0:
4888            ts._retries -= 1
4889            return parent._transition(key, "waiting")
4890        else:
4891            return parent._transition(
4892                key,
4893                "erred",
4894                cause=key,
4895                exception=exception,
4896                traceback=traceback,
4897                worker=worker,
4898                **kwargs,
4899            )
4900
4901    def stimulus_retry(self, comm=None, keys=None, client=None):
4902        parent: SchedulerState = cast(SchedulerState, self)
4903        logger.info("Client %s requests to retry %d keys", client, len(keys))
4904        if client:
4905            self.log_event(client, {"action": "retry", "count": len(keys)})
4906
4907        stack = list(keys)
4908        seen = set()
4909        roots = []
4910        ts: TaskState
4911        dts: TaskState
4912        while stack:
4913            key = stack.pop()
4914            seen.add(key)
4915            ts = parent._tasks[key]
4916            erred_deps = [dts._key for dts in ts._dependencies if dts._state == "erred"]
4917            if erred_deps:
4918                stack.extend(erred_deps)
4919            else:
4920                roots.append(key)
4921
4922        recommendations: dict = {key: "waiting" for key in roots}
4923        self.transitions(recommendations)
4924
4925        if parent._validate:
4926            for key in seen:
4927                assert not parent._tasks[key].exception_blame
4928
4929        return tuple(seen)
4930
4931    async def remove_worker(self, comm=None, address=None, safe=False, close=True):
4932        """
4933        Remove worker from cluster
4934
4935        We do this when a worker reports that it plans to leave or when it
4936        appears to be unresponsive.  This may send its tasks back to a released
4937        state.
4938        """
4939        parent: SchedulerState = cast(SchedulerState, self)
4940        with log_errors():
4941            if self.status == Status.closed:
4942                return
4943
4944            address = self.coerce_address(address)
4945
4946            if address not in parent._workers_dv:
4947                return "already-removed"
4948
4949            host = get_address_host(address)
4950
4951            ws: WorkerState = parent._workers_dv[address]
4952
4953            self.log_event(
4954                ["all", address],
4955                {
4956                    "action": "remove-worker",
4957                    "processing-tasks": dict(ws._processing),
4958                },
4959            )
4960            logger.info("Remove worker %s", ws)
4961            if close:
4962                with suppress(AttributeError, CommClosedError):
4963                    self.stream_comms[address].send({"op": "close", "report": False})
4964
4965            self.remove_resources(address)
4966
4967            dh: dict = parent._host_info.get(host)
4968            if dh is None:
4969                parent._host_info[host] = dh = {}
4970
4971            dh_addresses: set = dh["addresses"]
4972            dh_addresses.remove(address)
4973            dh["nthreads"] -= ws._nthreads
4974            parent._total_nthreads -= ws._nthreads
4975
4976            if not dh_addresses:
4977                dh = None
4978                dh_addresses = None
4979                del parent._host_info[host]
4980
4981            self.rpc.remove(address)
4982            del self.stream_comms[address]
4983            del parent._aliases[ws._name]
4984            parent._idle.pop(ws._address, None)
4985            parent._saturated.discard(ws)
4986            del parent._workers[address]
4987            ws.status = Status.closed
4988            parent._running.discard(ws)
4989            parent._total_occupancy -= ws._occupancy
4990
4991            recommendations: dict = {}
4992
4993            ts: TaskState
4994            for ts in list(ws._processing):
4995                k = ts._key
4996                recommendations[k] = "released"
4997                if not safe:
4998                    ts._suspicious += 1
4999                    ts._prefix._suspicious += 1
5000                    if ts._suspicious > self.allowed_failures:
5001                        del recommendations[k]
5002                        e = pickle.dumps(
5003                            KilledWorker(task=k, last_worker=ws.clean()), protocol=4
5004                        )
5005                        r = self.transition(k, "erred", exception=e, cause=k)
5006                        recommendations.update(r)
5007                        logger.info(
5008                            "Task %s marked as failed because %d workers died"
5009                            " while trying to run it",
5010                            ts._key,
5011                            self.allowed_failures,
5012                        )
5013
5014            for ts in list(ws._has_what):
5015                parent.remove_replica(ts, ws)
5016                if not ts._who_has:
5017                    if ts._run_spec:
5018                        recommendations[ts._key] = "released"
5019                    else:  # pure data
5020                        recommendations[ts._key] = "forgotten"
5021
5022            self.transitions(recommendations)
5023
5024            for plugin in list(self.plugins.values()):
5025                try:
5026                    result = plugin.remove_worker(scheduler=self, worker=address)
5027                    if inspect.isawaitable(result):
5028                        await result
5029                except Exception as e:
5030                    logger.exception(e)
5031
5032            if not parent._workers_dv:
5033                logger.info("Lost all workers")
5034
5035            for w in parent._workers_dv:
5036                self.bandwidth_workers.pop((address, w), None)
5037                self.bandwidth_workers.pop((w, address), None)
5038
5039            def remove_worker_from_events():
5040                # If the worker isn't registered anymore after the delay, remove from events
5041                if address not in parent._workers_dv and address in self.events:
5042                    del self.events[address]
5043
5044            cleanup_delay = parse_timedelta(
5045                dask.config.get("distributed.scheduler.events-cleanup-delay")
5046            )
5047            self.loop.call_later(cleanup_delay, remove_worker_from_events)
5048            logger.debug("Removed worker %s", ws)
5049
5050        return "OK"
5051
5052    def stimulus_cancel(self, comm, keys=None, client=None, force=False):
5053        """Stop execution on a list of keys"""
5054        logger.info("Client %s requests to cancel %d keys", client, len(keys))
5055        if client:
5056            self.log_event(
5057                client, {"action": "cancel", "count": len(keys), "force": force}
5058            )
5059        for key in keys:
5060            self.cancel_key(key, client, force=force)
5061
5062    def cancel_key(self, key, client, retries=5, force=False):
5063        """Cancel a particular key and all dependents"""
5064        # TODO: this should be converted to use the transition mechanism
5065        parent: SchedulerState = cast(SchedulerState, self)
5066        ts: TaskState = parent._tasks.get(key)
5067        dts: TaskState
5068        try:
5069            cs: ClientState = parent._clients[client]
5070        except KeyError:
5071            return
5072        if ts is None or not ts._who_wants:  # no key yet, lets try again in a moment
5073            if retries:
5074                self.loop.call_later(
5075                    0.2, lambda: self.cancel_key(key, client, retries - 1)
5076                )
5077            return
5078        if force or ts._who_wants == {cs}:  # no one else wants this key
5079            for dts in list(ts._dependents):
5080                self.cancel_key(dts._key, client, force=force)
5081        logger.info("Scheduler cancels key %s.  Force=%s", key, force)
5082        self.report({"op": "cancelled-key", "key": key})
5083        clients = list(ts._who_wants) if force else [cs]
5084        for cs in clients:
5085            self.client_releases_keys(keys=[key], client=cs._client_key)
5086
5087    def client_desires_keys(self, keys=None, client=None):
5088        parent: SchedulerState = cast(SchedulerState, self)
5089        cs: ClientState = parent._clients.get(client)
5090        if cs is None:
5091            # For publish, queues etc.
5092            parent._clients[client] = cs = ClientState(client)
5093        ts: TaskState
5094        for k in keys:
5095            ts = parent._tasks.get(k)
5096            if ts is None:
5097                # For publish, queues etc.
5098                ts = parent.new_task(k, None, "released")
5099            ts._who_wants.add(cs)
5100            cs._wants_what.add(ts)
5101
5102            if ts._state in ("memory", "erred"):
5103                self.report_on_key(ts=ts, client=client)
5104
5105    def client_releases_keys(self, keys=None, client=None):
5106        """Remove keys from client desired list"""
5107
5108        parent: SchedulerState = cast(SchedulerState, self)
5109        if not isinstance(keys, list):
5110            keys = list(keys)
5111        cs: ClientState = parent._clients[client]
5112        recommendations: dict = {}
5113
5114        _client_releases_keys(parent, keys=keys, cs=cs, recommendations=recommendations)
5115        self.transitions(recommendations)
5116
5117    def client_heartbeat(self, client=None):
5118        """Handle heartbeats from Client"""
5119        parent: SchedulerState = cast(SchedulerState, self)
5120        cs: ClientState = parent._clients[client]
5121        cs._last_seen = time()
5122
5123    ###################
5124    # Task Validation #
5125    ###################
5126
5127    def validate_released(self, key):
5128        parent: SchedulerState = cast(SchedulerState, self)
5129        ts: TaskState = parent._tasks[key]
5130        dts: TaskState
5131        assert ts._state == "released"
5132        assert not ts._waiters
5133        assert not ts._waiting_on
5134        assert not ts._who_has
5135        assert not ts._processing_on
5136        assert not any([ts in dts._waiters for dts in ts._dependencies])
5137        assert ts not in parent._unrunnable
5138
5139    def validate_waiting(self, key):
5140        parent: SchedulerState = cast(SchedulerState, self)
5141        ts: TaskState = parent._tasks[key]
5142        dts: TaskState
5143        assert ts._waiting_on
5144        assert not ts._who_has
5145        assert not ts._processing_on
5146        assert ts not in parent._unrunnable
5147        for dts in ts._dependencies:
5148            # We are waiting on a dependency iff it's not stored
5149            assert bool(dts._who_has) != (dts in ts._waiting_on)
5150            assert ts in dts._waiters  # XXX even if dts._who_has?
5151
5152    def validate_processing(self, key):
5153        parent: SchedulerState = cast(SchedulerState, self)
5154        ts: TaskState = parent._tasks[key]
5155        dts: TaskState
5156        assert not ts._waiting_on
5157        ws: WorkerState = ts._processing_on
5158        assert ws
5159        assert ts in ws._processing
5160        assert not ts._who_has
5161        for dts in ts._dependencies:
5162            assert dts._who_has
5163            assert ts in dts._waiters
5164
5165    def validate_memory(self, key):
5166        parent: SchedulerState = cast(SchedulerState, self)
5167        ts: TaskState = parent._tasks[key]
5168        dts: TaskState
5169        assert ts._who_has
5170        assert bool(ts in parent._replicated_tasks) == (len(ts._who_has) > 1)
5171        assert not ts._processing_on
5172        assert not ts._waiting_on
5173        assert ts not in parent._unrunnable
5174        for dts in ts._dependents:
5175            assert (dts in ts._waiters) == (dts._state in ("waiting", "processing"))
5176            assert ts not in dts._waiting_on
5177
5178    def validate_no_worker(self, key):
5179        parent: SchedulerState = cast(SchedulerState, self)
5180        ts: TaskState = parent._tasks[key]
5181        dts: TaskState
5182        assert ts in parent._unrunnable
5183        assert not ts._waiting_on
5184        assert ts in parent._unrunnable
5185        assert not ts._processing_on
5186        assert not ts._who_has
5187        for dts in ts._dependencies:
5188            assert dts._who_has
5189
5190    def validate_erred(self, key):
5191        parent: SchedulerState = cast(SchedulerState, self)
5192        ts: TaskState = parent._tasks[key]
5193        assert ts._exception_blame
5194        assert not ts._who_has
5195
5196    def validate_key(self, key, ts: TaskState = None):
5197        parent: SchedulerState = cast(SchedulerState, self)
5198        try:
5199            if ts is None:
5200                ts = parent._tasks.get(key)
5201            if ts is None:
5202                logger.debug("Key lost: %s", key)
5203            else:
5204                ts.validate()
5205                try:
5206                    func = getattr(self, "validate_" + ts._state.replace("-", "_"))
5207                except AttributeError:
5208                    logger.error(
5209                        "self.validate_%s not found", ts._state.replace("-", "_")
5210                    )
5211                else:
5212                    func(key)
5213        except Exception as e:
5214            logger.exception(e)
5215            if LOG_PDB:
5216                import pdb
5217
5218                pdb.set_trace()
5219            raise
5220
5221    def validate_state(self, allow_overlap=False):
5222        parent: SchedulerState = cast(SchedulerState, self)
5223        validate_state(parent._tasks, parent._workers, parent._clients)
5224
5225        if not (set(parent._workers_dv) == set(self.stream_comms)):
5226            raise ValueError("Workers not the same in all collections")
5227
5228        ws: WorkerState
5229        for w, ws in parent._workers_dv.items():
5230            assert isinstance(w, str), (type(w), w)
5231            assert isinstance(ws, WorkerState), (type(ws), ws)
5232            assert ws._address == w
5233            if not ws._processing:
5234                assert not ws._occupancy
5235                assert ws._address in parent._idle_dv
5236            assert (ws._status == Status.running) == (ws in parent._running)
5237
5238        for ws in parent._running:
5239            assert ws._status == Status.running
5240            assert ws._address in parent._workers_dv
5241
5242        ts: TaskState
5243        for k, ts in parent._tasks.items():
5244            assert isinstance(ts, TaskState), (type(ts), ts)
5245            assert ts._key == k
5246            assert bool(ts in parent._replicated_tasks) == (len(ts._who_has) > 1)
5247            self.validate_key(k, ts)
5248
5249        for ts in parent._replicated_tasks:
5250            assert ts._state == "memory"
5251            assert ts._key in parent._tasks
5252
5253        c: str
5254        cs: ClientState
5255        for c, cs in parent._clients.items():
5256            # client=None is often used in tests...
5257            assert c is None or type(c) == str, (type(c), c)
5258            assert type(cs) == ClientState, (type(cs), cs)
5259            assert cs._client_key == c
5260
5261        a = {w: ws._nbytes for w, ws in parent._workers_dv.items()}
5262        b = {
5263            w: sum(ts.get_nbytes() for ts in ws._has_what)
5264            for w, ws in parent._workers_dv.items()
5265        }
5266        assert a == b, (a, b)
5267
5268        actual_total_occupancy = 0
5269        for worker, ws in parent._workers_dv.items():
5270            assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8
5271            actual_total_occupancy += ws._occupancy
5272
5273        assert abs(actual_total_occupancy - parent._total_occupancy) < 1e-8, (
5274            actual_total_occupancy,
5275            parent._total_occupancy,
5276        )
5277
5278    ###################
5279    # Manage Messages #
5280    ###################
5281
5282    def report(self, msg: dict, ts: TaskState = None, client: str = None):
5283        """
5284        Publish updates to all listening Queues and Comms
5285
5286        If the message contains a key then we only send the message to those
5287        comms that care about the key.
5288        """
5289        parent: SchedulerState = cast(SchedulerState, self)
5290        if ts is None:
5291            msg_key = msg.get("key")
5292            if msg_key is not None:
5293                tasks: dict = parent._tasks
5294                ts = tasks.get(msg_key)
5295
5296        cs: ClientState
5297        client_comms: dict = self.client_comms
5298        client_keys: list
5299        if ts is None:
5300            # Notify all clients
5301            client_keys = list(client_comms)
5302        elif client is None:
5303            # Notify clients interested in key
5304            client_keys = [cs._client_key for cs in ts._who_wants]
5305        else:
5306            # Notify clients interested in key (including `client`)
5307            client_keys = [
5308                cs._client_key for cs in ts._who_wants if cs._client_key != client
5309            ]
5310            client_keys.append(client)
5311
5312        k: str
5313        for k in client_keys:
5314            c = client_comms.get(k)
5315            if c is None:
5316                continue
5317            try:
5318                c.send(msg)
5319                # logger.debug("Scheduler sends message to client %s", msg)
5320            except CommClosedError:
5321                if self.status == Status.running:
5322                    logger.critical(
5323                        "Closed comm %r while trying to write %s", c, msg, exc_info=True
5324                    )
5325
5326    async def add_client(self, comm, client=None, versions=None):
5327        """Add client to network
5328
5329        We listen to all future messages from this Comm.
5330        """
5331        parent: SchedulerState = cast(SchedulerState, self)
5332        assert client is not None
5333        comm.name = "Scheduler->Client"
5334        logger.info("Receive client connection: %s", client)
5335        self.log_event(["all", client], {"action": "add-client", "client": client})
5336        parent._clients[client] = ClientState(client, versions=versions)
5337
5338        for plugin in list(self.plugins.values()):
5339            try:
5340                plugin.add_client(scheduler=self, client=client)
5341            except Exception as e:
5342                logger.exception(e)
5343
5344        try:
5345            bcomm = BatchedSend(interval="2ms", loop=self.loop)
5346            bcomm.start(comm)
5347            self.client_comms[client] = bcomm
5348            msg = {"op": "stream-start"}
5349            ws: WorkerState
5350            version_warning = version_module.error_message(
5351                version_module.get_versions(),
5352                {w: ws._versions for w, ws in parent._workers_dv.items()},
5353                versions,
5354            )
5355            msg.update(version_warning)
5356            bcomm.send(msg)
5357
5358            try:
5359                await self.handle_stream(comm=comm, extra={"client": client})
5360            finally:
5361                self.remove_client(client=client)
5362                logger.debug("Finished handling client %s", client)
5363        finally:
5364            if not comm.closed():
5365                self.client_comms[client].send({"op": "stream-closed"})
5366            try:
5367                if not sys.is_finalizing():
5368                    await self.client_comms[client].close()
5369                    del self.client_comms[client]
5370                    if self.status == Status.running:
5371                        logger.info("Close client connection: %s", client)
5372            except TypeError:  # comm becomes None during GC
5373                pass
5374
5375    def remove_client(self, client=None):
5376        """Remove client from network"""
5377        parent: SchedulerState = cast(SchedulerState, self)
5378        if self.status == Status.running:
5379            logger.info("Remove client %s", client)
5380        self.log_event(["all", client], {"action": "remove-client", "client": client})
5381        try:
5382            cs: ClientState = parent._clients[client]
5383        except KeyError:
5384            # XXX is this a legitimate condition?
5385            pass
5386        else:
5387            ts: TaskState
5388            self.client_releases_keys(
5389                keys=[ts._key for ts in cs._wants_what], client=cs._client_key
5390            )
5391            del parent._clients[client]
5392
5393            for plugin in list(self.plugins.values()):
5394                try:
5395                    plugin.remove_client(scheduler=self, client=client)
5396                except Exception as e:
5397                    logger.exception(e)
5398
5399        def remove_client_from_events():
5400            # If the client isn't registered anymore after the delay, remove from events
5401            if client not in parent._clients and client in self.events:
5402                del self.events[client]
5403
5404        cleanup_delay = parse_timedelta(
5405            dask.config.get("distributed.scheduler.events-cleanup-delay")
5406        )
5407        self.loop.call_later(cleanup_delay, remove_client_from_events)
5408
5409    def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1):
5410        """Send a single computational task to a worker"""
5411        parent: SchedulerState = cast(SchedulerState, self)
5412        try:
5413            msg: dict = _task_to_msg(parent, ts, duration)
5414            self.worker_send(worker, msg)
5415        except Exception as e:
5416            logger.exception(e)
5417            if LOG_PDB:
5418                import pdb
5419
5420                pdb.set_trace()
5421            raise
5422
5423    def handle_uncaught_error(self, **msg):
5424        logger.exception(clean_exception(**msg)[1])
5425
5426    def handle_task_finished(self, key=None, worker=None, **msg):
5427        parent: SchedulerState = cast(SchedulerState, self)
5428        if worker not in parent._workers_dv:
5429            return
5430        validate_key(key)
5431
5432        recommendations: dict
5433        client_msgs: dict
5434        worker_msgs: dict
5435
5436        r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg)
5437        recommendations, client_msgs, worker_msgs = r
5438        parent._transitions(recommendations, client_msgs, worker_msgs)
5439
5440        self.send_all(client_msgs, worker_msgs)
5441
5442    def handle_task_erred(self, key=None, **msg):
5443        parent: SchedulerState = cast(SchedulerState, self)
5444        recommendations: dict
5445        client_msgs: dict
5446        worker_msgs: dict
5447        r: tuple = self.stimulus_task_erred(key=key, **msg)
5448        recommendations, client_msgs, worker_msgs = r
5449        parent._transitions(recommendations, client_msgs, worker_msgs)
5450
5451        self.send_all(client_msgs, worker_msgs)
5452
5453    def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
5454        parent: SchedulerState = cast(SchedulerState, self)
5455        logger.debug("handle missing data key=%s worker=%s", key, errant_worker)
5456        self.log.append(("missing", key, errant_worker))
5457
5458        ts: TaskState = parent._tasks.get(key)
5459        if ts is None:
5460            return
5461        ws: WorkerState = parent._workers_dv.get(errant_worker)
5462        if ws is not None and ws in ts._who_has:
5463            parent.remove_replica(ts, ws)
5464        if not ts._who_has:
5465            if ts._run_spec:
5466                self.transitions({key: "released"})
5467            else:
5468                self.transitions({key: "forgotten"})
5469
5470    def release_worker_data(self, comm=None, key=None, worker=None):
5471        parent: SchedulerState = cast(SchedulerState, self)
5472        ws: WorkerState = parent._workers_dv.get(worker)
5473        ts: TaskState = parent._tasks.get(key)
5474        if not ws or not ts:
5475            return
5476        recommendations: dict = {}
5477        if ws in ts._who_has:
5478            parent.remove_replica(ts, ws)
5479            if not ts._who_has:
5480                recommendations[ts._key] = "released"
5481        if recommendations:
5482            self.transitions(recommendations)
5483
5484    def handle_long_running(self, key=None, worker=None, compute_duration=None):
5485        """A task has seceded from the thread pool
5486
5487        We stop the task from being stolen in the future, and change task
5488        duration accounting as if the task has stopped.
5489        """
5490        parent: SchedulerState = cast(SchedulerState, self)
5491        if key not in parent._tasks:
5492            logger.debug("Skipping long_running since key %s was already released", key)
5493            return
5494        ts: TaskState = parent._tasks[key]
5495        steal = parent._extensions.get("stealing")
5496        if steal is not None:
5497            steal.remove_key_from_stealable(ts)
5498
5499        ws: WorkerState = ts._processing_on
5500        if ws is None:
5501            logger.debug("Received long-running signal from duplicate task. Ignoring.")
5502            return
5503
5504        if compute_duration:
5505            old_duration: double = ts._prefix._duration_average
5506            new_duration: double = compute_duration
5507            avg_duration: double
5508            if old_duration < 0:
5509                avg_duration = new_duration
5510            else:
5511                avg_duration = 0.5 * old_duration + 0.5 * new_duration
5512
5513            ts._prefix._duration_average = avg_duration
5514
5515        occ: double = ws._processing[ts]
5516        ws._occupancy -= occ
5517        parent._total_occupancy -= occ
5518        ws._processing[ts] = 0
5519        self.check_idle_saturated(ws)
5520
5521    def handle_worker_status_change(self, status: str, worker: str) -> None:
5522        parent: SchedulerState = cast(SchedulerState, self)
5523        ws: WorkerState = parent._workers_dv.get(worker)  # type: ignore
5524        if not ws:
5525            return
5526        prev_status = ws._status
5527        ws._status = Status.lookup[status]  # type: ignore
5528        if ws._status == prev_status:
5529            return
5530
5531        self.log_event(
5532            ws._address,
5533            {
5534                "action": "worker-status-change",
5535                "prev-status": prev_status.name,
5536                "status": status,
5537            },
5538        )
5539
5540        if ws._status == Status.running:
5541            parent._running.add(ws)
5542
5543            recs = {}
5544            ts: TaskState
5545            for ts in parent._unrunnable:
5546                valid: set = self.valid_workers(ts)
5547                if valid is None or ws in valid:
5548                    recs[ts._key] = "waiting"
5549            if recs:
5550                client_msgs: dict = {}
5551                worker_msgs: dict = {}
5552                parent._transitions(recs, client_msgs, worker_msgs)
5553                self.send_all(client_msgs, worker_msgs)
5554
5555        else:
5556            parent._running.discard(ws)
5557
5558    async def handle_worker(self, comm=None, worker=None):
5559        """
5560        Listen to responses from a single worker
5561
5562        This is the main loop for scheduler-worker interaction
5563
5564        See Also
5565        --------
5566        Scheduler.handle_client: Equivalent coroutine for clients
5567        """
5568        comm.name = "Scheduler connection to worker"
5569        worker_comm = self.stream_comms[worker]
5570        worker_comm.start(comm)
5571        logger.info("Starting worker compute stream, %s", worker)
5572        try:
5573            await self.handle_stream(comm=comm, extra={"worker": worker})
5574        finally:
5575            if worker in self.stream_comms:
5576                worker_comm.abort()
5577                await self.remove_worker(address=worker)
5578
5579    def add_plugin(
5580        self,
5581        plugin: SchedulerPlugin,
5582        *,
5583        idempotent: bool = False,
5584        name: "str | None" = None,
5585        **kwargs,
5586    ):
5587        """Add external plugin to scheduler.
5588
5589        See https://distributed.readthedocs.io/en/latest/plugins.html
5590
5591        Parameters
5592        ----------
5593        plugin : SchedulerPlugin
5594            SchedulerPlugin instance to add
5595        idempotent : bool
5596            If true, the plugin is assumed to already exist and no
5597            action is taken.
5598        name : str
5599            A name for the plugin, if None, the name attribute is
5600            checked on the Plugin instance and generated if not
5601            discovered.
5602        **kwargs
5603            Deprecated; additional arguments passed to the `plugin` class if it is
5604            not already an instance
5605        """
5606        if isinstance(plugin, type):
5607            warnings.warn(
5608                "Adding plugins by class is deprecated and will be disabled in a "
5609                "future release. Please add plugins by instance instead.",
5610                category=FutureWarning,
5611            )
5612            plugin = plugin(self, **kwargs)  # type: ignore
5613        elif kwargs:
5614            raise ValueError("kwargs provided but plugin is already an instance")
5615
5616        if name is None:
5617            name = _get_plugin_name(plugin)
5618
5619        if name in self.plugins:
5620            if idempotent:
5621                return
5622            warnings.warn(
5623                f"Scheduler already contains a plugin with name {name}; overwriting.",
5624                category=UserWarning,
5625            )
5626
5627        self.plugins[name] = plugin
5628
5629    def remove_plugin(
5630        self,
5631        name: "str | None" = None,
5632        plugin: "SchedulerPlugin | None" = None,
5633    ) -> None:
5634        """Remove external plugin from scheduler
5635
5636        Parameters
5637        ----------
5638        name : str
5639            Name of the plugin to remove
5640        plugin : SchedulerPlugin
5641            Deprecated; use `name` argument instead. Instance of a
5642            SchedulerPlugin class to remove;
5643        """
5644        # TODO: Remove this block of code once removing plugins by value is disabled
5645        if bool(name) == bool(plugin):
5646            raise ValueError("Must provide plugin or name (mutually exclusive)")
5647        if isinstance(name, SchedulerPlugin):
5648            # Backwards compatibility - the sig used to be (plugin, name)
5649            plugin = name
5650            name = None
5651        if plugin is not None:
5652            warnings.warn(
5653                "Removing scheduler plugins by value is deprecated and will be disabled "
5654                "in a future release. Please remove scheduler plugins by name instead.",
5655                category=FutureWarning,
5656            )
5657            if hasattr(plugin, "name"):
5658                name = plugin.name  # type: ignore
5659            else:
5660                names = [k for k, v in self.plugins.items() if v is plugin]
5661                if not names:
5662                    raise ValueError(
5663                        f"Could not find {plugin} among the current scheduler plugins"
5664                    )
5665                if len(names) > 1:
5666                    raise ValueError(
5667                        f"Multiple instances of {plugin} were found in the current "
5668                        "scheduler plugins; we cannot remove this plugin."
5669                    )
5670                name = names[0]
5671        assert name is not None
5672        # End deprecated code
5673
5674        try:
5675            del self.plugins[name]
5676        except KeyError:
5677            raise ValueError(
5678                f"Could not find plugin {name!r} among the current scheduler plugins"
5679            )
5680
5681    async def register_scheduler_plugin(self, comm=None, plugin=None, name=None):
5682        """Register a plugin on the scheduler."""
5683        if not dask.config.get("distributed.scheduler.pickle"):
5684            raise ValueError(
5685                "Cannot register a scheduler plugin as the scheduler "
5686                "has been explicitly disallowed from deserializing "
5687                "arbitrary bytestrings using pickle via the "
5688                "'distributed.scheduler.pickle' configuration setting."
5689            )
5690        plugin = loads(plugin)
5691
5692        if hasattr(plugin, "start"):
5693            result = plugin.start(self)
5694            if inspect.isawaitable(result):
5695                await result
5696
5697        self.add_plugin(plugin, name=name)
5698
5699    def worker_send(self, worker, msg):
5700        """Send message to worker
5701
5702        This also handles connection failures by adding a callback to remove
5703        the worker on the next cycle.
5704        """
5705        stream_comms: dict = self.stream_comms
5706        try:
5707            stream_comms[worker].send(msg)
5708        except (CommClosedError, AttributeError):
5709            self.loop.add_callback(self.remove_worker, address=worker)
5710
5711    def client_send(self, client, msg):
5712        """Send message to client"""
5713        client_comms: dict = self.client_comms
5714        c = client_comms.get(client)
5715        if c is None:
5716            return
5717        try:
5718            c.send(msg)
5719        except CommClosedError:
5720            if self.status == Status.running:
5721                logger.critical(
5722                    "Closed comm %r while trying to write %s", c, msg, exc_info=True
5723                )
5724
5725    def send_all(self, client_msgs: dict, worker_msgs: dict):
5726        """Send messages to client and workers"""
5727        client_comms: dict = self.client_comms
5728        stream_comms: dict = self.stream_comms
5729        msgs: list
5730
5731        for client, msgs in client_msgs.items():
5732            c = client_comms.get(client)
5733            if c is None:
5734                continue
5735            try:
5736                c.send(*msgs)
5737            except CommClosedError:
5738                if self.status == Status.running:
5739                    logger.critical(
5740                        "Closed comm %r while trying to write %s",
5741                        c,
5742                        msgs,
5743                        exc_info=True,
5744                    )
5745
5746        for worker, msgs in worker_msgs.items():
5747            try:
5748                w = stream_comms[worker]
5749                w.send(*msgs)
5750            except KeyError:
5751                # worker already gone
5752                pass
5753            except (CommClosedError, AttributeError):
5754                self.loop.add_callback(self.remove_worker, address=worker)
5755
5756    ############################
5757    # Less common interactions #
5758    ############################
5759
5760    async def scatter(
5761        self,
5762        comm=None,
5763        data=None,
5764        workers=None,
5765        client=None,
5766        broadcast=False,
5767        timeout=2,
5768    ):
5769        """Send data out to workers
5770
5771        See also
5772        --------
5773        Scheduler.broadcast:
5774        """
5775        parent: SchedulerState = cast(SchedulerState, self)
5776        ws: WorkerState
5777
5778        start = time()
5779        while True:
5780            if workers is None:
5781                wss = parent._running
5782            else:
5783                workers = [self.coerce_address(w) for w in workers]
5784                wss = {parent._workers_dv[w] for w in workers}
5785                wss = {ws for ws in wss if ws._status == Status.running}
5786
5787            if wss:
5788                break
5789            if time() > start + timeout:
5790                raise TimeoutError("No valid workers found")
5791            await asyncio.sleep(0.1)
5792
5793        nthreads = {ws._address: ws.nthreads for ws in wss}
5794
5795        assert isinstance(data, dict)
5796
5797        keys, who_has, nbytes = await scatter_to_workers(
5798            nthreads, data, rpc=self.rpc, report=False
5799        )
5800
5801        self.update_data(who_has=who_has, nbytes=nbytes, client=client)
5802
5803        if broadcast:
5804            n = len(nthreads) if broadcast is True else broadcast
5805            await self.replicate(keys=keys, workers=workers, n=n)
5806
5807        self.log_event(
5808            [client, "all"], {"action": "scatter", "client": client, "count": len(data)}
5809        )
5810        return keys
5811
5812    async def gather(self, comm=None, keys=None, serializers=None):
5813        """Collect data from workers to the scheduler"""
5814        parent: SchedulerState = cast(SchedulerState, self)
5815        ws: WorkerState
5816        keys = list(keys)
5817        who_has = {}
5818        for key in keys:
5819            ts: TaskState = parent._tasks.get(key)
5820            if ts is not None:
5821                who_has[key] = [ws._address for ws in ts._who_has]
5822            else:
5823                who_has[key] = []
5824
5825        data, missing_keys, missing_workers = await gather_from_workers(
5826            who_has, rpc=self.rpc, close=False, serializers=serializers
5827        )
5828        if not missing_keys:
5829            result = {"status": "OK", "data": data}
5830        else:
5831            missing_states = [
5832                (parent._tasks[key].state if key in parent._tasks else None)
5833                for key in missing_keys
5834            ]
5835            logger.exception(
5836                "Couldn't gather keys %s state: %s workers: %s",
5837                missing_keys,
5838                missing_states,
5839                missing_workers,
5840            )
5841            result = {"status": "error", "keys": missing_keys}
5842            with log_errors():
5843                # Remove suspicious workers from the scheduler but allow them to
5844                # reconnect.
5845                await asyncio.gather(
5846                    *(
5847                        self.remove_worker(address=worker, close=False)
5848                        for worker in missing_workers
5849                    )
5850                )
5851                recommendations: dict
5852                client_msgs: dict = {}
5853                worker_msgs: dict = {}
5854                for key, workers in missing_keys.items():
5855                    # Task may already be gone if it was held by a
5856                    # `missing_worker`
5857                    ts: TaskState = parent._tasks.get(key)
5858                    logger.exception(
5859                        "Workers don't have promised key: %s, %s",
5860                        str(workers),
5861                        str(key),
5862                    )
5863                    if not workers or ts is None:
5864                        continue
5865                    recommendations: dict = {key: "released"}
5866                    for worker in workers:
5867                        ws = parent._workers_dv.get(worker)
5868                        if ws is not None and ws in ts._who_has:
5869                            parent.remove_replica(ts, ws)
5870                            parent._transitions(
5871                                recommendations, client_msgs, worker_msgs
5872                            )
5873                self.send_all(client_msgs, worker_msgs)
5874
5875        self.log_event("all", {"action": "gather", "count": len(keys)})
5876        return result
5877
5878    def clear_task_state(self):
5879        # XXX what about nested state such as ClientState.wants_what
5880        # (see also fire-and-forget...)
5881        logger.info("Clear task state")
5882        for collection in self._task_state_collections:
5883            collection.clear()
5884
5885    async def restart(self, client=None, timeout=30):
5886        """Restart all workers. Reset local state."""
5887        parent: SchedulerState = cast(SchedulerState, self)
5888        with log_errors():
5889
5890            n_workers = len(parent._workers_dv)
5891
5892            logger.info("Send lost future signal to clients")
5893            cs: ClientState
5894            ts: TaskState
5895            for cs in parent._clients.values():
5896                self.client_releases_keys(
5897                    keys=[ts._key for ts in cs._wants_what], client=cs._client_key
5898                )
5899
5900            ws: WorkerState
5901            nannies = {addr: ws._nanny for addr, ws in parent._workers_dv.items()}
5902
5903            for addr in list(parent._workers_dv):
5904                try:
5905                    # Ask the worker to close if it doesn't have a nanny,
5906                    # otherwise the nanny will kill it anyway
5907                    await self.remove_worker(address=addr, close=addr not in nannies)
5908                except Exception:
5909                    logger.info(
5910                        "Exception while restarting.  This is normal", exc_info=True
5911                    )
5912
5913            self.clear_task_state()
5914
5915            for plugin in list(self.plugins.values()):
5916                try:
5917                    plugin.restart(self)
5918                except Exception as e:
5919                    logger.exception(e)
5920
5921            logger.debug("Send kill signal to nannies: %s", nannies)
5922
5923            nannies = [
5924                rpc(nanny_address, connection_args=self.connection_args)
5925                for nanny_address in nannies.values()
5926                if nanny_address is not None
5927            ]
5928
5929            resps = All(
5930                [
5931                    nanny.restart(
5932                        close=True, timeout=timeout * 0.8, executor_wait=False
5933                    )
5934                    for nanny in nannies
5935                ]
5936            )
5937            try:
5938                resps = await asyncio.wait_for(resps, timeout)
5939            except TimeoutError:
5940                logger.error(
5941                    "Nannies didn't report back restarted within "
5942                    "timeout.  Continuuing with restart process"
5943                )
5944            else:
5945                if not all(resp == "OK" for resp in resps):
5946                    logger.error(
5947                        "Not all workers responded positively: %s", resps, exc_info=True
5948                    )
5949            finally:
5950                await asyncio.gather(*[nanny.close_rpc() for nanny in nannies])
5951
5952            self.clear_task_state()
5953
5954            with suppress(AttributeError):
5955                for c in self._worker_coroutines:
5956                    c.cancel()
5957
5958            self.log_event([client, "all"], {"action": "restart", "client": client})
5959            start = time()
5960            while time() < start + 10 and len(parent._workers_dv) < n_workers:
5961                await asyncio.sleep(0.01)
5962
5963            self.report({"op": "restart"})
5964
5965    async def broadcast(
5966        self,
5967        comm=None,
5968        msg=None,
5969        workers=None,
5970        hosts=None,
5971        nanny=False,
5972        serializers=None,
5973    ):
5974        """Broadcast message to workers, return all results"""
5975        parent: SchedulerState = cast(SchedulerState, self)
5976        if workers is None or workers is True:
5977            if hosts is None:
5978                workers = list(parent._workers_dv)
5979            else:
5980                workers = []
5981        if hosts is not None:
5982            for host in hosts:
5983                dh: dict = parent._host_info.get(host)
5984                if dh is not None:
5985                    workers.extend(dh["addresses"])
5986        # TODO replace with worker_list
5987
5988        if nanny:
5989            addresses = [parent._workers_dv[w].nanny for w in workers]
5990        else:
5991            addresses = workers
5992
5993        async def send_message(addr):
5994            comm = await self.rpc.connect(addr)
5995            comm.name = "Scheduler Broadcast"
5996            try:
5997                resp = await send_recv(comm, close=True, serializers=serializers, **msg)
5998            finally:
5999                self.rpc.reuse(addr, comm)
6000            return resp
6001
6002        results = await All(
6003            [send_message(address) for address in addresses if address is not None]
6004        )
6005
6006        return dict(zip(workers, results))
6007
6008    async def proxy(self, comm=None, msg=None, worker=None, serializers=None):
6009        """Proxy a communication through the scheduler to some other worker"""
6010        d = await self.broadcast(
6011            comm=comm, msg=msg, workers=[worker], serializers=serializers
6012        )
6013        return d[worker]
6014
6015    async def gather_on_worker(
6016        self, worker_address: str, who_has: "dict[str, list[str]]"
6017    ) -> set:
6018        """Peer-to-peer copy of keys from multiple workers to a single worker
6019
6020        Parameters
6021        ----------
6022        worker_address: str
6023            Recipient worker address to copy keys to
6024        who_has: dict[Hashable, list[str]]
6025            {key: [sender address, sender address, ...], key: ...}
6026
6027        Returns
6028        -------
6029        returns:
6030            set of keys that failed to be copied
6031        """
6032        try:
6033            result = await retry_operation(
6034                self.rpc(addr=worker_address).gather, who_has=who_has
6035            )
6036        except OSError as e:
6037            # This can happen e.g. if the worker is going through controlled shutdown;
6038            # it doesn't necessarily mean that it went unexpectedly missing
6039            logger.warning(
6040                f"Communication with worker {worker_address} failed during "
6041                f"replication: {e.__class__.__name__}: {e}"
6042            )
6043            return set(who_has)
6044
6045        parent: SchedulerState = cast(SchedulerState, self)
6046        ws: WorkerState = parent._workers_dv.get(worker_address)  # type: ignore
6047
6048        if ws is None:
6049            logger.warning(f"Worker {worker_address} lost during replication")
6050            return set(who_has)
6051        elif result["status"] == "OK":
6052            keys_failed = set()
6053            keys_ok: Set = who_has.keys()
6054        elif result["status"] == "partial-fail":
6055            keys_failed = set(result["keys"])
6056            keys_ok = who_has.keys() - keys_failed
6057            logger.warning(
6058                f"Worker {worker_address} failed to acquire keys: {result['keys']}"
6059            )
6060        else:  # pragma: nocover
6061            raise ValueError(f"Unexpected message from {worker_address}: {result}")
6062
6063        for key in keys_ok:
6064            ts: TaskState = parent._tasks.get(key)  # type: ignore
6065            if ts is None or ts._state != "memory":
6066                logger.warning(f"Key lost during replication: {key}")
6067                continue
6068            if ws not in ts._who_has:
6069                parent.add_replica(ts, ws)
6070
6071        return keys_failed
6072
6073    async def delete_worker_data(
6074        self, worker_address: str, keys: "Collection[str]"
6075    ) -> None:
6076        """Delete data from a worker and update the corresponding worker/task states
6077
6078        Parameters
6079        ----------
6080        worker_address: str
6081            Worker address to delete keys from
6082        keys: list[str]
6083            List of keys to delete on the specified worker
6084        """
6085        parent: SchedulerState = cast(SchedulerState, self)
6086
6087        try:
6088            await retry_operation(
6089                self.rpc(addr=worker_address).free_keys,
6090                keys=list(keys),
6091                stimulus_id=f"delete-data-{time()}",
6092            )
6093        except OSError as e:
6094            # This can happen e.g. if the worker is going through controlled shutdown;
6095            # it doesn't necessarily mean that it went unexpectedly missing
6096            logger.warning(
6097                f"Communication with worker {worker_address} failed during "
6098                f"replication: {e.__class__.__name__}: {e}"
6099            )
6100            return
6101
6102        ws: WorkerState = parent._workers_dv.get(worker_address)  # type: ignore
6103        if ws is None:
6104            return
6105
6106        for key in keys:
6107            ts: TaskState = parent._tasks.get(key)  # type: ignore
6108            if ts is not None and ws in ts._who_has:
6109                assert ts._state == "memory"
6110                parent.remove_replica(ts, ws)
6111                if not ts._who_has:
6112                    # Last copy deleted
6113                    self.transitions({key: "released"})
6114
6115        self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys})
6116
6117    async def rebalance(
6118        self,
6119        comm=None,
6120        keys: "Iterable[Hashable]" = None,
6121        workers: "Iterable[str]" = None,
6122    ) -> dict:
6123        """Rebalance keys so that each worker ends up with roughly the same process
6124        memory (managed+unmanaged).
6125
6126        .. warning::
6127           This operation is generally not well tested against normal operation of the
6128           scheduler. It is not recommended to use it while waiting on computations.
6129
6130        **Algorithm**
6131
6132        #. Find the mean occupancy of the cluster, defined as data managed by dask +
6133           unmanaged process memory that has been there for at least 30 seconds
6134           (``distributed.worker.memory.recent-to-old-time``).
6135           This lets us ignore temporary spikes caused by task heap usage.
6136
6137           Alternatively, you may change how memory is measured both for the individual
6138           workers as well as to calculate the mean through
6139           ``distributed.worker.memory.rebalance.measure``. Namely, this can be useful
6140           to disregard inaccurate OS memory measurements.
6141
6142        #. Discard workers whose occupancy is within 5% of the mean cluster occupancy
6143           (``distributed.worker.memory.rebalance.sender-recipient-gap`` / 2).
6144           This helps avoid data from bouncing around the cluster repeatedly.
6145        #. Workers above the mean are senders; those below are recipients.
6146        #. Discard senders whose absolute occupancy is below 30%
6147           (``distributed.worker.memory.rebalance.sender-min``). In other words, no data
6148           is moved regardless of imbalancing as long as all workers are below 30%.
6149        #. Discard recipients whose absolute occupancy is above 60%
6150           (``distributed.worker.memory.rebalance.recipient-max``).
6151           Note that this threshold by default is the same as
6152           ``distributed.worker.memory.target`` to prevent workers from accepting data
6153           and immediately spilling it out to disk.
6154        #. Iteratively pick the sender and recipient that are farthest from the mean and
6155           move the *least recently inserted* key between the two, until either all
6156           senders or all recipients fall within 5% of the mean.
6157
6158           A recipient will be skipped if it already has a copy of the data. In other
6159           words, this method does not degrade replication.
6160           A key will be skipped if there are no recipients available with enough memory
6161           to accept the key and that don't already hold a copy.
6162
6163        The least recently insertd (LRI) policy is a greedy choice with the advantage of
6164        being O(1), trivial to implement (it relies on python dict insertion-sorting)
6165        and hopefully good enough in most cases. Discarded alternative policies were:
6166
6167        - Largest first. O(n*log(n)) save for non-trivial additional data structures and
6168          risks causing the largest chunks of data to repeatedly move around the
6169          cluster like pinballs.
6170        - Least recently used (LRU). This information is currently available on the
6171          workers only and not trivial to replicate on the scheduler; transmitting it
6172          over the network would be very expensive. Also, note that dask will go out of
6173          its way to minimise the amount of time intermediate keys are held in memory,
6174          so in such a case LRI is a close approximation of LRU.
6175
6176        Parameters
6177        ----------
6178        keys: optional
6179            whitelist of dask keys that should be considered for moving. All other keys
6180            will be ignored. Note that this offers no guarantee that a key will actually
6181            be moved (e.g. because it is unnecessary or because there are no viable
6182            recipient workers for it).
6183        workers: optional
6184            whitelist of workers addresses to be considered as senders or recipients.
6185            All other workers will be ignored. The mean cluster occupancy will be
6186            calculated only using the whitelisted workers.
6187        """
6188        parent: SchedulerState = cast(SchedulerState, self)
6189
6190        with log_errors():
6191            wss: "Collection[WorkerState]"
6192            if workers is not None:
6193                wss = [parent._workers_dv[w] for w in workers]
6194            else:
6195                wss = parent._workers_dv.values()
6196            if not wss:
6197                return {"status": "OK"}
6198
6199            if keys is not None:
6200                if not isinstance(keys, Set):
6201                    keys = set(keys)  # unless already a set-like
6202                if not keys:
6203                    return {"status": "OK"}
6204                missing_data = [
6205                    k
6206                    for k in keys
6207                    if k not in parent._tasks or not parent._tasks[k].who_has
6208                ]
6209                if missing_data:
6210                    return {"status": "partial-fail", "keys": missing_data}
6211
6212            msgs = self._rebalance_find_msgs(keys, wss)
6213            if not msgs:
6214                return {"status": "OK"}
6215
6216            async with self._lock:
6217                result = await self._rebalance_move_data(msgs)
6218                if result["status"] == "partial-fail" and keys is None:
6219                    # Only return failed keys if the client explicitly asked for them
6220                    result = {"status": "OK"}
6221                return result
6222
6223    def _rebalance_find_msgs(
6224        self,
6225        keys: "Set[Hashable] | None",
6226        workers: "Iterable[WorkerState]",
6227    ) -> "list[tuple[WorkerState, WorkerState, TaskState]]":
6228        """Identify workers that need to lose keys and those that can receive them,
6229        together with how many bytes each needs to lose/receive. Then, pair a sender
6230        worker with a recipient worker for each key, until the cluster is rebalanced.
6231
6232        This method only defines the work to be performed; it does not start any network
6233        transfers itself.
6234
6235        The big-O complexity is O(wt + ke*log(we)), where
6236
6237        - wt is the total number of workers on the cluster (or the number of whitelisted
6238          workers, if explicitly stated by the user)
6239        - we is the number of workers that are eligible to be senders or recipients
6240        - kt is the total number of keys on the cluster (or on the whitelisted workers)
6241        - ke is the number of keys that need to be moved in order to achieve a balanced
6242          cluster
6243
6244        There is a degenerate edge case O(wt + kt*log(we)) when kt is much greater than
6245        the number of whitelisted keys, or when most keys are replicated or cannot be
6246        moved for some other reason.
6247
6248        Returns list of tuples to feed into _rebalance_move_data:
6249
6250        - sender worker
6251        - recipient worker
6252        - task to be transferred
6253        """
6254        parent: SchedulerState = cast(SchedulerState, self)
6255        ts: TaskState
6256        ws: WorkerState
6257
6258        # Heaps of workers, managed by the heapq module, that need to send/receive data,
6259        # with how many bytes each needs to send/receive.
6260        #
6261        # Each element of the heap is a tuple constructed as follows:
6262        # - snd_bytes_max/rec_bytes_max: maximum number of bytes to send or receive.
6263        #   This number is negative, so that the workers farthest from the cluster mean
6264        #   are at the top of the smallest-first heaps.
6265        # - snd_bytes_min/rec_bytes_min: minimum number of bytes after sending/receiving
6266        #   which the worker should not be considered anymore. This is also negative.
6267        # - arbitrary unique number, there just to to make sure that WorkerState objects
6268        #   are never used for sorting in the unlikely event that two processes have
6269        #   exactly the same number of bytes allocated.
6270        # - WorkerState
6271        # - iterator of all tasks in memory on the worker (senders only), insertion
6272        #   sorted (least recently inserted first).
6273        #   Note that this iterator will typically *not* be exhausted. It will only be
6274        #   exhausted if, after moving away from the worker all keys that can be moved,
6275        #   is insufficient to drop snd_bytes_min above 0.
6276        senders: "list[tuple[int, int, int, WorkerState, Iterator[TaskState]]]" = []
6277        recipients: "list[tuple[int, int, int, WorkerState]]" = []
6278
6279        # Output: [(sender, recipient, task), ...]
6280        msgs: "list[tuple[WorkerState, WorkerState, TaskState]]" = []
6281
6282        # By default, this is the optimistic memory, meaning total process memory minus
6283        # unmanaged memory that appeared over the last 30 seconds
6284        # (distributed.worker.memory.recent-to-old-time).
6285        # This lets us ignore temporary spikes caused by task heap usage.
6286        memory_by_worker = [
6287            (ws, getattr(ws.memory, parent.MEMORY_REBALANCE_MEASURE)) for ws in workers
6288        ]
6289        mean_memory = sum(m for _, m in memory_by_worker) // len(memory_by_worker)
6290
6291        for ws, ws_memory in memory_by_worker:
6292            if ws.memory_limit:
6293                half_gap = int(parent.MEMORY_REBALANCE_HALF_GAP * ws.memory_limit)
6294                sender_min = parent.MEMORY_REBALANCE_SENDER_MIN * ws.memory_limit
6295                recipient_max = parent.MEMORY_REBALANCE_RECIPIENT_MAX * ws.memory_limit
6296            else:
6297                half_gap = 0
6298                sender_min = 0.0
6299                recipient_max = math.inf
6300
6301            if (
6302                ws._has_what
6303                and ws_memory >= mean_memory + half_gap
6304                and ws_memory >= sender_min
6305            ):
6306                # This may send the worker below sender_min (by design)
6307                snd_bytes_max = mean_memory - ws_memory  # negative
6308                snd_bytes_min = snd_bytes_max + half_gap  # negative
6309                # See definition of senders above
6310                senders.append(
6311                    (snd_bytes_max, snd_bytes_min, id(ws), ws, iter(ws._has_what))
6312                )
6313            elif ws_memory < mean_memory - half_gap and ws_memory < recipient_max:
6314                # This may send the worker above recipient_max (by design)
6315                rec_bytes_max = ws_memory - mean_memory  # negative
6316                rec_bytes_min = rec_bytes_max + half_gap  # negative
6317                # See definition of recipients above
6318                recipients.append((rec_bytes_max, rec_bytes_min, id(ws), ws))
6319
6320        # Fast exit in case no transfers are necessary or possible
6321        if not senders or not recipients:
6322            self.log_event(
6323                "all",
6324                {
6325                    "action": "rebalance",
6326                    "senders": len(senders),
6327                    "recipients": len(recipients),
6328                    "moved_keys": 0,
6329                },
6330            )
6331            return []
6332
6333        heapq.heapify(senders)
6334        heapq.heapify(recipients)
6335
6336        snd_ws: WorkerState
6337        rec_ws: WorkerState
6338
6339        while senders and recipients:
6340            snd_bytes_max, snd_bytes_min, _, snd_ws, ts_iter = senders[0]
6341
6342            # Iterate through tasks in memory, least recently inserted first
6343            for ts in ts_iter:
6344                if keys is not None and ts.key not in keys:
6345                    continue
6346                nbytes = ts.nbytes
6347                if nbytes + snd_bytes_max > 0:
6348                    # Moving this task would cause the sender to go below mean and
6349                    # potentially risk becoming a recipient, which would cause tasks to
6350                    # bounce around. Move on to the next task of the same sender.
6351                    continue
6352
6353                # Find the recipient, farthest from the mean, which
6354                # 1. has enough available RAM for this task, and
6355                # 2. doesn't hold a copy of this task already
6356                # There may not be any that satisfies these conditions; in this case
6357                # this task won't be moved.
6358                skipped_recipients = []
6359                use_recipient = False
6360                while recipients and not use_recipient:
6361                    rec_bytes_max, rec_bytes_min, _, rec_ws = recipients[0]
6362                    if nbytes + rec_bytes_max > 0:
6363                        # recipients are sorted by rec_bytes_max.
6364                        # The next ones will be worse; no reason to continue iterating
6365                        break
6366                    use_recipient = ts not in rec_ws._has_what
6367                    if not use_recipient:
6368                        skipped_recipients.append(heapq.heappop(recipients))
6369
6370                for recipient in skipped_recipients:
6371                    heapq.heappush(recipients, recipient)
6372
6373                if not use_recipient:
6374                    # This task has no recipients available. Leave it on the sender and
6375                    # move on to the next task of the same sender.
6376                    continue
6377
6378                # Schedule task for transfer from sender to recipient
6379                msgs.append((snd_ws, rec_ws, ts))
6380
6381                # *_bytes_max/min are all negative for heap sorting
6382                snd_bytes_max += nbytes
6383                snd_bytes_min += nbytes
6384                rec_bytes_max += nbytes
6385                rec_bytes_min += nbytes
6386
6387                # Stop iterating on the tasks of this sender for now and, if it still
6388                # has bytes to lose, push it back into the senders heap; it may or may
6389                # not come back on top again.
6390                if snd_bytes_min < 0:
6391                    # See definition of senders above
6392                    heapq.heapreplace(
6393                        senders,
6394                        (snd_bytes_max, snd_bytes_min, id(snd_ws), snd_ws, ts_iter),
6395                    )
6396                else:
6397                    heapq.heappop(senders)
6398
6399                # If recipient still has bytes to gain, push it back into the recipients
6400                # heap; it may or may not come back on top again.
6401                if rec_bytes_min < 0:
6402                    # See definition of recipients above
6403                    heapq.heapreplace(
6404                        recipients,
6405                        (rec_bytes_max, rec_bytes_min, id(rec_ws), rec_ws),
6406                    )
6407                else:
6408                    heapq.heappop(recipients)
6409
6410                # Move to next sender with the most data to lose.
6411                # It may or may not be the same sender again.
6412                break
6413
6414            else:  # for ts in ts_iter
6415                # Exhausted tasks on this sender
6416                heapq.heappop(senders)
6417
6418        return msgs
6419
6420    async def _rebalance_move_data(
6421        self, msgs: "list[tuple[WorkerState, WorkerState, TaskState]]"
6422    ) -> dict:
6423        """Perform the actual transfer of data across the network in rebalance().
6424        Takes in input the output of _rebalance_find_msgs(), that is a list of tuples:
6425
6426        - sender worker
6427        - recipient worker
6428        - task to be transferred
6429
6430        FIXME this method is not robust when the cluster is not idle.
6431        """
6432        snd_ws: WorkerState
6433        rec_ws: WorkerState
6434        ts: TaskState
6435
6436        to_recipients: "defaultdict[str, defaultdict[str, list[str]]]" = defaultdict(
6437            lambda: defaultdict(list)
6438        )
6439        for snd_ws, rec_ws, ts in msgs:
6440            to_recipients[rec_ws.address][ts._key].append(snd_ws.address)
6441        failed_keys_by_recipient = dict(
6442            zip(
6443                to_recipients,
6444                await asyncio.gather(
6445                    *(
6446                        # Note: this never raises exceptions
6447                        self.gather_on_worker(w, who_has)
6448                        for w, who_has in to_recipients.items()
6449                    )
6450                ),
6451            )
6452        )
6453
6454        to_senders = defaultdict(list)
6455        for snd_ws, rec_ws, ts in msgs:
6456            if ts._key not in failed_keys_by_recipient[rec_ws.address]:
6457                to_senders[snd_ws.address].append(ts._key)
6458
6459        # Note: this never raises exceptions
6460        await asyncio.gather(
6461            *(self.delete_worker_data(r, v) for r, v in to_senders.items())
6462        )
6463
6464        for r, v in to_recipients.items():
6465            self.log_event(r, {"action": "rebalance", "who_has": v})
6466        self.log_event(
6467            "all",
6468            {
6469                "action": "rebalance",
6470                "senders": valmap(len, to_senders),
6471                "recipients": valmap(len, to_recipients),
6472                "moved_keys": len(msgs),
6473            },
6474        )
6475
6476        missing_keys = {k for r in failed_keys_by_recipient.values() for k in r}
6477        if missing_keys:
6478            return {"status": "partial-fail", "keys": list(missing_keys)}
6479        else:
6480            return {"status": "OK"}
6481
6482    async def replicate(
6483        self,
6484        comm=None,
6485        keys=None,
6486        n=None,
6487        workers=None,
6488        branching_factor=2,
6489        delete=True,
6490        lock=True,
6491    ):
6492        """Replicate data throughout cluster
6493
6494        This performs a tree copy of the data throughout the network
6495        individually on each piece of data.
6496
6497        Parameters
6498        ----------
6499        keys: Iterable
6500            list of keys to replicate
6501        n: int
6502            Number of replications we expect to see within the cluster
6503        branching_factor: int, optional
6504            The number of workers that can copy data in each generation.
6505            The larger the branching factor, the more data we copy in
6506            a single step, but the more a given worker risks being
6507            swamped by data requests.
6508
6509        See also
6510        --------
6511        Scheduler.rebalance
6512        """
6513        parent: SchedulerState = cast(SchedulerState, self)
6514        ws: WorkerState
6515        wws: WorkerState
6516        ts: TaskState
6517
6518        assert branching_factor > 0
6519        async with self._lock if lock else empty_context:
6520            if workers is not None:
6521                workers = {parent._workers_dv[w] for w in self.workers_list(workers)}
6522                workers = {ws for ws in workers if ws._status == Status.running}
6523            else:
6524                workers = parent._running
6525
6526            if n is None:
6527                n = len(workers)
6528            else:
6529                n = min(n, len(workers))
6530            if n == 0:
6531                raise ValueError("Can not use replicate to delete data")
6532
6533            tasks = {parent._tasks[k] for k in keys}
6534            missing_data = [ts._key for ts in tasks if not ts._who_has]
6535            if missing_data:
6536                return {"status": "partial-fail", "keys": missing_data}
6537
6538            # Delete extraneous data
6539            if delete:
6540                del_worker_tasks = defaultdict(set)
6541                for ts in tasks:
6542                    del_candidates = tuple(ts._who_has & workers)
6543                    if len(del_candidates) > n:
6544                        for ws in random.sample(
6545                            del_candidates, len(del_candidates) - n
6546                        ):
6547                            del_worker_tasks[ws].add(ts)
6548
6549                # Note: this never raises exceptions
6550                await asyncio.gather(
6551                    *[
6552                        self.delete_worker_data(ws._address, [t.key for t in tasks])
6553                        for ws, tasks in del_worker_tasks.items()
6554                    ]
6555                )
6556
6557            # Copy not-yet-filled data
6558            while tasks:
6559                gathers = defaultdict(dict)
6560                for ts in list(tasks):
6561                    if ts._state == "forgotten":
6562                        # task is no longer needed by any client or dependant task
6563                        tasks.remove(ts)
6564                        continue
6565                    n_missing = n - len(ts._who_has & workers)
6566                    if n_missing <= 0:
6567                        # Already replicated enough
6568                        tasks.remove(ts)
6569                        continue
6570
6571                    count = min(n_missing, branching_factor * len(ts._who_has))
6572                    assert count > 0
6573
6574                    for ws in random.sample(tuple(workers - ts._who_has), count):
6575                        gathers[ws._address][ts._key] = [
6576                            wws._address for wws in ts._who_has
6577                        ]
6578
6579                await asyncio.gather(
6580                    *(
6581                        # Note: this never raises exceptions
6582                        self.gather_on_worker(w, who_has)
6583                        for w, who_has in gathers.items()
6584                    )
6585                )
6586                for r, v in gathers.items():
6587                    self.log_event(r, {"action": "replicate-add", "who_has": v})
6588
6589            self.log_event(
6590                "all",
6591                {
6592                    "action": "replicate",
6593                    "workers": list(workers),
6594                    "key-count": len(keys),
6595                    "branching-factor": branching_factor,
6596                },
6597            )
6598
6599    def workers_to_close(
6600        self,
6601        comm=None,
6602        memory_ratio: "int | float | None" = None,
6603        n: "int | None" = None,
6604        key: "Callable[[WorkerState], Hashable] | None" = None,
6605        minimum: "int | None" = None,
6606        target: "int | None" = None,
6607        attribute: str = "address",
6608    ) -> "list[str]":
6609        """
6610        Find workers that we can close with low cost
6611
6612        This returns a list of workers that are good candidates to retire.
6613        These workers are not running anything and are storing
6614        relatively little data relative to their peers.  If all workers are
6615        idle then we still maintain enough workers to have enough RAM to store
6616        our data, with a comfortable buffer.
6617
6618        This is for use with systems like ``distributed.deploy.adaptive``.
6619
6620        Parameters
6621        ----------
6622        memory_ratio : Number
6623            Amount of extra space we want to have for our stored data.
6624            Defaults to 2, or that we want to have twice as much memory as we
6625            currently have data.
6626        n : int
6627            Number of workers to close
6628        minimum : int
6629            Minimum number of workers to keep around
6630        key : Callable(WorkerState)
6631            An optional callable mapping a WorkerState object to a group
6632            affiliation. Groups will be closed together. This is useful when
6633            closing workers must be done collectively, such as by hostname.
6634        target : int
6635            Target number of workers to have after we close
6636        attribute : str
6637            The attribute of the WorkerState object to return, like "address"
6638            or "name".  Defaults to "address".
6639
6640        Examples
6641        --------
6642        >>> scheduler.workers_to_close()
6643        ['tcp://192.168.0.1:1234', 'tcp://192.168.0.2:1234']
6644
6645        Group workers by hostname prior to closing
6646
6647        >>> scheduler.workers_to_close(key=lambda ws: ws.host)
6648        ['tcp://192.168.0.1:1234', 'tcp://192.168.0.1:4567']
6649
6650        Remove two workers
6651
6652        >>> scheduler.workers_to_close(n=2)
6653
6654        Keep enough workers to have twice as much memory as we we need.
6655
6656        >>> scheduler.workers_to_close(memory_ratio=2)
6657
6658        Returns
6659        -------
6660        to_close: list of worker addresses that are OK to close
6661
6662        See Also
6663        --------
6664        Scheduler.retire_workers
6665        """
6666        parent: SchedulerState = cast(SchedulerState, self)
6667        if target is not None and n is None:
6668            n = len(parent._workers_dv) - target
6669        if n is not None:
6670            if n < 0:
6671                n = 0
6672            target = len(parent._workers_dv) - n
6673
6674        if n is None and memory_ratio is None:
6675            memory_ratio = 2
6676
6677        ws: WorkerState
6678        with log_errors():
6679            if not n and all([ws._processing for ws in parent._workers_dv.values()]):
6680                return []
6681
6682            if key is None:
6683                key = operator.attrgetter("address")
6684            if isinstance(key, bytes) and dask.config.get(
6685                "distributed.scheduler.pickle"
6686            ):
6687                key = pickle.loads(key)
6688
6689            groups = groupby(key, parent._workers.values())
6690
6691            limit_bytes = {
6692                k: sum([ws._memory_limit for ws in v]) for k, v in groups.items()
6693            }
6694            group_bytes = {k: sum([ws._nbytes for ws in v]) for k, v in groups.items()}
6695
6696            limit = sum(limit_bytes.values())
6697            total = sum(group_bytes.values())
6698
6699            def _key(group):
6700                wws: WorkerState
6701                is_idle = not any([wws._processing for wws in groups[group]])
6702                bytes = -group_bytes[group]
6703                return (is_idle, bytes)
6704
6705            idle = sorted(groups, key=_key)
6706
6707            to_close = []
6708            n_remain = len(parent._workers_dv)
6709
6710            while idle:
6711                group = idle.pop()
6712                if n is None and any([ws._processing for ws in groups[group]]):
6713                    break
6714
6715                if minimum and n_remain - len(groups[group]) < minimum:
6716                    break
6717
6718                limit -= limit_bytes[group]
6719
6720                if (
6721                    n is not None and n_remain - len(groups[group]) >= cast(int, target)
6722                ) or (memory_ratio is not None and limit >= memory_ratio * total):
6723                    to_close.append(group)
6724                    n_remain -= len(groups[group])
6725
6726                else:
6727                    break
6728
6729            result = [getattr(ws, attribute) for g in to_close for ws in groups[g]]
6730            if result:
6731                logger.debug("Suggest closing workers: %s", result)
6732
6733            return result
6734
6735    async def retire_workers(
6736        self,
6737        comm=None,
6738        workers=None,
6739        remove=True,
6740        close_workers=False,
6741        names=None,
6742        lock=True,
6743        **kwargs,
6744    ) -> dict:
6745        """Gracefully retire workers from cluster
6746
6747        Parameters
6748        ----------
6749        workers: list (optional)
6750            List of worker addresses to retire.
6751            If not provided we call ``workers_to_close`` which finds a good set
6752        names: list (optional)
6753            List of worker names to retire.
6754        remove: bool (defaults to True)
6755            Whether or not to remove the worker metadata immediately or else
6756            wait for the worker to contact us
6757        close_workers: bool (defaults to False)
6758            Whether or not to actually close the worker explicitly from here.
6759            Otherwise we expect some external job scheduler to finish off the
6760            worker.
6761        **kwargs: dict
6762            Extra options to pass to workers_to_close to determine which
6763            workers we should drop
6764
6765        Returns
6766        -------
6767        Dictionary mapping worker ID/address to dictionary of information about
6768        that worker for each retired worker.
6769
6770        See Also
6771        --------
6772        Scheduler.workers_to_close
6773        """
6774        parent: SchedulerState = cast(SchedulerState, self)
6775        ws: WorkerState
6776        ts: TaskState
6777        with log_errors():
6778            async with self._lock if lock else empty_context:
6779                if names is not None:
6780                    if workers is not None:
6781                        raise TypeError("names and workers are mutually exclusive")
6782                    if names:
6783                        logger.info("Retire worker names %s", names)
6784                    names = set(map(str, names))
6785                    workers = {
6786                        ws._address
6787                        for ws in parent._workers_dv.values()
6788                        if str(ws._name) in names
6789                    }
6790                elif workers is None:
6791                    while True:
6792                        try:
6793                            workers = self.workers_to_close(**kwargs)
6794                            if not workers:
6795                                return {}
6796                            return await self.retire_workers(
6797                                workers=workers,
6798                                remove=remove,
6799                                close_workers=close_workers,
6800                                lock=False,
6801                            )
6802                        except KeyError:  # keys left during replicate
6803                            pass
6804
6805                workers = {
6806                    parent._workers_dv[w] for w in workers if w in parent._workers_dv
6807                }
6808                if not workers:
6809                    return {}
6810                logger.info("Retire workers %s", workers)
6811
6812                # Keys orphaned by retiring those workers
6813                keys = {k for w in workers for k in w.has_what}
6814                keys = {ts._key for ts in keys if ts._who_has.issubset(workers)}
6815
6816                if keys:
6817                    other_workers = set(parent._workers_dv.values()) - workers
6818                    if not other_workers:
6819                        return {}
6820                    logger.info("Moving %d keys to other workers", len(keys))
6821                    await self.replicate(
6822                        keys=keys,
6823                        workers=[ws._address for ws in other_workers],
6824                        n=1,
6825                        delete=False,
6826                        lock=False,
6827                    )
6828
6829                worker_keys = {ws._address: ws.identity() for ws in workers}
6830                if close_workers:
6831                    await asyncio.gather(
6832                        *[self.close_worker(worker=w, safe=True) for w in worker_keys]
6833                    )
6834                if remove:
6835                    await asyncio.gather(
6836                        *[self.remove_worker(address=w, safe=True) for w in worker_keys]
6837                    )
6838
6839                self.log_event(
6840                    "all",
6841                    {
6842                        "action": "retire-workers",
6843                        "workers": worker_keys,
6844                        "moved-keys": len(keys),
6845                    },
6846                )
6847                self.log_event(list(worker_keys), {"action": "retired"})
6848
6849                return worker_keys
6850
6851    def add_keys(self, comm=None, worker=None, keys=(), stimulus_id=None):
6852        """
6853        Learn that a worker has certain keys
6854
6855        This should not be used in practice and is mostly here for legacy
6856        reasons.  However, it is sent by workers from time to time.
6857        """
6858        parent: SchedulerState = cast(SchedulerState, self)
6859        if worker not in parent._workers_dv:
6860            return "not found"
6861        ws: WorkerState = parent._workers_dv[worker]
6862        redundant_replicas = []
6863        for key in keys:
6864            ts: TaskState = parent._tasks.get(key)
6865            if ts is not None and ts._state == "memory":
6866                if ws not in ts._who_has:
6867                    parent.add_replica(ts, ws)
6868            else:
6869                redundant_replicas.append(key)
6870
6871        if redundant_replicas:
6872            if not stimulus_id:
6873                stimulus_id = f"redundant-replicas-{time()}"
6874            self.worker_send(
6875                worker,
6876                {
6877                    "op": "remove-replicas",
6878                    "keys": redundant_replicas,
6879                    "stimulus_id": stimulus_id,
6880                },
6881            )
6882
6883        return "OK"
6884
6885    def update_data(
6886        self,
6887        comm=None,
6888        *,
6889        who_has: dict,
6890        nbytes: dict,
6891        client=None,
6892        serializers=None,
6893    ):
6894        """
6895        Learn that new data has entered the network from an external source
6896
6897        See Also
6898        --------
6899        Scheduler.mark_key_in_memory
6900        """
6901        parent: SchedulerState = cast(SchedulerState, self)
6902        with log_errors():
6903            who_has = {
6904                k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items()
6905            }
6906            logger.debug("Update data %s", who_has)
6907
6908            for key, workers in who_has.items():
6909                ts: TaskState = parent._tasks.get(key)  # type: ignore
6910                if ts is None:
6911                    ts = parent.new_task(key, None, "memory")
6912                ts.state = "memory"
6913                ts_nbytes = nbytes.get(key, -1)
6914                if ts_nbytes >= 0:
6915                    ts.set_nbytes(ts_nbytes)
6916
6917                for w in workers:
6918                    ws: WorkerState = parent._workers_dv[w]
6919                    if ws not in ts._who_has:
6920                        parent.add_replica(ts, ws)
6921                self.report(
6922                    {"op": "key-in-memory", "key": key, "workers": list(workers)}
6923                )
6924
6925            if client:
6926                self.client_desires_keys(keys=list(who_has), client=client)
6927
6928    def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None):
6929        parent: SchedulerState = cast(SchedulerState, self)
6930        if ts is None:
6931            ts = parent._tasks.get(key)
6932        elif key is None:
6933            key = ts._key
6934        else:
6935            assert False, (key, ts)
6936            return
6937
6938        report_msg: dict
6939        if ts is None:
6940            report_msg = {"op": "cancelled-key", "key": key}
6941        else:
6942            report_msg = _task_to_report_msg(parent, ts)
6943        if report_msg is not None:
6944            self.report(report_msg, ts=ts, client=client)
6945
6946    async def feed(
6947        self, comm, function=None, setup=None, teardown=None, interval="1s", **kwargs
6948    ):
6949        """
6950        Provides a data Comm to external requester
6951
6952        Caution: this runs arbitrary Python code on the scheduler.  This should
6953        eventually be phased out.  It is mostly used by diagnostics.
6954        """
6955        if not dask.config.get("distributed.scheduler.pickle"):
6956            logger.warn(
6957                "Tried to call 'feed' route with custom functions, but "
6958                "pickle is disallowed.  Set the 'distributed.scheduler.pickle'"
6959                "config value to True to use the 'feed' route (this is mostly "
6960                "commonly used with progress bars)"
6961            )
6962            return
6963
6964        interval = parse_timedelta(interval)
6965        with log_errors():
6966            if function:
6967                function = pickle.loads(function)
6968            if setup:
6969                setup = pickle.loads(setup)
6970            if teardown:
6971                teardown = pickle.loads(teardown)
6972            state = setup(self) if setup else None
6973            if inspect.isawaitable(state):
6974                state = await state
6975            try:
6976                while self.status == Status.running:
6977                    if state is None:
6978                        response = function(self)
6979                    else:
6980                        response = function(self, state)
6981                    await comm.write(response)
6982                    await asyncio.sleep(interval)
6983            except OSError:
6984                pass
6985            finally:
6986                if teardown:
6987                    teardown(self, state)
6988
6989    def log_worker_event(self, worker=None, topic=None, msg=None):
6990        self.log_event(topic, msg)
6991
6992    def subscribe_worker_status(self, comm=None):
6993        WorkerStatusPlugin(self, comm)
6994        ident = self.identity()
6995        for v in ident["workers"].values():
6996            del v["metrics"]
6997            del v["last_seen"]
6998        return ident
6999
7000    def get_processing(self, comm=None, workers=None):
7001        parent: SchedulerState = cast(SchedulerState, self)
7002        ws: WorkerState
7003        ts: TaskState
7004        if workers is not None:
7005            workers = set(map(self.coerce_address, workers))
7006            return {
7007                w: [ts._key for ts in parent._workers_dv[w].processing] for w in workers
7008            }
7009        else:
7010            return {
7011                w: [ts._key for ts in ws._processing]
7012                for w, ws in parent._workers_dv.items()
7013            }
7014
7015    def get_who_has(self, comm=None, keys=None):
7016        parent: SchedulerState = cast(SchedulerState, self)
7017        ws: WorkerState
7018        ts: TaskState
7019        if keys is not None:
7020            return {
7021                k: [ws._address for ws in parent._tasks[k].who_has]
7022                if k in parent._tasks
7023                else []
7024                for k in keys
7025            }
7026        else:
7027            return {
7028                key: [ws._address for ws in ts._who_has]
7029                for key, ts in parent._tasks.items()
7030            }
7031
7032    def get_has_what(self, comm=None, workers=None):
7033        parent: SchedulerState = cast(SchedulerState, self)
7034        ws: WorkerState
7035        ts: TaskState
7036        if workers is not None:
7037            workers = map(self.coerce_address, workers)
7038            return {
7039                w: [ts._key for ts in parent._workers_dv[w].has_what]
7040                if w in parent._workers_dv
7041                else []
7042                for w in workers
7043            }
7044        else:
7045            return {
7046                w: [ts._key for ts in ws.has_what]
7047                for w, ws in parent._workers_dv.items()
7048            }
7049
7050    def get_ncores(self, comm=None, workers=None):
7051        parent: SchedulerState = cast(SchedulerState, self)
7052        ws: WorkerState
7053        if workers is not None:
7054            workers = map(self.coerce_address, workers)
7055            return {
7056                w: parent._workers_dv[w].nthreads
7057                for w in workers
7058                if w in parent._workers_dv
7059            }
7060        else:
7061            return {w: ws._nthreads for w, ws in parent._workers_dv.items()}
7062
7063    def get_ncores_running(self, comm=None, workers=None):
7064        parent: SchedulerState = cast(SchedulerState, self)
7065        ncores = self.get_ncores(workers=workers)
7066        return {
7067            w: n
7068            for w, n in ncores.items()
7069            if parent._workers_dv[w].status == Status.running
7070        }
7071
7072    async def get_call_stack(self, comm=None, keys=None):
7073        parent: SchedulerState = cast(SchedulerState, self)
7074        ts: TaskState
7075        dts: TaskState
7076        if keys is not None:
7077            stack = list(keys)
7078            processing = set()
7079            while stack:
7080                key = stack.pop()
7081                ts = parent._tasks[key]
7082                if ts._state == "waiting":
7083                    stack.extend([dts._key for dts in ts._dependencies])
7084                elif ts._state == "processing":
7085                    processing.add(ts)
7086
7087            workers = defaultdict(list)
7088            for ts in processing:
7089                if ts._processing_on:
7090                    workers[ts._processing_on.address].append(ts._key)
7091        else:
7092            workers = {w: None for w in parent._workers_dv}
7093
7094        if not workers:
7095            return {}
7096
7097        results = await asyncio.gather(
7098            *(self.rpc(w).call_stack(keys=v) for w, v in workers.items())
7099        )
7100        response = {w: r for w, r in zip(workers, results) if r}
7101        return response
7102
7103    def get_nbytes(self, comm=None, keys=None, summary=True):
7104        parent: SchedulerState = cast(SchedulerState, self)
7105        ts: TaskState
7106        with log_errors():
7107            if keys is not None:
7108                result = {k: parent._tasks[k].nbytes for k in keys}
7109            else:
7110                result = {
7111                    k: ts._nbytes for k, ts in parent._tasks.items() if ts._nbytes >= 0
7112                }
7113
7114            if summary:
7115                out = defaultdict(lambda: 0)
7116                for k, v in result.items():
7117                    out[key_split(k)] += v
7118                result = dict(out)
7119
7120            return result
7121
7122    def run_function(self, stream, function, args=(), kwargs={}, wait=True):
7123        """Run a function within this process
7124
7125        See Also
7126        --------
7127        Client.run_on_scheduler
7128        """
7129        from .worker import run
7130
7131        if not dask.config.get("distributed.scheduler.pickle"):
7132            raise ValueError(
7133                "Cannot run function as the scheduler has been explicitly disallowed from "
7134                "deserializing arbitrary bytestrings using pickle via the "
7135                "'distributed.scheduler.pickle' configuration setting."
7136            )
7137
7138        self.log_event("all", {"action": "run-function", "function": function})
7139        return run(self, stream, function=function, args=args, kwargs=kwargs, wait=wait)
7140
7141    def set_metadata(self, comm=None, keys=None, value=None):
7142        parent: SchedulerState = cast(SchedulerState, self)
7143        metadata = parent._task_metadata
7144        for key in keys[:-1]:
7145            if key not in metadata or not isinstance(metadata[key], (dict, list)):
7146                metadata[key] = {}
7147            metadata = metadata[key]
7148        metadata[keys[-1]] = value
7149
7150    def get_metadata(self, comm=None, keys=None, default=no_default):
7151        parent: SchedulerState = cast(SchedulerState, self)
7152        metadata = parent._task_metadata
7153        for key in keys[:-1]:
7154            metadata = metadata[key]
7155        try:
7156            return metadata[keys[-1]]
7157        except KeyError:
7158            if default != no_default:
7159                return default
7160            else:
7161                raise
7162
7163    def set_restrictions(self, comm=None, worker=None):
7164        ts: TaskState
7165        for key, restrictions in worker.items():
7166            ts = self.tasks[key]
7167            if isinstance(restrictions, str):
7168                restrictions = {restrictions}
7169            ts._worker_restrictions = set(restrictions)
7170
7171    def get_task_status(self, comm=None, keys=None):
7172        parent: SchedulerState = cast(SchedulerState, self)
7173        return {
7174            key: (parent._tasks[key].state if key in parent._tasks else None)
7175            for key in keys
7176        }
7177
7178    def get_task_stream(self, comm=None, start=None, stop=None, count=None):
7179        from distributed.diagnostics.task_stream import TaskStreamPlugin
7180
7181        if TaskStreamPlugin.name not in self.plugins:
7182            self.add_plugin(TaskStreamPlugin(self))
7183
7184        plugin = self.plugins[TaskStreamPlugin.name]
7185
7186        return plugin.collect(start=start, stop=stop, count=count)
7187
7188    def start_task_metadata(self, comm=None, name=None):
7189        plugin = CollectTaskMetaDataPlugin(scheduler=self, name=name)
7190        self.add_plugin(plugin)
7191
7192    def stop_task_metadata(self, comm=None, name=None):
7193        plugins = [
7194            p
7195            for p in list(self.plugins.values())
7196            if isinstance(p, CollectTaskMetaDataPlugin) and p.name == name
7197        ]
7198        if len(plugins) != 1:
7199            raise ValueError(
7200                "Expected to find exactly one CollectTaskMetaDataPlugin "
7201                f"with name {name} but found {len(plugins)}."
7202            )
7203
7204        plugin = plugins[0]
7205        self.remove_plugin(name=plugin.name)
7206        return {"metadata": plugin.metadata, "state": plugin.state}
7207
7208    async def register_worker_plugin(self, comm, plugin, name=None):
7209        """Registers a setup function, and call it on every worker"""
7210        self.worker_plugins[name] = plugin
7211
7212        responses = await self.broadcast(
7213            msg=dict(op="plugin-add", plugin=plugin, name=name)
7214        )
7215        return responses
7216
7217    async def unregister_worker_plugin(self, comm, name):
7218        """Unregisters a worker plugin"""
7219        try:
7220            self.worker_plugins.pop(name)
7221        except KeyError:
7222            raise ValueError(f"The worker plugin {name} does not exists")
7223
7224        responses = await self.broadcast(msg=dict(op="plugin-remove", name=name))
7225        return responses
7226
7227    async def register_nanny_plugin(self, comm, plugin, name=None):
7228        """Registers a setup function, and call it on every worker"""
7229        self.nanny_plugins[name] = plugin
7230
7231        responses = await self.broadcast(
7232            msg=dict(op="plugin_add", plugin=plugin, name=name),
7233            nanny=True,
7234        )
7235        return responses
7236
7237    async def unregister_nanny_plugin(self, comm, name):
7238        """Unregisters a worker plugin"""
7239        try:
7240            self.nanny_plugins.pop(name)
7241        except KeyError:
7242            raise ValueError(f"The nanny plugin {name} does not exists")
7243
7244        responses = await self.broadcast(
7245            msg=dict(op="plugin_remove", name=name), nanny=True
7246        )
7247        return responses
7248
7249    def transition(self, key, finish: str, *args, **kwargs):
7250        """Transition a key from its current state to the finish state
7251
7252        Examples
7253        --------
7254        >>> self.transition('x', 'waiting')
7255        {'x': 'processing'}
7256
7257        Returns
7258        -------
7259        Dictionary of recommendations for future transitions
7260
7261        See Also
7262        --------
7263        Scheduler.transitions: transitive version of this function
7264        """
7265        parent: SchedulerState = cast(SchedulerState, self)
7266        recommendations: dict
7267        worker_msgs: dict
7268        client_msgs: dict
7269        a: tuple = parent._transition(key, finish, *args, **kwargs)
7270        recommendations, client_msgs, worker_msgs = a
7271        self.send_all(client_msgs, worker_msgs)
7272        return recommendations
7273
7274    def transitions(self, recommendations: dict):
7275        """Process transitions until none are left
7276
7277        This includes feedback from previous transitions and continues until we
7278        reach a steady state
7279        """
7280        parent: SchedulerState = cast(SchedulerState, self)
7281        client_msgs: dict = {}
7282        worker_msgs: dict = {}
7283        parent._transitions(recommendations, client_msgs, worker_msgs)
7284        self.send_all(client_msgs, worker_msgs)
7285
7286    def story(self, *keys):
7287        """Get all transitions that touch one of the input keys"""
7288        keys = {key.key if isinstance(key, TaskState) else key for key in keys}
7289        return [
7290            t for t in self.transition_log if t[0] in keys or keys.intersection(t[3])
7291        ]
7292
7293    transition_story = story
7294
7295    def reschedule(self, key=None, worker=None):
7296        """Reschedule a task
7297
7298        Things may have shifted and this task may now be better suited to run
7299        elsewhere
7300        """
7301        parent: SchedulerState = cast(SchedulerState, self)
7302        ts: TaskState
7303        try:
7304            ts = parent._tasks[key]
7305        except KeyError:
7306            logger.warning(
7307                "Attempting to reschedule task {}, which was not "
7308                "found on the scheduler. Aborting reschedule.".format(key)
7309            )
7310            return
7311        if ts._state != "processing":
7312            return
7313        if worker and ts._processing_on.address != worker:
7314            return
7315        self.transitions({key: "released"})
7316
7317    #####################
7318    # Utility functions #
7319    #####################
7320
7321    def add_resources(self, comm=None, worker=None, resources=None):
7322        parent: SchedulerState = cast(SchedulerState, self)
7323        ws: WorkerState = parent._workers_dv[worker]
7324        if resources:
7325            ws._resources.update(resources)
7326        ws._used_resources = {}
7327        for resource, quantity in ws._resources.items():
7328            ws._used_resources[resource] = 0
7329            dr: dict = parent._resources.get(resource, None)
7330            if dr is None:
7331                parent._resources[resource] = dr = {}
7332            dr[worker] = quantity
7333        return "OK"
7334
7335    def remove_resources(self, worker):
7336        parent: SchedulerState = cast(SchedulerState, self)
7337        ws: WorkerState = parent._workers_dv[worker]
7338        for resource, quantity in ws._resources.items():
7339            dr: dict = parent._resources.get(resource, None)
7340            if dr is None:
7341                parent._resources[resource] = dr = {}
7342            del dr[worker]
7343
7344    def coerce_address(self, addr, resolve=True):
7345        """
7346        Coerce possible input addresses to canonical form.
7347        *resolve* can be disabled for testing with fake hostnames.
7348
7349        Handles strings, tuples, or aliases.
7350        """
7351        # XXX how many address-parsing routines do we have?
7352        parent: SchedulerState = cast(SchedulerState, self)
7353        if addr in parent._aliases:
7354            addr = parent._aliases[addr]
7355        if isinstance(addr, tuple):
7356            addr = unparse_host_port(*addr)
7357        if not isinstance(addr, str):
7358            raise TypeError(f"addresses should be strings or tuples, got {addr!r}")
7359
7360        if resolve:
7361            addr = resolve_address(addr)
7362        else:
7363            addr = normalize_address(addr)
7364
7365        return addr
7366
7367    def workers_list(self, workers):
7368        """
7369        List of qualifying workers
7370
7371        Takes a list of worker addresses or hostnames.
7372        Returns a list of all worker addresses that match
7373        """
7374        parent: SchedulerState = cast(SchedulerState, self)
7375        if workers is None:
7376            return list(parent._workers)
7377
7378        out = set()
7379        for w in workers:
7380            if ":" in w:
7381                out.add(w)
7382            else:
7383                out.update({ww for ww in parent._workers if w in ww})  # TODO: quadratic
7384        return list(out)
7385
7386    def start_ipython(self, comm=None):
7387        """Start an IPython kernel
7388
7389        Returns Jupyter connection info dictionary.
7390        """
7391        from ._ipython_utils import start_ipython
7392
7393        if self._ipython_kernel is None:
7394            self._ipython_kernel = start_ipython(
7395                ip=self.ip, ns={"scheduler": self}, log=logger
7396            )
7397        return self._ipython_kernel.get_connection_info()
7398
7399    async def get_profile(
7400        self,
7401        comm=None,
7402        workers=None,
7403        scheduler=False,
7404        server=False,
7405        merge_workers=True,
7406        start=None,
7407        stop=None,
7408        key=None,
7409    ):
7410        parent: SchedulerState = cast(SchedulerState, self)
7411        if workers is None:
7412            workers = parent._workers_dv
7413        else:
7414            workers = set(parent._workers_dv) & set(workers)
7415
7416        if scheduler:
7417            return profile.get_profile(self.io_loop.profile, start=start, stop=stop)
7418
7419        results = await asyncio.gather(
7420            *(
7421                self.rpc(w).profile(start=start, stop=stop, key=key, server=server)
7422                for w in workers
7423            ),
7424            return_exceptions=True,
7425        )
7426
7427        results = [r for r in results if not isinstance(r, Exception)]
7428
7429        if merge_workers:
7430            response = profile.merge(*results)
7431        else:
7432            response = dict(zip(workers, results))
7433        return response
7434
7435    async def get_profile_metadata(
7436        self,
7437        comm=None,
7438        workers=None,
7439        merge_workers=True,
7440        start=None,
7441        stop=None,
7442        profile_cycle_interval=None,
7443    ):
7444        parent: SchedulerState = cast(SchedulerState, self)
7445        dt = profile_cycle_interval or dask.config.get(
7446            "distributed.worker.profile.cycle"
7447        )
7448        dt = parse_timedelta(dt, default="ms")
7449
7450        if workers is None:
7451            workers = parent._workers_dv
7452        else:
7453            workers = set(parent._workers_dv) & set(workers)
7454        results = await asyncio.gather(
7455            *(self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers),
7456            return_exceptions=True,
7457        )
7458
7459        results = [r for r in results if not isinstance(r, Exception)]
7460        counts = [v["counts"] for v in results]
7461        counts = itertools.groupby(merge_sorted(*counts), lambda t: t[0] // dt * dt)
7462        counts = [(time, sum(pluck(1, group))) for time, group in counts]
7463
7464        keys = set()
7465        for v in results:
7466            for t, d in v["keys"]:
7467                for k in d:
7468                    keys.add(k)
7469        keys = {k: [] for k in keys}
7470
7471        groups1 = [v["keys"] for v in results]
7472        groups2 = list(merge_sorted(*groups1, key=first))
7473
7474        last = 0
7475        for t, d in groups2:
7476            tt = t // dt * dt
7477            if tt > last:
7478                last = tt
7479                for k, v in keys.items():
7480                    v.append([tt, 0])
7481            for k, v in d.items():
7482                keys[k][-1][1] += v
7483
7484        return {"counts": counts, "keys": keys}
7485
7486    async def performance_report(
7487        self, comm=None, start=None, last_count=None, code="", mode=None
7488    ):
7489        parent: SchedulerState = cast(SchedulerState, self)
7490        stop = time()
7491        # Profiles
7492        compute, scheduler, workers = await asyncio.gather(
7493            *[
7494                self.get_profile(start=start),
7495                self.get_profile(scheduler=True, start=start),
7496                self.get_profile(server=True, start=start),
7497            ]
7498        )
7499        from . import profile
7500
7501        def profile_to_figure(state):
7502            data = profile.plot_data(state)
7503            figure, source = profile.plot_figure(data, sizing_mode="stretch_both")
7504            return figure
7505
7506        compute, scheduler, workers = map(
7507            profile_to_figure, (compute, scheduler, workers)
7508        )
7509
7510        # Task stream
7511        task_stream = self.get_task_stream(start=start)
7512        total_tasks = len(task_stream)
7513        timespent = defaultdict(int)
7514        for d in task_stream:
7515            for x in d.get("startstops", []):
7516                timespent[x["action"]] += x["stop"] - x["start"]
7517        tasks_timings = ""
7518        for k in sorted(timespent.keys()):
7519            tasks_timings += f"\n<li> {k} time: {format_time(timespent[k])} </li>"
7520
7521        from .dashboard.components.scheduler import task_stream_figure
7522        from .diagnostics.task_stream import rectangles
7523
7524        rects = rectangles(task_stream)
7525        source, task_stream = task_stream_figure(sizing_mode="stretch_both")
7526        source.data.update(rects)
7527
7528        # Bandwidth
7529        from distributed.dashboard.components.scheduler import (
7530            BandwidthTypes,
7531            BandwidthWorkers,
7532        )
7533
7534        bandwidth_workers = BandwidthWorkers(self, sizing_mode="stretch_both")
7535        bandwidth_workers.update()
7536        bandwidth_types = BandwidthTypes(self, sizing_mode="stretch_both")
7537        bandwidth_types.update()
7538
7539        # System monitor
7540        from distributed.dashboard.components.shared import SystemMonitor
7541
7542        sysmon = SystemMonitor(self, last_count=last_count, sizing_mode="stretch_both")
7543        sysmon.update()
7544
7545        # Scheduler logs
7546        from distributed.dashboard.components.scheduler import SchedulerLogs
7547
7548        logs = SchedulerLogs(self)
7549
7550        from bokeh.models import Div, Panel, Tabs
7551
7552        import distributed
7553
7554        # HTML
7555        ws: WorkerState
7556        html = """
7557        <h1> Dask Performance Report </h1>
7558
7559        <i> Select different tabs on the top for additional information </i>
7560
7561        <h2> Duration: {time} </h2>
7562        <h2> Tasks Information </h2>
7563        <ul>
7564         <li> number of tasks: {ntasks} </li>
7565         {tasks_timings}
7566        </ul>
7567
7568        <h2> Scheduler Information </h2>
7569        <ul>
7570          <li> Address: {address} </li>
7571          <li> Workers: {nworkers} </li>
7572          <li> Threads: {threads} </li>
7573          <li> Memory: {memory} </li>
7574          <li> Dask Version: {dask_version} </li>
7575          <li> Dask.Distributed Version: {distributed_version} </li>
7576        </ul>
7577
7578        <h2> Calling Code </h2>
7579        <pre>
7580{code}
7581        </pre>
7582        """.format(
7583            time=format_time(stop - start),
7584            ntasks=total_tasks,
7585            tasks_timings=tasks_timings,
7586            address=self.address,
7587            nworkers=len(parent._workers_dv),
7588            threads=sum([ws._nthreads for ws in parent._workers_dv.values()]),
7589            memory=format_bytes(
7590                sum([ws._memory_limit for ws in parent._workers_dv.values()])
7591            ),
7592            code=code,
7593            dask_version=dask.__version__,
7594            distributed_version=distributed.__version__,
7595        )
7596        html = Div(
7597            text=html,
7598            style={
7599                "width": "100%",
7600                "height": "100%",
7601                "max-width": "1920px",
7602                "max-height": "1080px",
7603                "padding": "12px",
7604                "border": "1px solid lightgray",
7605                "box-shadow": "inset 1px 0 8px 0 lightgray",
7606                "overflow": "auto",
7607            },
7608        )
7609
7610        html = Panel(child=html, title="Summary")
7611        compute = Panel(child=compute, title="Worker Profile (compute)")
7612        workers = Panel(child=workers, title="Worker Profile (administrative)")
7613        scheduler = Panel(child=scheduler, title="Scheduler Profile (administrative)")
7614        task_stream = Panel(child=task_stream, title="Task Stream")
7615        bandwidth_workers = Panel(
7616            child=bandwidth_workers.root, title="Bandwidth (Workers)"
7617        )
7618        bandwidth_types = Panel(child=bandwidth_types.root, title="Bandwidth (Types)")
7619        system = Panel(child=sysmon.root, title="System")
7620        logs = Panel(child=logs.root, title="Scheduler Logs")
7621
7622        tabs = Tabs(
7623            tabs=[
7624                html,
7625                task_stream,
7626                system,
7627                logs,
7628                compute,
7629                workers,
7630                scheduler,
7631                bandwidth_workers,
7632                bandwidth_types,
7633            ]
7634        )
7635
7636        from bokeh.core.templates import get_env
7637        from bokeh.plotting import output_file, save
7638
7639        with tmpfile(extension=".html") as fn:
7640            output_file(filename=fn, title="Dask Performance Report", mode=mode)
7641            template_directory = os.path.join(
7642                os.path.dirname(os.path.abspath(__file__)), "dashboard", "templates"
7643            )
7644            template_environment = get_env()
7645            template_environment.loader.searchpath.append(template_directory)
7646            template = template_environment.get_template("performance_report.html")
7647            save(tabs, filename=fn, template=template)
7648
7649            with open(fn) as f:
7650                data = f.read()
7651
7652        return data
7653
7654    async def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False):
7655        results = await self.broadcast(
7656            msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny
7657        )
7658        return results
7659
7660    def log_event(self, name, msg):
7661        event = (time(), msg)
7662        if isinstance(name, (list, tuple)):
7663            for n in name:
7664                self.events[n].append(event)
7665                self.event_counts[n] += 1
7666                self._report_event(n, event)
7667        else:
7668            self.events[name].append(event)
7669            self.event_counts[name] += 1
7670            self._report_event(name, event)
7671
7672    def _report_event(self, name, event):
7673        for client in self.event_subscriber[name]:
7674            self.report(
7675                {
7676                    "op": "event",
7677                    "topic": name,
7678                    "event": event,
7679                },
7680                client=client,
7681            )
7682
7683    def subscribe_topic(self, topic, client):
7684        self.event_subscriber[topic].add(client)
7685
7686    def unsubscribe_topic(self, topic, client):
7687        self.event_subscriber[topic].discard(client)
7688
7689    def get_events(self, comm=None, topic=None):
7690        if topic is not None:
7691            return tuple(self.events[topic])
7692        else:
7693            return valmap(tuple, self.events)
7694
7695    async def get_worker_monitor_info(self, recent=False, starts=None):
7696        parent: SchedulerState = cast(SchedulerState, self)
7697        if starts is None:
7698            starts = {}
7699        results = await asyncio.gather(
7700            *(
7701                self.rpc(w).get_monitor_info(recent=recent, start=starts.get(w, 0))
7702                for w in parent._workers_dv
7703            )
7704        )
7705        return dict(zip(parent._workers_dv, results))
7706
7707    ###########
7708    # Cleanup #
7709    ###########
7710
7711    def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0):
7712        """Periodically reassess task duration time
7713
7714        The expected duration of a task can change over time.  Unfortunately we
7715        don't have a good constant-time way to propagate the effects of these
7716        changes out to the summaries that they affect, like the total expected
7717        runtime of each of the workers, or what tasks are stealable.
7718
7719        In this coroutine we walk through all of the workers and re-align their
7720        estimates with the current state of tasks.  We do this periodically
7721        rather than at every transition, and we only do it if the scheduler
7722        process isn't under load (using psutil.Process.cpu_percent()).  This
7723        lets us avoid this fringe optimization when we have better things to
7724        think about.
7725        """
7726        parent: SchedulerState = cast(SchedulerState, self)
7727        try:
7728            if self.status == Status.closed:
7729                return
7730            last = time()
7731            next_time = timedelta(seconds=0.1)
7732
7733            if self.proc.cpu_percent() < 50:
7734                workers: list = list(parent._workers.values())
7735                nworkers: Py_ssize_t = len(workers)
7736                i: Py_ssize_t
7737                for i in range(nworkers):
7738                    ws: WorkerState = workers[worker_index % nworkers]
7739                    worker_index += 1
7740                    try:
7741                        if ws is None or not ws._processing:
7742                            continue
7743                        _reevaluate_occupancy_worker(parent, ws)
7744                    finally:
7745                        del ws  # lose ref
7746
7747                    duration = time() - last
7748                    if duration > 0.005:  # 5ms since last release
7749                        next_time = timedelta(seconds=duration * 5)  # 25ms gap
7750                        break
7751
7752            self.loop.add_timeout(
7753                next_time, self.reevaluate_occupancy, worker_index=worker_index
7754            )
7755
7756        except Exception:
7757            logger.error("Error in reevaluate occupancy", exc_info=True)
7758            raise
7759
7760    async def check_worker_ttl(self):
7761        parent: SchedulerState = cast(SchedulerState, self)
7762        ws: WorkerState
7763        now = time()
7764        for ws in parent._workers_dv.values():
7765            if (ws._last_seen < now - self.worker_ttl) and (
7766                ws._last_seen < now - 10 * heartbeat_interval(len(parent._workers_dv))
7767            ):
7768                logger.warning(
7769                    "Worker failed to heartbeat within %s seconds. Closing: %s",
7770                    self.worker_ttl,
7771                    ws,
7772                )
7773                await self.remove_worker(address=ws._address)
7774
7775    def check_idle(self):
7776        parent: SchedulerState = cast(SchedulerState, self)
7777        ws: WorkerState
7778        if (
7779            any([ws._processing for ws in parent._workers_dv.values()])
7780            or parent._unrunnable
7781        ):
7782            self.idle_since = None
7783            return
7784        elif not self.idle_since:
7785            self.idle_since = time()
7786
7787        if time() > self.idle_since + self.idle_timeout:
7788            logger.info(
7789                "Scheduler closing after being idle for %s",
7790                format_time(self.idle_timeout),
7791            )
7792            self.loop.add_callback(self.close)
7793
7794    def adaptive_target(self, comm=None, target_duration=None):
7795        """Desired number of workers based on the current workload
7796
7797        This looks at the current running tasks and memory use, and returns a
7798        number of desired workers.  This is often used by adaptive scheduling.
7799
7800        Parameters
7801        ----------
7802        target_duration : str
7803            A desired duration of time for computations to take.  This affects
7804            how rapidly the scheduler will ask to scale.
7805
7806        See Also
7807        --------
7808        distributed.deploy.Adaptive
7809        """
7810        parent: SchedulerState = cast(SchedulerState, self)
7811        if target_duration is None:
7812            target_duration = dask.config.get("distributed.adaptive.target-duration")
7813        target_duration = parse_timedelta(target_duration)
7814
7815        # CPU
7816        cpu = math.ceil(
7817            parent._total_occupancy / target_duration
7818        )  # TODO: threads per worker
7819
7820        # Avoid a few long tasks from asking for many cores
7821        ws: WorkerState
7822        tasks_processing = 0
7823        for ws in parent._workers_dv.values():
7824            tasks_processing += len(ws._processing)
7825
7826            if tasks_processing > cpu:
7827                break
7828        else:
7829            cpu = min(tasks_processing, cpu)
7830
7831        if parent._unrunnable and not parent._workers_dv:
7832            cpu = max(1, cpu)
7833
7834        # add more workers if more than 60% of memory is used
7835        limit = sum([ws._memory_limit for ws in parent._workers_dv.values()])
7836        used = sum([ws._nbytes for ws in parent._workers_dv.values()])
7837        memory = 0
7838        if used > 0.6 * limit and limit > 0:
7839            memory = 2 * len(parent._workers_dv)
7840
7841        target = max(memory, cpu)
7842        if target >= len(parent._workers_dv):
7843            return target
7844        else:  # Scale down?
7845            to_close = self.workers_to_close()
7846            return len(parent._workers_dv) - len(to_close)
7847
7848
7849@cfunc
7850@exceptval(check=False)
7851def _remove_from_processing(
7852    state: SchedulerState, ts: TaskState
7853) -> str:  # -> str | None
7854    """
7855    Remove *ts* from the set of processing tasks.
7856
7857    See also ``Scheduler.set_duration_estimate``
7858    """
7859    ws: WorkerState = ts._processing_on
7860    ts._processing_on = None  # type: ignore
7861    w: str = ws._address
7862
7863    if w not in state._workers_dv:  # may have been removed
7864        return None  # type: ignore
7865
7866    duration: double = ws._processing.pop(ts)
7867    if not ws._processing:
7868        state._total_occupancy -= ws._occupancy
7869        ws._occupancy = 0
7870    else:
7871        state._total_occupancy -= duration
7872        ws._occupancy -= duration
7873
7874    state.check_idle_saturated(ws)
7875    state.release_resources(ts, ws)
7876
7877    return w
7878
7879
7880@cfunc
7881@exceptval(check=False)
7882def _add_to_memory(
7883    state: SchedulerState,
7884    ts: TaskState,
7885    ws: WorkerState,
7886    recommendations: dict,
7887    client_msgs: dict,
7888    type=None,
7889    typename: str = None,
7890):
7891    """
7892    Add *ts* to the set of in-memory tasks.
7893    """
7894    if state._validate:
7895        assert ts not in ws._has_what
7896
7897    state.add_replica(ts, ws)
7898
7899    deps: list = list(ts._dependents)
7900    if len(deps) > 1:
7901        deps.sort(key=operator.attrgetter("priority"), reverse=True)
7902
7903    dts: TaskState
7904    s: set
7905    for dts in deps:
7906        s = dts._waiting_on
7907        if ts in s:
7908            s.discard(ts)
7909            if not s:  # new task ready to run
7910                recommendations[dts._key] = "processing"
7911
7912    for dts in ts._dependencies:
7913        s = dts._waiters
7914        s.discard(ts)
7915        if not s and not dts._who_wants:
7916            recommendations[dts._key] = "released"
7917
7918    report_msg: dict = {}
7919    cs: ClientState
7920    if not ts._waiters and not ts._who_wants:
7921        recommendations[ts._key] = "released"
7922    else:
7923        report_msg["op"] = "key-in-memory"
7924        report_msg["key"] = ts._key
7925        if type is not None:
7926            report_msg["type"] = type
7927
7928        for cs in ts._who_wants:
7929            client_msgs[cs._client_key] = [report_msg]
7930
7931    ts.state = "memory"
7932    ts._type = typename  # type: ignore
7933    ts._group._types.add(typename)
7934
7935    cs = state._clients["fire-and-forget"]
7936    if ts in cs._wants_what:
7937        _client_releases_keys(
7938            state,
7939            cs=cs,
7940            keys=[ts._key],
7941            recommendations=recommendations,
7942        )
7943
7944
7945@cfunc
7946@exceptval(check=False)
7947def _propagate_forgotten(
7948    state: SchedulerState, ts: TaskState, recommendations: dict, worker_msgs: dict
7949):
7950    ts.state = "forgotten"
7951    key: str = ts._key
7952    dts: TaskState
7953    for dts in ts._dependents:
7954        dts._has_lost_dependencies = True
7955        dts._dependencies.remove(ts)
7956        dts._waiting_on.discard(ts)
7957        if dts._state not in ("memory", "erred"):
7958            # Cannot compute task anymore
7959            recommendations[dts._key] = "forgotten"
7960    ts._dependents.clear()
7961    ts._waiters.clear()
7962
7963    for dts in ts._dependencies:
7964        dts._dependents.remove(ts)
7965        dts._waiters.discard(ts)
7966        if not dts._dependents and not dts._who_wants:
7967            # Task not needed anymore
7968            assert dts is not ts
7969            recommendations[dts._key] = "forgotten"
7970    ts._dependencies.clear()
7971    ts._waiting_on.clear()
7972
7973    ws: WorkerState
7974    for ws in ts._who_has:
7975        w: str = ws._address
7976        if w in state._workers_dv:  # in case worker has died
7977            worker_msgs[w] = [
7978                {
7979                    "op": "free-keys",
7980                    "keys": [key],
7981                    "stimulus_id": f"propagate-forgotten-{time()}",
7982                }
7983            ]
7984    state.remove_all_replicas(ts)
7985
7986
7987@cfunc
7988@exceptval(check=False)
7989def _client_releases_keys(
7990    state: SchedulerState, keys: list, cs: ClientState, recommendations: dict
7991):
7992    """Remove keys from client desired list"""
7993    logger.debug("Client %s releases keys: %s", cs._client_key, keys)
7994    ts: TaskState
7995    for key in keys:
7996        ts = state._tasks.get(key)  # type: ignore
7997        if ts is not None and ts in cs._wants_what:
7998            cs._wants_what.remove(ts)
7999            ts._who_wants.remove(cs)
8000            if not ts._who_wants:
8001                if not ts._dependents:
8002                    # No live dependents, can forget
8003                    recommendations[ts._key] = "forgotten"
8004                elif ts._state != "erred" and not ts._waiters:
8005                    recommendations[ts._key] = "released"
8006
8007
8008@cfunc
8009@exceptval(check=False)
8010def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> dict:
8011    """Convert a single computational task to a message"""
8012    ws: WorkerState
8013    dts: TaskState
8014
8015    # FIXME: The duration attribute is not used on worker. We could safe ourselves the time to compute and submit this
8016    if duration < 0:
8017        duration = state.get_task_duration(ts)
8018
8019    msg: dict = {
8020        "op": "compute-task",
8021        "key": ts._key,
8022        "priority": ts._priority,
8023        "duration": duration,
8024        "stimulus_id": f"compute-task-{time()}",
8025        "who_has": {},
8026    }
8027    if ts._resource_restrictions:
8028        msg["resource_restrictions"] = ts._resource_restrictions
8029    if ts._actor:
8030        msg["actor"] = True
8031
8032    deps: set = ts._dependencies
8033    if deps:
8034        msg["who_has"] = {
8035            dts._key: [ws._address for ws in dts._who_has] for dts in deps
8036        }
8037        msg["nbytes"] = {dts._key: dts._nbytes for dts in deps}
8038
8039        if state._validate:
8040            assert all(msg["who_has"].values())
8041
8042    task = ts._run_spec
8043    if type(task) is dict:
8044        msg.update(task)
8045    else:
8046        msg["task"] = task
8047
8048    if ts._annotations:
8049        msg["annotations"] = ts._annotations
8050
8051    return msg
8052
8053
8054@cfunc
8055@exceptval(check=False)
8056def _task_to_report_msg(state: SchedulerState, ts: TaskState) -> dict:  # -> dict | None
8057    if ts._state == "forgotten":
8058        return {"op": "cancelled-key", "key": ts._key}
8059    elif ts._state == "memory":
8060        return {"op": "key-in-memory", "key": ts._key}
8061    elif ts._state == "erred":
8062        failing_ts: TaskState = ts._exception_blame
8063        return {
8064            "op": "task-erred",
8065            "key": ts._key,
8066            "exception": failing_ts._exception,
8067            "traceback": failing_ts._traceback,
8068        }
8069    else:
8070        return None  # type: ignore
8071
8072
8073@cfunc
8074@exceptval(check=False)
8075def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict:
8076    if ts._who_wants:
8077        report_msg: dict = _task_to_report_msg(state, ts)
8078        if report_msg is not None:
8079            cs: ClientState
8080            return {cs._client_key: [report_msg] for cs in ts._who_wants}
8081    return {}
8082
8083
8084@cfunc
8085@exceptval(check=False)
8086def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState):
8087    """See reevaluate_occupancy"""
8088    ts: TaskState
8089    old = ws._occupancy
8090    for ts in ws._processing:
8091        state.set_duration_estimate(ts, ws)
8092
8093    state.check_idle_saturated(ws)
8094    steal = state.extensions.get("stealing")
8095    if not steal:
8096        return
8097    if ws._occupancy > old * 1.3 or old > ws._occupancy * 1.3:
8098        for ts in ws._processing:
8099            steal.recalculate_cost(ts)
8100
8101
8102@cfunc
8103@exceptval(check=False)
8104def decide_worker(
8105    ts: TaskState, all_workers, valid_workers: set, objective
8106) -> WorkerState:  # -> WorkerState | None
8107    """
8108    Decide which worker should take task *ts*.
8109
8110    We choose the worker that has the data on which *ts* depends.
8111
8112    If several workers have dependencies then we choose the less-busy worker.
8113
8114    Optionally provide *valid_workers* of where jobs are allowed to occur
8115    (if all workers are allowed to take the task, pass None instead).
8116
8117    If the task requires data communication because no eligible worker has
8118    all the dependencies already, then we choose to minimize the number
8119    of bytes sent between workers.  This is determined by calling the
8120    *objective* function.
8121    """
8122    ws: WorkerState = None  # type: ignore
8123    wws: WorkerState
8124    dts: TaskState
8125    deps: set = ts._dependencies
8126    candidates: set
8127    assert all([dts._who_has for dts in deps])
8128    if ts._actor:
8129        candidates = set(all_workers)
8130    else:
8131        candidates = {wws for dts in deps for wws in dts._who_has}
8132    if valid_workers is None:
8133        if not candidates:
8134            candidates = set(all_workers)
8135    else:
8136        candidates &= valid_workers
8137        if not candidates:
8138            candidates = valid_workers
8139            if not candidates:
8140                if ts._loose_restrictions:
8141                    ws = decide_worker(ts, all_workers, None, objective)
8142                return ws
8143
8144    ncandidates: Py_ssize_t = len(candidates)
8145    if ncandidates == 0:
8146        pass
8147    elif ncandidates == 1:
8148        for ws in candidates:
8149            break
8150    else:
8151        ws = min(candidates, key=objective)
8152    return ws
8153
8154
8155def validate_task_state(ts: TaskState):
8156    """
8157    Validate the given TaskState.
8158    """
8159    ws: WorkerState
8160    dts: TaskState
8161
8162    assert ts._state in ALL_TASK_STATES or ts._state == "forgotten", ts
8163
8164    if ts._waiting_on:
8165        assert ts._waiting_on.issubset(ts._dependencies), (
8166            "waiting not subset of dependencies",
8167            str(ts._waiting_on),
8168            str(ts._dependencies),
8169        )
8170    if ts._waiters:
8171        assert ts._waiters.issubset(ts._dependents), (
8172            "waiters not subset of dependents",
8173            str(ts._waiters),
8174            str(ts._dependents),
8175        )
8176
8177    for dts in ts._waiting_on:
8178        assert not dts._who_has, ("waiting on in-memory dep", str(ts), str(dts))
8179        assert dts._state != "released", ("waiting on released dep", str(ts), str(dts))
8180    for dts in ts._dependencies:
8181        assert ts in dts._dependents, (
8182            "not in dependency's dependents",
8183            str(ts),
8184            str(dts),
8185            str(dts._dependents),
8186        )
8187        if ts._state in ("waiting", "processing"):
8188            assert dts in ts._waiting_on or dts._who_has, (
8189                "dep missing",
8190                str(ts),
8191                str(dts),
8192            )
8193        assert dts._state != "forgotten"
8194
8195    for dts in ts._waiters:
8196        assert dts._state in ("waiting", "processing"), (
8197            "waiter not in play",
8198            str(ts),
8199            str(dts),
8200        )
8201    for dts in ts._dependents:
8202        assert ts in dts._dependencies, (
8203            "not in dependent's dependencies",
8204            str(ts),
8205            str(dts),
8206            str(dts._dependencies),
8207        )
8208        assert dts._state != "forgotten"
8209
8210    assert (ts._processing_on is not None) == (ts._state == "processing")
8211    assert bool(ts._who_has) == (ts._state == "memory"), (ts, ts._who_has, ts._state)
8212
8213    if ts._state == "processing":
8214        assert all([dts._who_has for dts in ts._dependencies]), (
8215            "task processing without all deps",
8216            str(ts),
8217            str(ts._dependencies),
8218        )
8219        assert not ts._waiting_on
8220
8221    if ts._who_has:
8222        assert ts._waiters or ts._who_wants, (
8223            "unneeded task in memory",
8224            str(ts),
8225            str(ts._who_has),
8226        )
8227        if ts._run_spec:  # was computed
8228            assert ts._type
8229            assert isinstance(ts._type, str)
8230        assert not any([ts in dts._waiting_on for dts in ts._dependents])
8231        for ws in ts._who_has:
8232            assert ts in ws._has_what, (
8233                "not in who_has' has_what",
8234                str(ts),
8235                str(ws),
8236                str(ws._has_what),
8237            )
8238
8239    if ts._who_wants:
8240        cs: ClientState
8241        for cs in ts._who_wants:
8242            assert ts in cs._wants_what, (
8243                "not in who_wants' wants_what",
8244                str(ts),
8245                str(cs),
8246                str(cs._wants_what),
8247            )
8248
8249    if ts._actor:
8250        if ts._state == "memory":
8251            assert sum([ts in ws._actors for ws in ts._who_has]) == 1
8252        if ts._state == "processing":
8253            assert ts in ts._processing_on.actors
8254
8255
8256def validate_worker_state(ws: WorkerState):
8257    ts: TaskState
8258    for ts in ws._has_what:
8259        assert ws in ts._who_has, (
8260            "not in has_what' who_has",
8261            str(ws),
8262            str(ts),
8263            str(ts._who_has),
8264        )
8265
8266    for ts in ws._actors:
8267        assert ts._state in ("memory", "processing")
8268
8269
8270def validate_state(tasks, workers, clients):
8271    """
8272    Validate a current runtime state
8273
8274    This performs a sequence of checks on the entire graph, running in about
8275    linear time.  This raises assert errors if anything doesn't check out.
8276    """
8277    ts: TaskState
8278    for ts in tasks.values():
8279        validate_task_state(ts)
8280
8281    ws: WorkerState
8282    for ws in workers.values():
8283        validate_worker_state(ws)
8284
8285    cs: ClientState
8286    for cs in clients.values():
8287        for ts in cs._wants_what:
8288            assert cs in ts._who_wants, (
8289                "not in wants_what' who_wants",
8290                str(cs),
8291                str(ts),
8292                str(ts._who_wants),
8293            )
8294
8295
8296_round_robin = [0]
8297
8298
8299def heartbeat_interval(n):
8300    """
8301    Interval in seconds that we desire heartbeats based on number of workers
8302    """
8303    if n <= 10:
8304        return 0.5
8305    elif n < 50:
8306        return 1
8307    elif n < 200:
8308        return 2
8309    else:
8310        # no more than 200 hearbeats a second scaled by workers
8311        return n / 200 + 1
8312
8313
8314class KilledWorker(Exception):
8315    def __init__(self, task, last_worker):
8316        super().__init__(task, last_worker)
8317        self.task = task
8318        self.last_worker = last_worker
8319
8320
8321class WorkerStatusPlugin(SchedulerPlugin):
8322    """
8323    An plugin to share worker status with a remote observer
8324
8325    This is used in cluster managers to keep updated about the status of the
8326    scheduler.
8327    """
8328
8329    name = "worker-status"
8330
8331    def __init__(self, scheduler, comm):
8332        self.bcomm = BatchedSend(interval="5ms")
8333        self.bcomm.start(comm)
8334
8335        self.scheduler = scheduler
8336        self.scheduler.add_plugin(self)
8337
8338    def add_worker(self, worker=None, **kwargs):
8339        ident = self.scheduler.workers[worker].identity()
8340        del ident["metrics"]
8341        del ident["last_seen"]
8342        try:
8343            self.bcomm.send(["add", {"workers": {worker: ident}}])
8344        except CommClosedError:
8345            self.scheduler.remove_plugin(name=self.name)
8346
8347    def remove_worker(self, worker=None, **kwargs):
8348        try:
8349            self.bcomm.send(["remove", worker])
8350        except CommClosedError:
8351            self.scheduler.remove_plugin(name=self.name)
8352
8353    def teardown(self):
8354        self.bcomm.close()
8355
8356
8357class CollectTaskMetaDataPlugin(SchedulerPlugin):
8358    def __init__(self, scheduler, name):
8359        self.scheduler = scheduler
8360        self.name = name
8361        self.keys = set()
8362        self.metadata = {}
8363        self.state = {}
8364
8365    def update_graph(self, scheduler, dsk=None, keys=None, restrictions=None, **kwargs):
8366        self.keys.update(keys)
8367
8368    def transition(self, key, start, finish, *args, **kwargs):
8369        if finish == "memory" or finish == "erred":
8370            ts: TaskState = self.scheduler.tasks.get(key)
8371            if ts is not None and ts._key in self.keys:
8372                self.metadata[key] = ts._metadata
8373                self.state[key] = finish
8374                self.keys.discard(key)
8375