1import asyncio 2import heapq 3import inspect 4import itertools 5import json 6import logging 7import math 8import operator 9import os 10import random 11import sys 12import uuid 13import warnings 14import weakref 15from collections import defaultdict, deque 16from collections.abc import ( 17 Callable, 18 Collection, 19 Hashable, 20 Iterable, 21 Iterator, 22 Mapping, 23 Set, 24) 25from contextlib import suppress 26from datetime import timedelta 27from functools import partial 28from numbers import Number 29from typing import Any, ClassVar, Container 30from typing import cast as pep484_cast 31 32import psutil 33from sortedcontainers import SortedDict, SortedSet 34from tlz import ( 35 compose, 36 first, 37 groupby, 38 merge, 39 merge_sorted, 40 merge_with, 41 pluck, 42 second, 43 valmap, 44) 45from tornado.ioloop import IOLoop, PeriodicCallback 46 47import dask 48from dask.highlevelgraph import HighLevelGraph 49from dask.utils import format_bytes, format_time, parse_bytes, parse_timedelta, tmpfile 50from dask.widgets import get_template 51 52from distributed.utils import recursive_to_dict 53 54from . import preloading, profile 55from . import versions as version_module 56from .active_memory_manager import ActiveMemoryManagerExtension 57from .batched import BatchedSend 58from .comm import ( 59 Comm, 60 get_address_host, 61 normalize_address, 62 resolve_address, 63 unparse_host_port, 64) 65from .comm.addressing import addresses_from_user_args 66from .core import CommClosedError, Status, clean_exception, rpc, send_recv 67from .diagnostics.memory_sampler import MemorySamplerExtension 68from .diagnostics.plugin import SchedulerPlugin, _get_plugin_name 69from .event import EventExtension 70from .http import get_handlers 71from .lock import LockExtension 72from .metrics import time 73from .multi_lock import MultiLockExtension 74from .node import ServerNode 75from .proctitle import setproctitle 76from .protocol.pickle import loads 77from .publish import PublishExtension 78from .pubsub import PubSubSchedulerExtension 79from .queues import QueueExtension 80from .recreate_tasks import ReplayTaskScheduler 81from .security import Security 82from .semaphore import SemaphoreExtension 83from .stealing import WorkStealing 84from .utils import ( 85 All, 86 TimeoutError, 87 empty_context, 88 get_fileno_limit, 89 key_split, 90 key_split_group, 91 log_errors, 92 no_default, 93 validate_key, 94) 95from .utils_comm import gather_from_workers, retry_operation, scatter_to_workers 96from .utils_perf import disable_gc_diagnosis, enable_gc_diagnosis 97from .variable import VariableExtension 98 99try: 100 from cython import compiled 101except ImportError: 102 compiled = False 103 104if compiled: 105 from cython import ( 106 Py_hash_t, 107 Py_ssize_t, 108 bint, 109 cast, 110 ccall, 111 cclass, 112 cfunc, 113 declare, 114 double, 115 exceptval, 116 final, 117 inline, 118 nogil, 119 ) 120else: 121 from ctypes import c_double as double 122 from ctypes import c_ssize_t as Py_hash_t 123 from ctypes import c_ssize_t as Py_ssize_t 124 125 bint = bool 126 127 def cast(T, v, *a, **k): 128 return v 129 130 def ccall(func): 131 return func 132 133 def cclass(cls): 134 return cls 135 136 def cfunc(func): 137 return func 138 139 def declare(*a, **k): 140 if len(a) == 2: 141 return a[1] 142 else: 143 pass 144 145 def exceptval(*a, **k): 146 def wrapper(func): 147 return func 148 149 return wrapper 150 151 def final(cls): 152 return cls 153 154 def inline(func): 155 return func 156 157 def nogil(func): 158 return func 159 160 161if sys.version_info < (3, 8): 162 try: 163 import pickle5 as pickle 164 except ImportError: 165 import pickle 166else: 167 import pickle 168 169 170logger = logging.getLogger(__name__) 171 172 173LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") 174DEFAULT_DATA_SIZE = declare( 175 Py_ssize_t, parse_bytes(dask.config.get("distributed.scheduler.default-data-size")) 176) 177 178DEFAULT_EXTENSIONS = [ 179 LockExtension, 180 MultiLockExtension, 181 PublishExtension, 182 ReplayTaskScheduler, 183 QueueExtension, 184 VariableExtension, 185 PubSubSchedulerExtension, 186 SemaphoreExtension, 187 EventExtension, 188 ActiveMemoryManagerExtension, 189 MemorySamplerExtension, 190] 191 192ALL_TASK_STATES = declare( 193 set, {"released", "waiting", "no-worker", "processing", "erred", "memory"} 194) 195globals()["ALL_TASK_STATES"] = ALL_TASK_STATES 196COMPILED = declare(bint, compiled) 197globals()["COMPILED"] = COMPILED 198 199 200@final 201@cclass 202class ClientState: 203 """ 204 A simple object holding information about a client. 205 206 .. attribute:: client_key: str 207 208 A unique identifier for this client. This is generally an opaque 209 string generated by the client itself. 210 211 .. attribute:: wants_what: {TaskState} 212 213 A set of tasks this client wants kept in memory, so that it can 214 download its result when desired. This is the reverse mapping of 215 :class:`TaskState.who_wants`. 216 217 Tasks are typically removed from this set when the corresponding 218 object in the client's space (for example a ``Future`` or a Dask 219 collection) gets garbage-collected. 220 221 """ 222 223 _client_key: str 224 _hash: Py_hash_t 225 _wants_what: set 226 _last_seen: double 227 _versions: dict 228 229 __slots__ = ("_client_key", "_hash", "_wants_what", "_last_seen", "_versions") 230 231 def __init__(self, client: str, versions: dict = None): 232 self._client_key = client 233 self._hash = hash(client) 234 self._wants_what = set() 235 self._last_seen = time() 236 self._versions = versions or {} 237 238 def __hash__(self): 239 return self._hash 240 241 def __eq__(self, other): 242 typ_self: type = type(self) 243 typ_other: type = type(other) 244 if typ_self == typ_other: 245 other_cs: ClientState = other 246 return self._client_key == other_cs._client_key 247 else: 248 return False 249 250 def __repr__(self): 251 return "<Client '%s'>" % self._client_key 252 253 def __str__(self): 254 return self._client_key 255 256 @property 257 def client_key(self): 258 return self._client_key 259 260 @property 261 def wants_what(self): 262 return self._wants_what 263 264 @property 265 def last_seen(self): 266 return self._last_seen 267 268 @property 269 def versions(self): 270 return self._versions 271 272 273@final 274@cclass 275class MemoryState: 276 """Memory readings on a worker or on the whole cluster. 277 278 managed 279 Sum of the output of sizeof() for all dask keys held by the worker, both in 280 memory and spilled to disk 281 managed_in_memory 282 Sum of the output of sizeof() for the dask keys held in RAM 283 managed_spilled 284 Sum of the output of sizeof() for the dask keys spilled to the hard drive. 285 Note that this is the size in memory; serialized size may be different. 286 process 287 Total RSS memory measured by the OS on the worker process. 288 This is always exactly equal to managed_in_memory + unmanaged. 289 unmanaged 290 process - managed_in_memory. This is the sum of 291 292 - Python interpreter and modules 293 - global variables 294 - memory temporarily allocated by the dask tasks that are currently running 295 - memory fragmentation 296 - memory leaks 297 - memory not yet garbage collected 298 - memory not yet free()'d by the Python memory manager to the OS 299 300 unmanaged_old 301 Minimum of the 'unmanaged' measures over the last 302 ``distributed.memory.recent-to-old-time`` seconds 303 unmanaged_recent 304 unmanaged - unmanaged_old; in other words process memory that has been recently 305 allocated but is not accounted for by dask; hopefully it's mostly a temporary 306 spike. 307 optimistic 308 managed_in_memory + unmanaged_old; in other words the memory held long-term by 309 the process under the hopeful assumption that all unmanaged_recent memory is a 310 temporary spike 311 """ 312 313 __slots__ = ("_process", "_managed_in_memory", "_managed_spilled", "_unmanaged_old") 314 315 _process: Py_ssize_t 316 _managed_in_memory: Py_ssize_t 317 _managed_spilled: Py_ssize_t 318 _unmanaged_old: Py_ssize_t 319 320 def __init__( 321 self, 322 *, 323 process: Py_ssize_t, 324 unmanaged_old: Py_ssize_t, 325 managed: Py_ssize_t, 326 managed_spilled: Py_ssize_t, 327 ): 328 # Some data arrives with the heartbeat, some other arrives in realtime as the 329 # tasks progress. Also, sizeof() is not guaranteed to return correct results. 330 # This can cause glitches where a partial measure is larger than the whole, so 331 # we need to force all numbers to add up exactly by definition. 332 self._process = process 333 self._managed_spilled = min(managed_spilled, managed) 334 # Subtractions between unsigned ints guaranteed by construction to be >= 0 335 self._managed_in_memory = min(managed - self._managed_spilled, process) 336 self._unmanaged_old = min(unmanaged_old, process - self._managed_in_memory) 337 338 @property 339 def process(self) -> Py_ssize_t: 340 return self._process 341 342 @property 343 def managed_in_memory(self) -> Py_ssize_t: 344 return self._managed_in_memory 345 346 @property 347 def managed_spilled(self) -> Py_ssize_t: 348 return self._managed_spilled 349 350 @property 351 def unmanaged_old(self) -> Py_ssize_t: 352 return self._unmanaged_old 353 354 @classmethod 355 def sum(cls, *infos: "MemoryState") -> "MemoryState": 356 out = MemoryState(process=0, unmanaged_old=0, managed=0, managed_spilled=0) 357 ms: MemoryState 358 for ms in infos: 359 out._process += ms._process 360 out._managed_spilled += ms._managed_spilled 361 out._managed_in_memory += ms._managed_in_memory 362 out._unmanaged_old += ms._unmanaged_old 363 return out 364 365 @property 366 def managed(self) -> Py_ssize_t: 367 return self._managed_in_memory + self._managed_spilled 368 369 @property 370 def unmanaged(self) -> Py_ssize_t: 371 # This is never negative thanks to __init__ 372 return self._process - self._managed_in_memory 373 374 @property 375 def unmanaged_recent(self) -> Py_ssize_t: 376 # This is never negative thanks to __init__ 377 return self._process - self._managed_in_memory - self._unmanaged_old 378 379 @property 380 def optimistic(self) -> Py_ssize_t: 381 return self._managed_in_memory + self._unmanaged_old 382 383 def __repr__(self) -> str: 384 return ( 385 f"Process memory (RSS) : {format_bytes(self._process)}\n" 386 f" - managed by Dask : {format_bytes(self._managed_in_memory)}\n" 387 f" - unmanaged (old) : {format_bytes(self._unmanaged_old)}\n" 388 f" - unmanaged (recent): {format_bytes(self.unmanaged_recent)}\n" 389 f"Spilled to disk : {format_bytes(self._managed_spilled)}\n" 390 ) 391 392 393@final 394@cclass 395class WorkerState: 396 """ 397 A simple object holding information about a worker. 398 399 .. attribute:: address: str 400 401 This worker's unique key. This can be its connected address 402 (such as ``'tcp://127.0.0.1:8891'``) or an alias (such as ``'alice'``). 403 404 .. attribute:: processing: {TaskState: cost} 405 406 A dictionary of tasks that have been submitted to this worker. 407 Each task state is associated with the expected cost in seconds 408 of running that task, summing both the task's expected computation 409 time and the expected communication time of its result. 410 411 If a task is already executing on the worker and the excecution time is 412 twice the learned average TaskGroup duration, this will be set to twice 413 the current executing time. If the task is unknown, the default task 414 duration is used instead of the TaskGroup average. 415 416 Multiple tasks may be submitted to a worker in advance and the worker 417 will run them eventually, depending on its execution resources 418 (but see :doc:`work-stealing`). 419 420 All the tasks here are in the "processing" state. 421 422 This attribute is kept in sync with :attr:`TaskState.processing_on`. 423 424 .. attribute:: executing: {TaskState: duration} 425 426 A dictionary of tasks that are currently being run on this worker. 427 Each task state is asssociated with the duration in seconds which 428 the task has been running. 429 430 .. attribute:: has_what: {TaskState} 431 432 An insertion-sorted set-like of tasks which currently reside on this worker. 433 All the tasks here are in the "memory" state. 434 435 This is the reverse mapping of :class:`TaskState.who_has`. 436 437 .. attribute:: nbytes: int 438 439 The total memory size, in bytes, used by the tasks this worker 440 holds in memory (i.e. the tasks in this worker's :attr:`has_what`). 441 442 .. attribute:: nthreads: int 443 444 The number of CPU threads made available on this worker. 445 446 .. attribute:: resources: {str: Number} 447 448 The available resources on this worker like ``{'gpu': 2}``. 449 These are abstract quantities that constrain certain tasks from 450 running at the same time on this worker. 451 452 .. attribute:: used_resources: {str: Number} 453 454 The sum of each resource used by all tasks allocated to this worker. 455 The numbers in this dictionary can only be less or equal than 456 those in this worker's :attr:`resources`. 457 458 .. attribute:: occupancy: double 459 460 The total expected runtime, in seconds, of all tasks currently 461 processing on this worker. This is the sum of all the costs in 462 this worker's :attr:`processing` dictionary. 463 464 .. attribute:: status: Status 465 466 Read-only worker status, synced one way from the remote Worker object 467 468 .. attribute:: nanny: str 469 470 Address of the associated Nanny, if present 471 472 .. attribute:: last_seen: Py_ssize_t 473 474 The last time we received a heartbeat from this worker, in local 475 scheduler time. 476 477 .. attribute:: actors: {TaskState} 478 479 A set of all TaskStates on this worker that are actors. This only 480 includes those actors whose state actually lives on this worker, not 481 actors to which this worker has a reference. 482 483 """ 484 485 # XXX need a state field to signal active/removed? 486 487 _actors: set 488 _address: str 489 _bandwidth: double 490 _executing: dict 491 _extra: dict 492 # _has_what is a dict with all values set to None as rebalance() relies on the 493 # property of Python >=3.7 dicts to be insertion-sorted. 494 _has_what: dict 495 _hash: Py_hash_t 496 _last_seen: double 497 _local_directory: str 498 _memory_limit: Py_ssize_t 499 _memory_other_history: "deque[tuple[float, Py_ssize_t]]" 500 _memory_unmanaged_old: Py_ssize_t 501 _metrics: dict 502 _name: object 503 _nanny: str 504 _nbytes: Py_ssize_t 505 _nthreads: Py_ssize_t 506 _occupancy: double 507 _pid: Py_ssize_t 508 _processing: dict 509 _resources: dict 510 _services: dict 511 _status: Status 512 _time_delay: double 513 _used_resources: dict 514 _versions: dict 515 516 __slots__ = ( 517 "_actors", 518 "_address", 519 "_bandwidth", 520 "_extra", 521 "_executing", 522 "_has_what", 523 "_hash", 524 "_last_seen", 525 "_local_directory", 526 "_memory_limit", 527 "_memory_other_history", 528 "_memory_unmanaged_old", 529 "_metrics", 530 "_name", 531 "_nanny", 532 "_nbytes", 533 "_nthreads", 534 "_occupancy", 535 "_pid", 536 "_processing", 537 "_resources", 538 "_services", 539 "_status", 540 "_time_delay", 541 "_used_resources", 542 "_versions", 543 ) 544 545 def __init__( 546 self, 547 *, 548 address: str, 549 status: Status, 550 pid: Py_ssize_t, 551 name: object, 552 nthreads: Py_ssize_t = 0, 553 memory_limit: Py_ssize_t, 554 local_directory: str, 555 nanny: str, 556 services: "dict | None" = None, 557 versions: "dict | None" = None, 558 extra: "dict | None" = None, 559 ): 560 self._address = address 561 self._pid = pid 562 self._name = name 563 self._nthreads = nthreads 564 self._memory_limit = memory_limit 565 self._local_directory = local_directory 566 self._services = services or {} 567 self._versions = versions or {} 568 self._nanny = nanny 569 self._status = status 570 571 self._hash = hash(address) 572 self._nbytes = 0 573 self._occupancy = 0 574 self._memory_unmanaged_old = 0 575 self._memory_other_history = deque() 576 self._metrics = {} 577 self._last_seen = 0 578 self._time_delay = 0 579 self._bandwidth = float( 580 parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) 581 ) 582 583 self._actors = set() 584 self._has_what = {} 585 self._processing = {} 586 self._executing = {} 587 self._resources = {} 588 self._used_resources = {} 589 590 self._extra = extra or {} 591 592 def __hash__(self): 593 return self._hash 594 595 def __eq__(self, other): 596 typ_self: type = type(self) 597 typ_other: type = type(other) 598 if typ_self == typ_other: 599 other_ws: WorkerState = other 600 return self._address == other_ws._address 601 else: 602 return False 603 604 @property 605 def actors(self): 606 return self._actors 607 608 @property 609 def address(self) -> str: 610 return self._address 611 612 @property 613 def bandwidth(self): 614 return self._bandwidth 615 616 @property 617 def executing(self): 618 return self._executing 619 620 @property 621 def extra(self): 622 return self._extra 623 624 @property 625 def has_what(self) -> "Set[TaskState]": 626 return self._has_what.keys() 627 628 @property 629 def host(self): 630 return get_address_host(self._address) 631 632 @property 633 def last_seen(self): 634 return self._last_seen 635 636 @property 637 def local_directory(self): 638 return self._local_directory 639 640 @property 641 def memory_limit(self): 642 return self._memory_limit 643 644 @property 645 def metrics(self): 646 return self._metrics 647 648 @property 649 def memory(self) -> MemoryState: 650 return MemoryState( 651 # metrics["memory"] is None if the worker sent a heartbeat before its 652 # SystemMonitor ever had a chance to run 653 process=self._metrics["memory"] or 0, 654 managed=self._nbytes, 655 managed_spilled=self._metrics["spilled_nbytes"], 656 unmanaged_old=self._memory_unmanaged_old, 657 ) 658 659 @property 660 def name(self): 661 return self._name 662 663 @property 664 def nanny(self): 665 return self._nanny 666 667 @property 668 def nbytes(self): 669 return self._nbytes 670 671 @nbytes.setter 672 def nbytes(self, v: Py_ssize_t): 673 self._nbytes = v 674 675 @property 676 def nthreads(self): 677 return self._nthreads 678 679 @property 680 def occupancy(self): 681 return self._occupancy 682 683 @occupancy.setter 684 def occupancy(self, v: double): 685 self._occupancy = v 686 687 @property 688 def pid(self): 689 return self._pid 690 691 @property 692 def processing(self): 693 return self._processing 694 695 @property 696 def resources(self): 697 return self._resources 698 699 @property 700 def services(self): 701 return self._services 702 703 @property 704 def status(self): 705 return self._status 706 707 @status.setter 708 def status(self, new_status): 709 if not isinstance(new_status, Status): 710 raise TypeError(f"Expected Status; got {new_status!r}") 711 self._status = new_status 712 713 @property 714 def time_delay(self): 715 return self._time_delay 716 717 @property 718 def used_resources(self): 719 return self._used_resources 720 721 @property 722 def versions(self): 723 return self._versions 724 725 @ccall 726 def clean(self): 727 """Return a version of this object that is appropriate for serialization""" 728 ws: WorkerState = WorkerState( 729 address=self._address, 730 status=self._status, 731 pid=self._pid, 732 name=self._name, 733 nthreads=self._nthreads, 734 memory_limit=self._memory_limit, 735 local_directory=self._local_directory, 736 services=self._services, 737 nanny=self._nanny, 738 extra=self._extra, 739 ) 740 ts: TaskState 741 ws._processing = {ts._key: cost for ts, cost in self._processing.items()} 742 ws._executing = {ts._key: duration for ts, duration in self._executing.items()} 743 return ws 744 745 def __repr__(self): 746 return "<WorkerState %r, name: %s, status: %s, memory: %d, processing: %d>" % ( 747 self._address, 748 self._name, 749 self._status.name, 750 len(self._has_what), 751 len(self._processing), 752 ) 753 754 def _repr_html_(self): 755 return get_template("worker_state.html.j2").render( 756 address=self.address, 757 name=self.name, 758 status=self.status.name, 759 has_what=self._has_what, 760 processing=self.processing, 761 ) 762 763 @ccall 764 @exceptval(check=False) 765 def identity(self) -> dict: 766 return { 767 "type": "Worker", 768 "id": self._name, 769 "host": self.host, 770 "resources": self._resources, 771 "local_directory": self._local_directory, 772 "name": self._name, 773 "nthreads": self._nthreads, 774 "memory_limit": self._memory_limit, 775 "last_seen": self._last_seen, 776 "services": self._services, 777 "metrics": self._metrics, 778 "nanny": self._nanny, 779 **self._extra, 780 } 781 782 @property 783 def ncores(self): 784 warnings.warn("WorkerState.ncores has moved to WorkerState.nthreads") 785 return self._nthreads 786 787 788@final 789@cclass 790class Computation: 791 """ 792 Collection tracking a single compute or persist call 793 794 See also 795 -------- 796 TaskPrefix 797 TaskGroup 798 TaskState 799 """ 800 801 _start: double 802 _groups: set 803 _code: object 804 _id: object 805 806 def __init__(self): 807 self._start = time() 808 self._groups = set() 809 self._code = SortedSet() 810 self._id = uuid.uuid4() 811 812 @property 813 def code(self): 814 return self._code 815 816 @property 817 def start(self): 818 return self._start 819 820 @property 821 def stop(self): 822 if self.groups: 823 return max(tg.stop for tg in self.groups) 824 else: 825 return -1 826 827 @property 828 def states(self): 829 tg: TaskGroup 830 return merge_with(sum, [tg._states for tg in self._groups]) 831 832 @property 833 def groups(self): 834 return self._groups 835 836 def __repr__(self): 837 return ( 838 f"<Computation {self._id}: " 839 + "Tasks: " 840 + ", ".join( 841 "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v 842 ) 843 + ">" 844 ) 845 846 def _repr_html_(self): 847 return get_template("computation.html.j2").render( 848 id=self._id, 849 start=self.start, 850 stop=self.stop, 851 groups=self.groups, 852 states=self.states, 853 code=self.code, 854 ) 855 856 857@final 858@cclass 859class TaskPrefix: 860 """Collection tracking all tasks within a group 861 862 Keys often have a structure like ``("x-123", 0)`` 863 A group takes the first section, like ``"x"`` 864 865 .. attribute:: name: str 866 867 The name of a group of tasks. 868 For a task like ``("x-123", 0)`` this is the text ``"x"`` 869 870 .. attribute:: states: Dict[str, int] 871 872 The number of tasks in each state, 873 like ``{"memory": 10, "processing": 3, "released": 4, ...}`` 874 875 .. attribute:: duration_average: float 876 877 An exponentially weighted moving average duration of all tasks with this prefix 878 879 .. attribute:: suspicious: int 880 881 Numbers of times a task was marked as suspicious with this prefix 882 883 884 See Also 885 -------- 886 TaskGroup 887 """ 888 889 _name: str 890 _all_durations: "defaultdict[str, float]" 891 _duration_average: double 892 _suspicious: Py_ssize_t 893 _groups: list 894 895 def __init__(self, name: str): 896 self._name = name 897 self._groups = [] 898 899 # store timings for each prefix-action 900 self._all_durations = defaultdict(float) 901 902 task_durations = dask.config.get("distributed.scheduler.default-task-durations") 903 if self._name in task_durations: 904 self._duration_average = parse_timedelta(task_durations[self._name]) 905 else: 906 self._duration_average = -1 907 self._suspicious = 0 908 909 @property 910 def name(self) -> str: 911 return self._name 912 913 @property 914 def all_durations(self) -> "defaultdict[str, float]": 915 return self._all_durations 916 917 @ccall 918 @exceptval(check=False) 919 def add_duration(self, action: str, start: double, stop: double): 920 duration = stop - start 921 self._all_durations[action] += duration 922 if action == "compute": 923 old = self._duration_average 924 if old < 0: 925 self._duration_average = duration 926 else: 927 self._duration_average = 0.5 * duration + 0.5 * old 928 929 @property 930 def duration_average(self) -> double: 931 return self._duration_average 932 933 @property 934 def suspicious(self) -> Py_ssize_t: 935 return self._suspicious 936 937 @property 938 def groups(self): 939 return self._groups 940 941 @property 942 def states(self): 943 tg: TaskGroup 944 return merge_with(sum, [tg._states for tg in self._groups]) 945 946 @property 947 def active(self) -> "list[TaskGroup]": 948 tg: TaskGroup 949 return [ 950 tg 951 for tg in self._groups 952 if any([v != 0 for k, v in tg._states.items() if k != "forgotten"]) 953 ] 954 955 @property 956 def active_states(self): 957 tg: TaskGroup 958 return merge_with(sum, [tg._states for tg in self.active]) 959 960 def __repr__(self): 961 return ( 962 "<" 963 + self._name 964 + ": " 965 + ", ".join( 966 "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v 967 ) 968 + ">" 969 ) 970 971 @property 972 def nbytes_total(self): 973 tg: TaskGroup 974 return sum([tg._nbytes_total for tg in self._groups]) 975 976 def __len__(self): 977 return sum(map(len, self._groups)) 978 979 @property 980 def duration(self): 981 tg: TaskGroup 982 return sum([tg._duration for tg in self._groups]) 983 984 @property 985 def types(self): 986 tg: TaskGroup 987 return set().union(*[tg._types for tg in self._groups]) 988 989 990@final 991@cclass 992class TaskGroup: 993 """Collection tracking all tasks within a group 994 995 Keys often have a structure like ``("x-123", 0)`` 996 A group takes the first section, like ``"x-123"`` 997 998 .. attribute:: name: str 999 1000 The name of a group of tasks. 1001 For a task like ``("x-123", 0)`` this is the text ``"x-123"`` 1002 1003 .. attribute:: states: Dict[str, int] 1004 1005 The number of tasks in each state, 1006 like ``{"memory": 10, "processing": 3, "released": 4, ...}`` 1007 1008 .. attribute:: dependencies: Set[TaskGroup] 1009 1010 The other TaskGroups on which this one depends 1011 1012 .. attribute:: nbytes_total: int 1013 1014 The total number of bytes that this task group has produced 1015 1016 .. attribute:: duration: float 1017 1018 The total amount of time spent on all tasks in this TaskGroup 1019 1020 .. attribute:: types: Set[str] 1021 1022 The result types of this TaskGroup 1023 1024 .. attribute:: last_worker: WorkerState 1025 1026 The worker most recently assigned a task from this group, or None when the group 1027 is not identified to be root-like by `SchedulerState.decide_worker`. 1028 1029 .. attribute:: last_worker_tasks_left: int 1030 1031 If `last_worker` is not None, the number of times that worker should be assigned 1032 subsequent tasks until a new worker is chosen. 1033 1034 See also 1035 -------- 1036 TaskPrefix 1037 """ 1038 1039 _name: str 1040 _prefix: TaskPrefix # TaskPrefix | None 1041 _states: dict 1042 _dependencies: set 1043 _nbytes_total: Py_ssize_t 1044 _duration: double 1045 _types: set 1046 _start: double 1047 _stop: double 1048 _all_durations: "defaultdict[str, float]" 1049 _last_worker: WorkerState # WorkerState | None 1050 _last_worker_tasks_left: Py_ssize_t 1051 1052 def __init__(self, name: str): 1053 self._name = name 1054 self._prefix = None # type: ignore 1055 self._states = {state: 0 for state in ALL_TASK_STATES} 1056 self._states["forgotten"] = 0 1057 self._dependencies = set() 1058 self._nbytes_total = 0 1059 self._duration = 0 1060 self._types = set() 1061 self._start = 0.0 1062 self._stop = 0.0 1063 self._all_durations = defaultdict(float) 1064 self._last_worker = None # type: ignore 1065 self._last_worker_tasks_left = 0 1066 1067 @property 1068 def name(self) -> str: 1069 return self._name 1070 1071 @property 1072 def prefix(self) -> "TaskPrefix | None": 1073 return self._prefix 1074 1075 @property 1076 def states(self) -> dict: 1077 return self._states 1078 1079 @property 1080 def dependencies(self) -> set: 1081 return self._dependencies 1082 1083 @property 1084 def nbytes_total(self): 1085 return self._nbytes_total 1086 1087 @property 1088 def duration(self) -> double: 1089 return self._duration 1090 1091 @ccall 1092 @exceptval(check=False) 1093 def add_duration(self, action: str, start: double, stop: double): 1094 duration = stop - start 1095 self._all_durations[action] += duration 1096 if action == "compute": 1097 if self._stop < stop: 1098 self._stop = stop 1099 self._start = self._start or start 1100 self._duration += duration 1101 self._prefix.add_duration(action, start, stop) 1102 1103 @property 1104 def types(self) -> set: 1105 return self._types 1106 1107 @property 1108 def all_durations(self) -> "defaultdict[str, float]": 1109 return self._all_durations 1110 1111 @property 1112 def start(self) -> double: 1113 return self._start 1114 1115 @property 1116 def stop(self) -> double: 1117 return self._stop 1118 1119 @property 1120 def last_worker(self) -> "WorkerState | None": 1121 return self._last_worker 1122 1123 @property 1124 def last_worker_tasks_left(self) -> int: 1125 return self._last_worker_tasks_left 1126 1127 @ccall 1128 def add(self, other: "TaskState"): 1129 self._states[other._state] += 1 1130 other._group = self 1131 1132 def __repr__(self): 1133 return ( 1134 "<" 1135 + (self._name or "no-group") 1136 + ": " 1137 + ", ".join( 1138 "%s: %d" % (k, v) for (k, v) in sorted(self._states.items()) if v 1139 ) 1140 + ">" 1141 ) 1142 1143 def __len__(self): 1144 return sum(self._states.values()) 1145 1146 1147@final 1148@cclass 1149class TaskState: 1150 """ 1151 A simple object holding information about a task. 1152 1153 .. attribute:: key: str 1154 1155 The key is the unique identifier of a task, generally formed 1156 from the name of the function, followed by a hash of the function 1157 and arguments, like ``'inc-ab31c010444977004d656610d2d421ec'``. 1158 1159 .. attribute:: prefix: TaskPrefix 1160 1161 The broad class of tasks to which this task belongs like "inc" or 1162 "read_csv" 1163 1164 .. attribute:: run_spec: object 1165 1166 A specification of how to run the task. The type and meaning of this 1167 value is opaque to the scheduler, as it is only interpreted by the 1168 worker to which the task is sent for executing. 1169 1170 As a special case, this attribute may also be ``None``, in which case 1171 the task is "pure data" (such as, for example, a piece of data loaded 1172 in the scheduler using :meth:`Client.scatter`). A "pure data" task 1173 cannot be computed again if its value is lost. 1174 1175 .. attribute:: priority: tuple 1176 1177 The priority provides each task with a relative ranking which is used 1178 to break ties when many tasks are being considered for execution. 1179 1180 This ranking is generally a 2-item tuple. The first (and dominant) 1181 item corresponds to when it was submitted. Generally, earlier tasks 1182 take precedence. The second item is determined by the client, and is 1183 a way to prioritize tasks within a large graph that may be important, 1184 such as if they are on the critical path, or good to run in order to 1185 release many dependencies. This is explained further in 1186 :doc:`Scheduling Policy <scheduling-policies>`. 1187 1188 .. attribute:: state: str 1189 1190 This task's current state. Valid states include ``released``, 1191 ``waiting``, ``no-worker``, ``processing``, ``memory``, ``erred`` 1192 and ``forgotten``. If it is ``forgotten``, the task isn't stored 1193 in the ``tasks`` dictionary anymore and will probably disappear 1194 soon from memory. 1195 1196 .. attribute:: dependencies: {TaskState} 1197 1198 The set of tasks this task depends on for proper execution. Only 1199 tasks still alive are listed in this set. If, for whatever reason, 1200 this task also depends on a forgotten task, the 1201 :attr:`has_lost_dependencies` flag is set. 1202 1203 A task can only be executed once all its dependencies have already 1204 been successfully executed and have their result stored on at least 1205 one worker. This is tracked by progressively draining the 1206 :attr:`waiting_on` set. 1207 1208 .. attribute:: dependents: {TaskState} 1209 1210 The set of tasks which depend on this task. Only tasks still alive 1211 are listed in this set. 1212 1213 This is the reverse mapping of :attr:`dependencies`. 1214 1215 .. attribute:: has_lost_dependencies: bool 1216 1217 Whether any of the dependencies of this task has been forgotten. 1218 For memory consumption reasons, forgotten tasks are not kept in 1219 memory even though they may have dependent tasks. When a task is 1220 forgotten, therefore, each of its dependents has their 1221 :attr:`has_lost_dependencies` attribute set to ``True``. 1222 1223 If :attr:`has_lost_dependencies` is true, this task cannot go 1224 into the "processing" state anymore. 1225 1226 .. attribute:: waiting_on: {TaskState} 1227 1228 The set of tasks this task is waiting on *before* it can be executed. 1229 This is always a subset of :attr:`dependencies`. Each time one of the 1230 dependencies has finished processing, it is removed from the 1231 :attr:`waiting_on` set. 1232 1233 Once :attr:`waiting_on` becomes empty, this task can move from the 1234 "waiting" state to the "processing" state (unless one of the 1235 dependencies errored out, in which case this task is instead 1236 marked "erred"). 1237 1238 .. attribute:: waiters: {TaskState} 1239 1240 The set of tasks which need this task to remain alive. This is always 1241 a subset of :attr:`dependents`. Each time one of the dependents 1242 has finished processing, it is removed from the :attr:`waiters` 1243 set. 1244 1245 Once both :attr:`waiters` and :attr:`who_wants` become empty, this 1246 task can be released (if it has a non-empty :attr:`run_spec`) or 1247 forgotten (otherwise) by the scheduler, and by any workers 1248 in :attr:`who_has`. 1249 1250 .. note:: Counter-intuitively, :attr:`waiting_on` and 1251 :attr:`waiters` are not reverse mappings of each other. 1252 1253 .. attribute:: who_wants: {ClientState} 1254 1255 The set of clients who want this task's result to remain alive. 1256 This is the reverse mapping of :attr:`ClientState.wants_what`. 1257 1258 When a client submits a graph to the scheduler it also specifies 1259 which output tasks it desires, such that their results are not released 1260 from memory. 1261 1262 Once a task has finished executing (i.e. moves into the "memory" 1263 or "erred" state), the clients in :attr:`who_wants` are notified. 1264 1265 Once both :attr:`waiters` and :attr:`who_wants` become empty, this 1266 task can be released (if it has a non-empty :attr:`run_spec`) or 1267 forgotten (otherwise) by the scheduler, and by any workers 1268 in :attr:`who_has`. 1269 1270 .. attribute:: who_has: {WorkerState} 1271 1272 The set of workers who have this task's result in memory. 1273 It is non-empty iff the task is in the "memory" state. There can be 1274 more than one worker in this set if, for example, :meth:`Client.scatter` 1275 or :meth:`Client.replicate` was used. 1276 1277 This is the reverse mapping of :attr:`WorkerState.has_what`. 1278 1279 .. attribute:: processing_on: WorkerState (or None) 1280 1281 If this task is in the "processing" state, which worker is currently 1282 processing it. Otherwise this is ``None``. 1283 1284 This attribute is kept in sync with :attr:`WorkerState.processing`. 1285 1286 .. attribute:: retries: int 1287 1288 The number of times this task can automatically be retried in case 1289 of failure. If a task fails executing (the worker returns with 1290 an error), its :attr:`retries` attribute is checked. If it is 1291 equal to 0, the task is marked "erred". If it is greater than 0, 1292 the :attr:`retries` attribute is decremented and execution is 1293 attempted again. 1294 1295 .. attribute:: nbytes: int (or None) 1296 1297 The number of bytes, as determined by ``sizeof``, of the result 1298 of a finished task. This number is used for diagnostics and to 1299 help prioritize work. 1300 1301 .. attribute:: type: str 1302 1303 The type of the object as a string. Only present for tasks that have 1304 been computed. 1305 1306 .. attribute:: exception: object 1307 1308 If this task failed executing, the exception object is stored here. 1309 Otherwise this is ``None``. 1310 1311 .. attribute:: traceback: object 1312 1313 If this task failed executing, the traceback object is stored here. 1314 Otherwise this is ``None``. 1315 1316 .. attribute:: exception_blame: TaskState (or None) 1317 1318 If this task or one of its dependencies failed executing, the 1319 failed task is stored here (possibly itself). Otherwise this 1320 is ``None``. 1321 1322 .. attribute:: erred_on: set(str) 1323 1324 Worker addresses on which errors appeared causing this task to be in an error state. 1325 1326 .. attribute:: suspicious: int 1327 1328 The number of times this task has been involved in a worker death. 1329 1330 Some tasks may cause workers to die (such as calling ``os._exit(0)``). 1331 When a worker dies, all of the tasks on that worker are reassigned 1332 to others. This combination of behaviors can cause a bad task to 1333 catastrophically destroy all workers on the cluster, one after 1334 another. Whenever a worker dies, we mark each task currently 1335 processing on that worker (as recorded by 1336 :attr:`WorkerState.processing`) as suspicious. 1337 1338 If a task is involved in three deaths (or some other fixed constant) 1339 then we mark the task as ``erred``. 1340 1341 .. attribute:: host_restrictions: {hostnames} 1342 1343 A set of hostnames where this task can be run (or ``None`` if empty). 1344 Usually this is empty unless the task has been specifically restricted 1345 to only run on certain hosts. A hostname may correspond to one or 1346 several connected workers. 1347 1348 .. attribute:: worker_restrictions: {worker addresses} 1349 1350 A set of complete worker addresses where this can be run (or ``None`` 1351 if empty). Usually this is empty unless the task has been specifically 1352 restricted to only run on certain workers. 1353 1354 Note this is tracking worker addresses, not worker states, since 1355 the specific workers may not be connected at this time. 1356 1357 .. attribute:: resource_restrictions: {resource: quantity} 1358 1359 Resources required by this task, such as ``{'gpu': 1}`` or 1360 ``{'memory': 1e9}`` (or ``None`` if empty). These are user-defined 1361 names and are matched against the contents of each 1362 :attr:`WorkerState.resources` dictionary. 1363 1364 .. attribute:: loose_restrictions: bool 1365 1366 If ``False``, each of :attr:`host_restrictions`, 1367 :attr:`worker_restrictions` and :attr:`resource_restrictions` is 1368 a hard constraint: if no worker is available satisfying those 1369 restrictions, the task cannot go into the "processing" state and 1370 will instead go into the "no-worker" state. 1371 1372 If ``True``, the above restrictions are mere preferences: if no worker 1373 is available satisfying those restrictions, the task can still go 1374 into the "processing" state and be sent for execution to another 1375 connected worker. 1376 1377 .. attribute:: metadata: dict 1378 1379 Metadata related to task. 1380 1381 .. attribute:: actor: bool 1382 1383 Whether or not this task is an Actor. 1384 1385 .. attribute:: group: TaskGroup 1386 1387 The group of tasks to which this one belongs. 1388 1389 .. attribute:: annotations: dict 1390 1391 Task annotations 1392 """ 1393 1394 _key: str 1395 _hash: Py_hash_t 1396 _prefix: TaskPrefix 1397 _run_spec: object 1398 _priority: tuple # tuple | None 1399 _state: str # str | None 1400 _dependencies: set # set[TaskState] 1401 _dependents: set # set[TaskState] 1402 _has_lost_dependencies: bint 1403 _waiting_on: set # set[TaskState] 1404 _waiters: set # set[TaskState] 1405 _who_wants: set # set[ClientState] 1406 _who_has: set # set[WorkerState] 1407 _processing_on: WorkerState # WorkerState | None 1408 _retries: Py_ssize_t 1409 _nbytes: Py_ssize_t 1410 _type: str # str | None 1411 _exception: object 1412 _exception_text: str 1413 _traceback: object 1414 _traceback_text: str 1415 _exception_blame: "TaskState" # TaskState | None" 1416 _erred_on: set 1417 _suspicious: Py_ssize_t 1418 _host_restrictions: set # set[str] | None 1419 _worker_restrictions: set # set[str] | None 1420 _resource_restrictions: dict # dict | None 1421 _loose_restrictions: bint 1422 _metadata: dict 1423 _annotations: dict 1424 _actor: bint 1425 _group: TaskGroup # TaskGroup | None 1426 _group_key: str 1427 1428 __slots__ = ( 1429 # === General description === 1430 "_actor", 1431 # Key name 1432 "_key", 1433 # Hash of the key name 1434 "_hash", 1435 # Key prefix (see key_split()) 1436 "_prefix", 1437 # How to run the task (None if pure data) 1438 "_run_spec", 1439 # Alive dependents and dependencies 1440 "_dependencies", 1441 "_dependents", 1442 # Compute priority 1443 "_priority", 1444 # Restrictions 1445 "_host_restrictions", 1446 "_worker_restrictions", # not WorkerStates but addresses 1447 "_resource_restrictions", 1448 "_loose_restrictions", 1449 # === Task state === 1450 "_state", 1451 # Whether some dependencies were forgotten 1452 "_has_lost_dependencies", 1453 # If in 'waiting' state, which tasks need to complete 1454 # before we can run 1455 "_waiting_on", 1456 # If in 'waiting' or 'processing' state, which tasks needs us 1457 # to complete before they can run 1458 "_waiters", 1459 # In in 'processing' state, which worker we are processing on 1460 "_processing_on", 1461 # If in 'memory' state, Which workers have us 1462 "_who_has", 1463 # Which clients want us 1464 "_who_wants", 1465 "_exception", 1466 "_exception_text", 1467 "_traceback", 1468 "_traceback_text", 1469 "_erred_on", 1470 "_exception_blame", 1471 "_suspicious", 1472 "_retries", 1473 "_nbytes", 1474 "_type", 1475 "_group_key", 1476 "_group", 1477 "_metadata", 1478 "_annotations", 1479 ) 1480 1481 def __init__(self, key: str, run_spec: object): 1482 self._key = key 1483 self._hash = hash(key) 1484 self._run_spec = run_spec 1485 self._state = None # type: ignore 1486 self._exception = None 1487 self._exception_blame = None # type: ignore 1488 self._traceback = None 1489 self._exception_text = "" 1490 self._traceback_text = "" 1491 self._suspicious = 0 1492 self._retries = 0 1493 self._nbytes = -1 1494 self._priority = None # type: ignore 1495 self._who_wants = set() 1496 self._dependencies = set() 1497 self._dependents = set() 1498 self._waiting_on = set() 1499 self._waiters = set() 1500 self._who_has = set() 1501 self._processing_on = None # type: ignore 1502 self._has_lost_dependencies = False 1503 self._host_restrictions = None # type: ignore 1504 self._worker_restrictions = None # type: ignore 1505 self._resource_restrictions = None # type: ignore 1506 self._loose_restrictions = False 1507 self._actor = False 1508 self._type = None # type: ignore 1509 self._group_key = key_split_group(key) 1510 self._group = None # type: ignore 1511 self._metadata = {} 1512 self._annotations = {} 1513 self._erred_on = set() 1514 1515 def __hash__(self): 1516 return self._hash 1517 1518 def __eq__(self, other): 1519 typ_self: type = type(self) 1520 typ_other: type = type(other) 1521 if typ_self == typ_other: 1522 other_ts: TaskState = other 1523 return self._key == other_ts._key 1524 else: 1525 return False 1526 1527 @property 1528 def key(self): 1529 return self._key 1530 1531 @property 1532 def prefix(self): 1533 return self._prefix 1534 1535 @property 1536 def run_spec(self): 1537 return self._run_spec 1538 1539 @property 1540 def priority(self) -> "tuple | None": 1541 return self._priority 1542 1543 @property 1544 def state(self) -> "str | None": 1545 return self._state 1546 1547 @state.setter 1548 def state(self, value: str): 1549 self._group._states[self._state] -= 1 1550 self._group._states[value] += 1 1551 self._state = value 1552 1553 @property 1554 def dependencies(self) -> "set[TaskState]": 1555 return self._dependencies 1556 1557 @property 1558 def dependents(self) -> "set[TaskState]": 1559 return self._dependents 1560 1561 @property 1562 def has_lost_dependencies(self): 1563 return self._has_lost_dependencies 1564 1565 @property 1566 def waiting_on(self) -> "set[TaskState]": 1567 return self._waiting_on 1568 1569 @property 1570 def waiters(self) -> "set[TaskState]": 1571 return self._waiters 1572 1573 @property 1574 def who_wants(self) -> "set[ClientState]": 1575 return self._who_wants 1576 1577 @property 1578 def who_has(self) -> "set[WorkerState]": 1579 return self._who_has 1580 1581 @property 1582 def processing_on(self) -> "WorkerState | None": 1583 return self._processing_on 1584 1585 @processing_on.setter 1586 def processing_on(self, v: WorkerState) -> None: 1587 self._processing_on = v 1588 1589 @property 1590 def retries(self): 1591 return self._retries 1592 1593 @property 1594 def nbytes(self): 1595 return self._nbytes 1596 1597 @nbytes.setter 1598 def nbytes(self, v: Py_ssize_t): 1599 self._nbytes = v 1600 1601 @property 1602 def type(self) -> "str | None": 1603 return self._type 1604 1605 @property 1606 def exception(self): 1607 return self._exception 1608 1609 @property 1610 def exception_text(self): 1611 return self._exception_text 1612 1613 @property 1614 def traceback(self): 1615 return self._traceback 1616 1617 @property 1618 def traceback_text(self): 1619 return self._traceback_text 1620 1621 @property 1622 def exception_blame(self) -> "TaskState | None": 1623 return self._exception_blame 1624 1625 @property 1626 def suspicious(self): 1627 return self._suspicious 1628 1629 @property 1630 def host_restrictions(self) -> "set[str] | None": 1631 return self._host_restrictions 1632 1633 @property 1634 def worker_restrictions(self) -> "set[str] | None": 1635 return self._worker_restrictions 1636 1637 @property 1638 def resource_restrictions(self) -> "dict | None": 1639 return self._resource_restrictions 1640 1641 @property 1642 def loose_restrictions(self): 1643 return self._loose_restrictions 1644 1645 @property 1646 def metadata(self): 1647 return self._metadata 1648 1649 @property 1650 def annotations(self): 1651 return self._annotations 1652 1653 @property 1654 def actor(self): 1655 return self._actor 1656 1657 @property 1658 def group(self) -> "TaskGroup | None": 1659 return self._group 1660 1661 @property 1662 def group_key(self) -> str: 1663 return self._group_key 1664 1665 @property 1666 def prefix_key(self): 1667 return self._prefix._name 1668 1669 @property 1670 def erred_on(self): 1671 return self._erred_on 1672 1673 @ccall 1674 def add_dependency(self, other: "TaskState"): 1675 """Add another task as a dependency of this task""" 1676 self._dependencies.add(other) 1677 self._group._dependencies.add(other._group) 1678 other._dependents.add(self) 1679 1680 @ccall 1681 @inline 1682 @nogil 1683 def get_nbytes(self) -> Py_ssize_t: 1684 return self._nbytes if self._nbytes >= 0 else DEFAULT_DATA_SIZE 1685 1686 @ccall 1687 def set_nbytes(self, nbytes: Py_ssize_t): 1688 diff: Py_ssize_t = nbytes 1689 old_nbytes: Py_ssize_t = self._nbytes 1690 if old_nbytes >= 0: 1691 diff -= old_nbytes 1692 self._group._nbytes_total += diff 1693 ws: WorkerState 1694 for ws in self._who_has: 1695 ws._nbytes += diff 1696 self._nbytes = nbytes 1697 1698 def __repr__(self): 1699 return f"<TaskState {self._key!r} {self._state}>" 1700 1701 def _repr_html_(self): 1702 return get_template("task_state.html.j2").render( 1703 state=self._state, 1704 nbytes=self._nbytes, 1705 key=self._key, 1706 ) 1707 1708 @ccall 1709 def validate(self): 1710 try: 1711 for cs in self._who_wants: 1712 assert isinstance(cs, ClientState), (repr(cs), self._who_wants) 1713 for ws in self._who_has: 1714 assert isinstance(ws, WorkerState), (repr(ws), self._who_has) 1715 for ts in self._dependencies: 1716 assert isinstance(ts, TaskState), (repr(ts), self._dependencies) 1717 for ts in self._dependents: 1718 assert isinstance(ts, TaskState), (repr(ts), self._dependents) 1719 validate_task_state(self) 1720 except Exception as e: 1721 logger.exception(e) 1722 if LOG_PDB: 1723 import pdb 1724 1725 pdb.set_trace() 1726 1727 def get_nbytes_deps(self): 1728 nbytes: Py_ssize_t = 0 1729 ts: TaskState 1730 for ts in self._dependencies: 1731 nbytes += ts.get_nbytes() 1732 return nbytes 1733 1734 @ccall 1735 def _to_dict(self, *, exclude: Container[str] = None): 1736 """ 1737 A very verbose dictionary representation for debugging purposes. 1738 Not type stable and not inteded for roundtrips. 1739 1740 Parameters 1741 ---------- 1742 exclude: 1743 A list of attributes which must not be present in the output. 1744 1745 See also 1746 -------- 1747 Client.dump_cluster_state 1748 """ 1749 1750 if not exclude: 1751 exclude = set() 1752 members = inspect.getmembers(self) 1753 return recursive_to_dict( 1754 {k: v for k, v in members if k not in exclude and not callable(v)}, 1755 exclude=exclude, 1756 ) 1757 1758 1759class _StateLegacyMapping(Mapping): 1760 """ 1761 A mapping interface mimicking the former Scheduler state dictionaries. 1762 """ 1763 1764 def __init__(self, states, accessor): 1765 self._states = states 1766 self._accessor = accessor 1767 1768 def __iter__(self): 1769 return iter(self._states) 1770 1771 def __len__(self): 1772 return len(self._states) 1773 1774 def __getitem__(self, key): 1775 return self._accessor(self._states[key]) 1776 1777 def __repr__(self): 1778 return f"{self.__class__}({dict(self)})" 1779 1780 1781class _OptionalStateLegacyMapping(_StateLegacyMapping): 1782 """ 1783 Similar to _StateLegacyMapping, but a false-y value is interpreted 1784 as a missing key. 1785 """ 1786 1787 # For tasks etc. 1788 1789 def __iter__(self): 1790 accessor = self._accessor 1791 for k, v in self._states.items(): 1792 if accessor(v): 1793 yield k 1794 1795 def __len__(self): 1796 accessor = self._accessor 1797 return sum(bool(accessor(v)) for v in self._states.values()) 1798 1799 def __getitem__(self, key): 1800 v = self._accessor(self._states[key]) 1801 if v: 1802 return v 1803 else: 1804 raise KeyError 1805 1806 1807class _StateLegacySet(Set): 1808 """ 1809 Similar to _StateLegacyMapping, but exposes a set containing 1810 all values with a true value. 1811 """ 1812 1813 # For loose_restrictions 1814 1815 def __init__(self, states, accessor): 1816 self._states = states 1817 self._accessor = accessor 1818 1819 def __iter__(self): 1820 return (k for k, v in self._states.items() if self._accessor(v)) 1821 1822 def __len__(self): 1823 return sum(map(bool, map(self._accessor, self._states.values()))) 1824 1825 def __contains__(self, k): 1826 st = self._states.get(k) 1827 return st is not None and bool(self._accessor(st)) 1828 1829 def __repr__(self): 1830 return f"{self.__class__}({set(self)})" 1831 1832 1833def _legacy_task_key_set(tasks): 1834 """ 1835 Transform a set of task states into a set of task keys. 1836 """ 1837 ts: TaskState 1838 return {ts._key for ts in tasks} 1839 1840 1841def _legacy_client_key_set(clients): 1842 """ 1843 Transform a set of client states into a set of client keys. 1844 """ 1845 cs: ClientState 1846 return {cs._client_key for cs in clients} 1847 1848 1849def _legacy_worker_key_set(workers): 1850 """ 1851 Transform a set of worker states into a set of worker keys. 1852 """ 1853 ws: WorkerState 1854 return {ws._address for ws in workers} 1855 1856 1857def _legacy_task_key_dict(task_dict: dict): 1858 """ 1859 Transform a dict of {task state: value} into a dict of {task key: value}. 1860 """ 1861 ts: TaskState 1862 return {ts._key: value for ts, value in task_dict.items()} 1863 1864 1865def _task_key_or_none(task: TaskState): 1866 return task._key if task is not None else None 1867 1868 1869@cclass 1870class SchedulerState: 1871 """Underlying task state of dynamic scheduler 1872 1873 Tracks the current state of workers, data, and computations. 1874 1875 Handles transitions between different task states. Notifies the 1876 Scheduler of changes by messaging passing through Queues, which the 1877 Scheduler listens to responds accordingly. 1878 1879 All events are handled quickly, in linear time with respect to their 1880 input (which is often of constant size) and generally within a 1881 millisecond. Additionally when Cythonized, this can be faster still. 1882 To accomplish this the scheduler tracks a lot of state. Every 1883 operation maintains the consistency of this state. 1884 1885 Users typically do not interact with ``Transitions`` directly. Instead 1886 users interact with the ``Client``, which in turn engages the 1887 ``Scheduler`` affecting different transitions here under-the-hood. In 1888 the background ``Worker``s also engage with the ``Scheduler`` 1889 affecting these state transitions as well. 1890 1891 **State** 1892 1893 The ``Transitions`` object contains the following state variables. 1894 Each variable is listed along with what it stores and a brief 1895 description. 1896 1897 * **tasks:** ``{task key: TaskState}`` 1898 Tasks currently known to the scheduler 1899 * **unrunnable:** ``{TaskState}`` 1900 Tasks in the "no-worker" state 1901 1902 * **workers:** ``{worker key: WorkerState}`` 1903 Workers currently connected to the scheduler 1904 * **idle:** ``{WorkerState}``: 1905 Set of workers that are not fully utilized 1906 * **saturated:** ``{WorkerState}``: 1907 Set of workers that are not over-utilized 1908 * **running:** ``{WorkerState}``: 1909 Set of workers that are currently in running state 1910 1911 * **clients:** ``{client key: ClientState}`` 1912 Clients currently connected to the scheduler 1913 1914 * **task_duration:** ``{key-prefix: time}`` 1915 Time we expect certain functions to take, e.g. ``{'sum': 0.25}`` 1916 """ 1917 1918 _aliases: dict 1919 _bandwidth: double 1920 _clients: dict # dict[str, ClientState] 1921 _computations: object 1922 _extensions: dict 1923 _host_info: dict 1924 _idle: "SortedDict[str, WorkerState]" 1925 _idle_dv: dict # dict[str, WorkerState] 1926 _n_tasks: Py_ssize_t 1927 _resources: dict 1928 _saturated: set # set[WorkerState] 1929 _running: set # set[WorkerState] 1930 _tasks: dict 1931 _task_groups: dict 1932 _task_prefixes: dict 1933 _task_metadata: dict 1934 _replicated_tasks: set 1935 _total_nthreads: Py_ssize_t 1936 _total_occupancy: double 1937 _transitions_table: dict 1938 _unknown_durations: dict 1939 _unrunnable: set 1940 _validate: bint 1941 _workers: "SortedDict[str, WorkerState]" 1942 _workers_dv: dict # dict[str, WorkerState] 1943 _transition_counter: Py_ssize_t 1944 _plugins: dict # dict[str, SchedulerPlugin] 1945 1946 # Variables from dask.config, cached by __init__ for performance 1947 UNKNOWN_TASK_DURATION: double 1948 MEMORY_RECENT_TO_OLD_TIME: double 1949 MEMORY_REBALANCE_MEASURE: str 1950 MEMORY_REBALANCE_SENDER_MIN: double 1951 MEMORY_REBALANCE_RECIPIENT_MAX: double 1952 MEMORY_REBALANCE_HALF_GAP: double 1953 1954 def __init__( 1955 self, 1956 aliases: dict, 1957 clients: "dict[str, ClientState]", 1958 workers: "SortedDict[str, WorkerState]", 1959 host_info: dict, 1960 resources: dict, 1961 tasks: dict, 1962 unrunnable: set, 1963 validate: bint, 1964 plugins: "Iterable[SchedulerPlugin]" = (), 1965 **kwargs, # Passed verbatim to Server.__init__() 1966 ): 1967 self._aliases = aliases 1968 self._bandwidth = parse_bytes( 1969 dask.config.get("distributed.scheduler.bandwidth") 1970 ) 1971 self._clients = clients 1972 self._clients["fire-and-forget"] = ClientState("fire-and-forget") 1973 self._extensions = {} 1974 self._host_info = host_info 1975 self._idle = SortedDict() 1976 # Note: cython.cast, not typing.cast! 1977 self._idle_dv = cast(dict, self._idle) 1978 self._n_tasks = 0 1979 self._resources = resources 1980 self._saturated = set() 1981 self._tasks = tasks 1982 self._replicated_tasks = { 1983 ts for ts in self._tasks.values() if len(ts._who_has) > 1 1984 } 1985 self._computations = deque( 1986 maxlen=dask.config.get("distributed.diagnostics.computations.max-history") 1987 ) 1988 self._task_groups = {} 1989 self._task_prefixes = {} 1990 self._task_metadata = {} 1991 self._total_nthreads = 0 1992 self._total_occupancy = 0 1993 self._transitions_table = { 1994 ("released", "waiting"): self.transition_released_waiting, 1995 ("waiting", "released"): self.transition_waiting_released, 1996 ("waiting", "processing"): self.transition_waiting_processing, 1997 ("waiting", "memory"): self.transition_waiting_memory, 1998 ("processing", "released"): self.transition_processing_released, 1999 ("processing", "memory"): self.transition_processing_memory, 2000 ("processing", "erred"): self.transition_processing_erred, 2001 ("no-worker", "released"): self.transition_no_worker_released, 2002 ("no-worker", "waiting"): self.transition_no_worker_waiting, 2003 ("no-worker", "memory"): self.transition_no_worker_memory, 2004 ("released", "forgotten"): self.transition_released_forgotten, 2005 ("memory", "forgotten"): self.transition_memory_forgotten, 2006 ("erred", "released"): self.transition_erred_released, 2007 ("memory", "released"): self.transition_memory_released, 2008 ("released", "erred"): self.transition_released_erred, 2009 } 2010 self._unknown_durations = {} 2011 self._unrunnable = unrunnable 2012 self._validate = validate 2013 self._workers = workers 2014 # Note: cython.cast, not typing.cast! 2015 self._workers_dv = cast(dict, self._workers) 2016 self._running = { 2017 ws for ws in self._workers.values() if ws.status == Status.running 2018 } 2019 self._plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins} 2020 2021 # Variables from dask.config, cached by __init__ for performance 2022 self.UNKNOWN_TASK_DURATION = parse_timedelta( 2023 dask.config.get("distributed.scheduler.unknown-task-duration") 2024 ) 2025 self.MEMORY_RECENT_TO_OLD_TIME = parse_timedelta( 2026 dask.config.get("distributed.worker.memory.recent-to-old-time") 2027 ) 2028 self.MEMORY_REBALANCE_MEASURE = dask.config.get( 2029 "distributed.worker.memory.rebalance.measure" 2030 ) 2031 self.MEMORY_REBALANCE_SENDER_MIN = dask.config.get( 2032 "distributed.worker.memory.rebalance.sender-min" 2033 ) 2034 self.MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get( 2035 "distributed.worker.memory.rebalance.recipient-max" 2036 ) 2037 self.MEMORY_REBALANCE_HALF_GAP = ( 2038 dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") 2039 / 2.0 2040 ) 2041 self._transition_counter = 0 2042 2043 # Call Server.__init__() 2044 super().__init__(**kwargs) # type: ignore 2045 2046 @property 2047 def aliases(self): 2048 return self._aliases 2049 2050 @property 2051 def bandwidth(self): 2052 return self._bandwidth 2053 2054 @property 2055 def clients(self): 2056 return self._clients 2057 2058 @property 2059 def computations(self): 2060 return self._computations 2061 2062 @property 2063 def extensions(self): 2064 return self._extensions 2065 2066 @property 2067 def host_info(self): 2068 return self._host_info 2069 2070 @property 2071 def idle(self): 2072 return self._idle 2073 2074 @property 2075 def n_tasks(self): 2076 return self._n_tasks 2077 2078 @property 2079 def resources(self): 2080 return self._resources 2081 2082 @property 2083 def saturated(self) -> "set[WorkerState]": 2084 return self._saturated 2085 2086 @property 2087 def running(self) -> "set[WorkerState]": 2088 return self._running 2089 2090 @property 2091 def tasks(self): 2092 return self._tasks 2093 2094 @property 2095 def task_groups(self): 2096 return self._task_groups 2097 2098 @property 2099 def task_prefixes(self): 2100 return self._task_prefixes 2101 2102 @property 2103 def task_metadata(self): 2104 return self._task_metadata 2105 2106 @property 2107 def replicated_tasks(self): 2108 return self._replicated_tasks 2109 2110 @property 2111 def total_nthreads(self): 2112 return self._total_nthreads 2113 2114 @property 2115 def total_occupancy(self): 2116 return self._total_occupancy 2117 2118 @total_occupancy.setter 2119 def total_occupancy(self, v: double): 2120 self._total_occupancy = v 2121 2122 @property 2123 def transition_counter(self): 2124 return self._transition_counter 2125 2126 @property 2127 def unknown_durations(self): 2128 return self._unknown_durations 2129 2130 @property 2131 def unrunnable(self): 2132 return self._unrunnable 2133 2134 @property 2135 def validate(self): 2136 return self._validate 2137 2138 @validate.setter 2139 def validate(self, v: bint): 2140 self._validate = v 2141 2142 @property 2143 def workers(self): 2144 return self._workers 2145 2146 @property 2147 def plugins(self) -> "dict[str, SchedulerPlugin]": 2148 return self._plugins 2149 2150 @property 2151 def memory(self) -> MemoryState: 2152 return MemoryState.sum(*(w.memory for w in self.workers.values())) 2153 2154 @property 2155 def __pdict__(self): 2156 return { 2157 "bandwidth": self._bandwidth, 2158 "resources": self._resources, 2159 "saturated": self._saturated, 2160 "unrunnable": self._unrunnable, 2161 "n_tasks": self._n_tasks, 2162 "unknown_durations": self._unknown_durations, 2163 "validate": self._validate, 2164 "tasks": self._tasks, 2165 "task_groups": self._task_groups, 2166 "task_prefixes": self._task_prefixes, 2167 "total_nthreads": self._total_nthreads, 2168 "total_occupancy": self._total_occupancy, 2169 "extensions": self._extensions, 2170 "clients": self._clients, 2171 "workers": self._workers, 2172 "idle": self._idle, 2173 "host_info": self._host_info, 2174 } 2175 2176 @ccall 2177 @exceptval(check=False) 2178 def new_task( 2179 self, key: str, spec: object, state: str, computation: Computation = None 2180 ) -> TaskState: 2181 """Create a new task, and associated states""" 2182 ts: TaskState = TaskState(key, spec) 2183 ts._state = state 2184 2185 tp: TaskPrefix 2186 prefix_key = key_split(key) 2187 tp = self._task_prefixes.get(prefix_key) # type: ignore 2188 if tp is None: 2189 self._task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key) 2190 ts._prefix = tp 2191 2192 group_key = ts._group_key 2193 tg: TaskGroup = self._task_groups.get(group_key) # type: ignore 2194 if tg is None: 2195 self._task_groups[group_key] = tg = TaskGroup(group_key) 2196 if computation: 2197 computation.groups.add(tg) 2198 tg._prefix = tp 2199 tp._groups.append(tg) 2200 tg.add(ts) 2201 2202 self._tasks[key] = ts 2203 2204 return ts 2205 2206 ##################### 2207 # State Transitions # 2208 ##################### 2209 2210 def _transition(self, key, finish: str, *args, **kwargs): 2211 """Transition a key from its current state to the finish state 2212 2213 Examples 2214 -------- 2215 >>> self._transition('x', 'waiting') 2216 {'x': 'processing'} 2217 2218 Returns 2219 ------- 2220 Dictionary of recommendations for future transitions 2221 2222 See Also 2223 -------- 2224 Scheduler.transitions : transitive version of this function 2225 """ 2226 parent: SchedulerState = cast(SchedulerState, self) 2227 ts: TaskState 2228 start: str 2229 start_finish: tuple 2230 finish2: str 2231 recommendations: dict 2232 worker_msgs: dict 2233 client_msgs: dict 2234 msgs: list 2235 new_msgs: list 2236 dependents: set 2237 dependencies: set 2238 try: 2239 recommendations = {} 2240 worker_msgs = {} 2241 client_msgs = {} 2242 2243 ts = parent._tasks.get(key) # type: ignore 2244 if ts is None: 2245 return recommendations, client_msgs, worker_msgs 2246 start = ts._state 2247 if start == finish: 2248 return recommendations, client_msgs, worker_msgs 2249 2250 if self.plugins: 2251 dependents = set(ts._dependents) 2252 dependencies = set(ts._dependencies) 2253 2254 start_finish = (start, finish) 2255 func = self._transitions_table.get(start_finish) 2256 if func is not None: 2257 recommendations, client_msgs, worker_msgs = func(key, *args, **kwargs) 2258 self._transition_counter += 1 2259 elif "released" not in start_finish: 2260 assert not args and not kwargs, (args, kwargs, start_finish) 2261 a_recs: dict 2262 a_cmsgs: dict 2263 a_wmsgs: dict 2264 a: tuple = self._transition(key, "released") 2265 a_recs, a_cmsgs, a_wmsgs = a 2266 2267 v = a_recs.get(key, finish) 2268 func = self._transitions_table["released", v] 2269 b_recs: dict 2270 b_cmsgs: dict 2271 b_wmsgs: dict 2272 b: tuple = func(key) 2273 b_recs, b_cmsgs, b_wmsgs = b 2274 2275 recommendations.update(a_recs) 2276 for c, new_msgs in a_cmsgs.items(): 2277 msgs = client_msgs.get(c) # type: ignore 2278 if msgs is not None: 2279 msgs.extend(new_msgs) 2280 else: 2281 client_msgs[c] = new_msgs 2282 for w, new_msgs in a_wmsgs.items(): 2283 msgs = worker_msgs.get(w) # type: ignore 2284 if msgs is not None: 2285 msgs.extend(new_msgs) 2286 else: 2287 worker_msgs[w] = new_msgs 2288 2289 recommendations.update(b_recs) 2290 for c, new_msgs in b_cmsgs.items(): 2291 msgs = client_msgs.get(c) # type: ignore 2292 if msgs is not None: 2293 msgs.extend(new_msgs) 2294 else: 2295 client_msgs[c] = new_msgs 2296 for w, new_msgs in b_wmsgs.items(): 2297 msgs = worker_msgs.get(w) # type: ignore 2298 if msgs is not None: 2299 msgs.extend(new_msgs) 2300 else: 2301 worker_msgs[w] = new_msgs 2302 2303 start = "released" 2304 else: 2305 raise RuntimeError("Impossible transition from %r to %r" % start_finish) 2306 2307 finish2 = ts._state 2308 # FIXME downcast antipattern 2309 scheduler = pep484_cast(Scheduler, self) 2310 scheduler.transition_log.append( 2311 (key, start, finish2, recommendations, time()) 2312 ) 2313 if parent._validate: 2314 logger.debug( 2315 "Transitioned %r %s->%s (actual: %s). Consequence: %s", 2316 key, 2317 start, 2318 finish2, 2319 ts._state, 2320 dict(recommendations), 2321 ) 2322 if self.plugins: 2323 # Temporarily put back forgotten key for plugin to retrieve it 2324 if ts._state == "forgotten": 2325 ts._dependents = dependents 2326 ts._dependencies = dependencies 2327 parent._tasks[ts._key] = ts 2328 for plugin in list(self.plugins.values()): 2329 try: 2330 plugin.transition(key, start, finish2, *args, **kwargs) 2331 except Exception: 2332 logger.info("Plugin failed with exception", exc_info=True) 2333 if ts._state == "forgotten": 2334 del parent._tasks[ts._key] 2335 2336 tg: TaskGroup = ts._group 2337 if ts._state == "forgotten" and tg._name in parent._task_groups: 2338 # Remove TaskGroup if all tasks are in the forgotten state 2339 all_forgotten: bint = True 2340 for s in ALL_TASK_STATES: 2341 if tg._states.get(s): 2342 all_forgotten = False 2343 break 2344 if all_forgotten: 2345 ts._prefix._groups.remove(tg) 2346 del parent._task_groups[tg._name] 2347 2348 return recommendations, client_msgs, worker_msgs 2349 except Exception: 2350 logger.exception("Error transitioning %r from %r to %r", key, start, finish) 2351 if LOG_PDB: 2352 import pdb 2353 2354 pdb.set_trace() 2355 raise 2356 2357 def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: dict): 2358 """Process transitions until none are left 2359 2360 This includes feedback from previous transitions and continues until we 2361 reach a steady state 2362 """ 2363 keys: set = set() 2364 recommendations = recommendations.copy() 2365 msgs: list 2366 new_msgs: list 2367 new: tuple 2368 new_recs: dict 2369 new_cmsgs: dict 2370 new_wmsgs: dict 2371 while recommendations: 2372 key, finish = recommendations.popitem() 2373 keys.add(key) 2374 2375 new = self._transition(key, finish) 2376 new_recs, new_cmsgs, new_wmsgs = new 2377 2378 recommendations.update(new_recs) 2379 for c, new_msgs in new_cmsgs.items(): 2380 msgs = client_msgs.get(c) # type: ignore 2381 if msgs is not None: 2382 msgs.extend(new_msgs) 2383 else: 2384 client_msgs[c] = new_msgs 2385 for w, new_msgs in new_wmsgs.items(): 2386 msgs = worker_msgs.get(w) # type: ignore 2387 if msgs is not None: 2388 msgs.extend(new_msgs) 2389 else: 2390 worker_msgs[w] = new_msgs 2391 2392 if self._validate: 2393 # FIXME downcast antipattern 2394 scheduler = pep484_cast(Scheduler, self) 2395 for key in keys: 2396 scheduler.validate_key(key) 2397 2398 def transition_released_waiting(self, key): 2399 try: 2400 ts: TaskState = self._tasks[key] 2401 dts: TaskState 2402 recommendations: dict = {} 2403 client_msgs: dict = {} 2404 worker_msgs: dict = {} 2405 2406 if self._validate: 2407 assert ts._run_spec 2408 assert not ts._waiting_on 2409 assert not ts._who_has 2410 assert not ts._processing_on 2411 assert not any([dts._state == "forgotten" for dts in ts._dependencies]) 2412 2413 if ts._has_lost_dependencies: 2414 recommendations[key] = "forgotten" 2415 return recommendations, client_msgs, worker_msgs 2416 2417 ts.state = "waiting" 2418 2419 dts: TaskState 2420 for dts in ts._dependencies: 2421 if dts._exception_blame: 2422 ts._exception_blame = dts._exception_blame 2423 recommendations[key] = "erred" 2424 return recommendations, client_msgs, worker_msgs 2425 2426 for dts in ts._dependencies: 2427 dep = dts._key 2428 if not dts._who_has: 2429 ts._waiting_on.add(dts) 2430 if dts._state == "released": 2431 recommendations[dep] = "waiting" 2432 else: 2433 dts._waiters.add(ts) 2434 2435 ts._waiters = {dts for dts in ts._dependents if dts._state == "waiting"} 2436 2437 if not ts._waiting_on: 2438 if self._workers_dv: 2439 recommendations[key] = "processing" 2440 else: 2441 self._unrunnable.add(ts) 2442 ts.state = "no-worker" 2443 2444 return recommendations, client_msgs, worker_msgs 2445 except Exception as e: 2446 logger.exception(e) 2447 if LOG_PDB: 2448 import pdb 2449 2450 pdb.set_trace() 2451 raise 2452 2453 def transition_no_worker_waiting(self, key): 2454 try: 2455 ts: TaskState = self._tasks[key] 2456 dts: TaskState 2457 recommendations: dict = {} 2458 client_msgs: dict = {} 2459 worker_msgs: dict = {} 2460 2461 if self._validate: 2462 assert ts in self._unrunnable 2463 assert not ts._waiting_on 2464 assert not ts._who_has 2465 assert not ts._processing_on 2466 2467 self._unrunnable.remove(ts) 2468 2469 if ts._has_lost_dependencies: 2470 recommendations[key] = "forgotten" 2471 return recommendations, client_msgs, worker_msgs 2472 2473 for dts in ts._dependencies: 2474 dep = dts._key 2475 if not dts._who_has: 2476 ts._waiting_on.add(dts) 2477 if dts._state == "released": 2478 recommendations[dep] = "waiting" 2479 else: 2480 dts._waiters.add(ts) 2481 2482 ts.state = "waiting" 2483 2484 if not ts._waiting_on: 2485 if self._workers_dv: 2486 recommendations[key] = "processing" 2487 else: 2488 self._unrunnable.add(ts) 2489 ts.state = "no-worker" 2490 2491 return recommendations, client_msgs, worker_msgs 2492 except Exception as e: 2493 logger.exception(e) 2494 if LOG_PDB: 2495 import pdb 2496 2497 pdb.set_trace() 2498 raise 2499 2500 def transition_no_worker_memory( 2501 self, key, nbytes=None, type=None, typename: str = None, worker=None 2502 ): 2503 try: 2504 ws: WorkerState = self._workers_dv[worker] 2505 ts: TaskState = self._tasks[key] 2506 recommendations: dict = {} 2507 client_msgs: dict = {} 2508 worker_msgs: dict = {} 2509 2510 if self._validate: 2511 assert not ts._processing_on 2512 assert not ts._waiting_on 2513 assert ts._state == "no-worker" 2514 2515 self._unrunnable.remove(ts) 2516 2517 if nbytes is not None: 2518 ts.set_nbytes(nbytes) 2519 2520 self.check_idle_saturated(ws) 2521 2522 _add_to_memory( 2523 self, ts, ws, recommendations, client_msgs, type=type, typename=typename 2524 ) 2525 ts.state = "memory" 2526 2527 return recommendations, client_msgs, worker_msgs 2528 except Exception as e: 2529 logger.exception(e) 2530 if LOG_PDB: 2531 import pdb 2532 2533 pdb.set_trace() 2534 raise 2535 2536 @ccall 2537 @exceptval(check=False) 2538 def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None 2539 """ 2540 Decide on a worker for task *ts*. Return a WorkerState. 2541 2542 If it's a root or root-like task, we place it with its relatives to 2543 reduce future data tansfer. 2544 2545 If it has dependencies or restrictions, we use 2546 `decide_worker_from_deps_and_restrictions`. 2547 2548 Otherwise, we pick the least occupied worker, or pick from all workers 2549 in a round-robin fashion. 2550 """ 2551 if not self._workers_dv: 2552 return None # type: ignore 2553 2554 ws: WorkerState 2555 tg: TaskGroup = ts._group 2556 valid_workers: set = self.valid_workers(ts) 2557 2558 if ( 2559 valid_workers is not None 2560 and not valid_workers 2561 and not ts._loose_restrictions 2562 ): 2563 self._unrunnable.add(ts) 2564 ts.state = "no-worker" 2565 return None # type: ignore 2566 2567 # Group is larger than cluster with few dependencies? 2568 # Minimize future data transfers. 2569 if ( 2570 valid_workers is None 2571 and len(tg) > self._total_nthreads * 2 2572 and len(tg._dependencies) < 5 2573 and sum(map(len, tg._dependencies)) < 5 2574 ): 2575 ws = tg._last_worker 2576 2577 if not ( 2578 ws and tg._last_worker_tasks_left and ws._address in self._workers_dv 2579 ): 2580 # Last-used worker is full or unknown; pick a new worker for the next few tasks 2581 ws = min( 2582 (self._idle_dv or self._workers_dv).values(), 2583 key=partial(self.worker_objective, ts), 2584 ) 2585 tg._last_worker_tasks_left = math.floor( 2586 (len(tg) / self._total_nthreads) * ws._nthreads 2587 ) 2588 2589 # Record `last_worker`, or clear it on the final task 2590 tg._last_worker = ( 2591 ws if tg.states["released"] + tg.states["waiting"] > 1 else None 2592 ) 2593 tg._last_worker_tasks_left -= 1 2594 return ws 2595 2596 if ts._dependencies or valid_workers is not None: 2597 ws = decide_worker( 2598 ts, 2599 self._workers_dv.values(), 2600 valid_workers, 2601 partial(self.worker_objective, ts), 2602 ) 2603 else: 2604 # Fastpath when there are no related tasks or restrictions 2605 worker_pool = self._idle or self._workers 2606 # Note: cython.cast, not typing.cast! 2607 worker_pool_dv = cast(dict, worker_pool) 2608 wp_vals = worker_pool.values() 2609 n_workers: Py_ssize_t = len(worker_pool_dv) 2610 if n_workers < 20: # smart but linear in small case 2611 ws = min(wp_vals, key=operator.attrgetter("occupancy")) 2612 if ws._occupancy == 0: 2613 # special case to use round-robin; linear search 2614 # for next worker with zero occupancy (or just 2615 # land back where we started). 2616 wp_i: WorkerState 2617 start: Py_ssize_t = self._n_tasks % n_workers 2618 i: Py_ssize_t 2619 for i in range(n_workers): 2620 wp_i = wp_vals[(i + start) % n_workers] 2621 if wp_i._occupancy == 0: 2622 ws = wp_i 2623 break 2624 else: # dumb but fast in large case 2625 ws = wp_vals[self._n_tasks % n_workers] 2626 2627 if self._validate: 2628 assert ws is None or isinstance(ws, WorkerState), ( 2629 type(ws), 2630 ws, 2631 ) 2632 assert ws._address in self._workers_dv 2633 2634 return ws 2635 2636 @ccall 2637 def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double: 2638 """Estimate task duration using worker state and task state. 2639 2640 If a task takes longer than twice the current average duration we 2641 estimate the task duration to be 2x current-runtime, otherwise we set it 2642 to be the average duration. 2643 2644 See also ``_remove_from_processing`` 2645 """ 2646 exec_time: double = ws._executing.get(ts, 0) 2647 duration: double = self.get_task_duration(ts) 2648 total_duration: double 2649 if exec_time > 2 * duration: 2650 total_duration = 2 * exec_time 2651 else: 2652 comm: double = self.get_comm_cost(ts, ws) 2653 total_duration = duration + comm 2654 old = ws._processing.get(ts, 0) 2655 ws._processing[ts] = total_duration 2656 self._total_occupancy += total_duration - old 2657 ws._occupancy += total_duration - old 2658 2659 return total_duration 2660 2661 def transition_waiting_processing(self, key): 2662 try: 2663 ts: TaskState = self._tasks[key] 2664 dts: TaskState 2665 recommendations: dict = {} 2666 client_msgs: dict = {} 2667 worker_msgs: dict = {} 2668 2669 if self._validate: 2670 assert not ts._waiting_on 2671 assert not ts._who_has 2672 assert not ts._exception_blame 2673 assert not ts._processing_on 2674 assert not ts._has_lost_dependencies 2675 assert ts not in self._unrunnable 2676 assert all([dts._who_has for dts in ts._dependencies]) 2677 2678 ws: WorkerState = self.decide_worker(ts) 2679 if ws is None: 2680 return recommendations, client_msgs, worker_msgs 2681 worker = ws._address 2682 2683 self.set_duration_estimate(ts, ws) 2684 ts._processing_on = ws 2685 ts.state = "processing" 2686 self.consume_resources(ts, ws) 2687 self.check_idle_saturated(ws) 2688 self._n_tasks += 1 2689 2690 if ts._actor: 2691 ws._actors.add(ts) 2692 2693 # logger.debug("Send job to worker: %s, %s", worker, key) 2694 2695 worker_msgs[worker] = [_task_to_msg(self, ts)] 2696 2697 return recommendations, client_msgs, worker_msgs 2698 except Exception as e: 2699 logger.exception(e) 2700 if LOG_PDB: 2701 import pdb 2702 2703 pdb.set_trace() 2704 raise 2705 2706 def transition_waiting_memory( 2707 self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs 2708 ): 2709 try: 2710 ws: WorkerState = self._workers_dv[worker] 2711 ts: TaskState = self._tasks[key] 2712 recommendations: dict = {} 2713 client_msgs: dict = {} 2714 worker_msgs: dict = {} 2715 2716 if self._validate: 2717 assert not ts._processing_on 2718 assert ts._waiting_on 2719 assert ts._state == "waiting" 2720 2721 ts._waiting_on.clear() 2722 2723 if nbytes is not None: 2724 ts.set_nbytes(nbytes) 2725 2726 self.check_idle_saturated(ws) 2727 2728 _add_to_memory( 2729 self, ts, ws, recommendations, client_msgs, type=type, typename=typename 2730 ) 2731 2732 if self._validate: 2733 assert not ts._processing_on 2734 assert not ts._waiting_on 2735 assert ts._who_has 2736 2737 return recommendations, client_msgs, worker_msgs 2738 except Exception as e: 2739 logger.exception(e) 2740 if LOG_PDB: 2741 import pdb 2742 2743 pdb.set_trace() 2744 raise 2745 2746 def transition_processing_memory( 2747 self, 2748 key, 2749 nbytes=None, 2750 type=None, 2751 typename: str = None, 2752 worker=None, 2753 startstops=None, 2754 **kwargs, 2755 ): 2756 ws: WorkerState 2757 wws: WorkerState 2758 recommendations: dict = {} 2759 client_msgs: dict = {} 2760 worker_msgs: dict = {} 2761 try: 2762 ts: TaskState = self._tasks[key] 2763 2764 assert worker 2765 assert isinstance(worker, str) 2766 2767 if self._validate: 2768 assert ts._processing_on 2769 ws = ts._processing_on 2770 assert ts in ws._processing 2771 assert not ts._waiting_on 2772 assert not ts._who_has, (ts, ts._who_has) 2773 assert not ts._exception_blame 2774 assert ts._state == "processing" 2775 2776 ws = self._workers_dv.get(worker) # type: ignore 2777 if ws is None: 2778 recommendations[key] = "released" 2779 return recommendations, client_msgs, worker_msgs 2780 2781 if ws != ts._processing_on: # someone else has this task 2782 logger.info( 2783 "Unexpected worker completed task. Expected: %s, Got: %s, Key: %s", 2784 ts._processing_on, 2785 ws, 2786 key, 2787 ) 2788 worker_msgs[ts._processing_on.address] = [ 2789 { 2790 "op": "cancel-compute", 2791 "key": key, 2792 "reason": "Finished on different worker", 2793 } 2794 ] 2795 2796 ############################# 2797 # Update Timing Information # 2798 ############################# 2799 if startstops: 2800 startstop: dict 2801 for startstop in startstops: 2802 ts._group.add_duration( 2803 stop=startstop["stop"], 2804 start=startstop["start"], 2805 action=startstop["action"], 2806 ) 2807 2808 s: set = self._unknown_durations.pop(ts._prefix._name, set()) 2809 tts: TaskState 2810 steal = self.extensions.get("stealing") 2811 for tts in s: 2812 if tts._processing_on: 2813 self.set_duration_estimate(tts, tts._processing_on) 2814 if steal: 2815 steal.put_key_in_stealable(tts) 2816 2817 ############################ 2818 # Update State Information # 2819 ############################ 2820 if nbytes is not None: 2821 ts.set_nbytes(nbytes) 2822 2823 _remove_from_processing(self, ts) 2824 2825 _add_to_memory( 2826 self, ts, ws, recommendations, client_msgs, type=type, typename=typename 2827 ) 2828 2829 if self._validate: 2830 assert not ts._processing_on 2831 assert not ts._waiting_on 2832 2833 return recommendations, client_msgs, worker_msgs 2834 except Exception as e: 2835 logger.exception(e) 2836 if LOG_PDB: 2837 import pdb 2838 2839 pdb.set_trace() 2840 raise 2841 2842 def transition_memory_released(self, key, safe: bint = False): 2843 ws: WorkerState 2844 try: 2845 ts: TaskState = self._tasks[key] 2846 dts: TaskState 2847 recommendations: dict = {} 2848 client_msgs: dict = {} 2849 worker_msgs: dict = {} 2850 2851 if self._validate: 2852 assert not ts._waiting_on 2853 assert not ts._processing_on 2854 if safe: 2855 assert not ts._waiters 2856 2857 if ts._actor: 2858 for ws in ts._who_has: 2859 ws._actors.discard(ts) 2860 if ts._who_wants: 2861 ts._exception_blame = ts 2862 ts._exception = "Worker holding Actor was lost" 2863 recommendations[ts._key] = "erred" 2864 return ( 2865 recommendations, 2866 client_msgs, 2867 worker_msgs, 2868 ) # don't try to recreate 2869 2870 for dts in ts._waiters: 2871 if dts._state in ("no-worker", "processing"): 2872 recommendations[dts._key] = "waiting" 2873 elif dts._state == "waiting": 2874 dts._waiting_on.add(ts) 2875 2876 # XXX factor this out? 2877 worker_msg = { 2878 "op": "free-keys", 2879 "keys": [key], 2880 "stimulus_id": f"memory-released-{time()}", 2881 } 2882 for ws in ts._who_has: 2883 worker_msgs[ws._address] = [worker_msg] 2884 self.remove_all_replicas(ts) 2885 2886 ts.state = "released" 2887 2888 report_msg = {"op": "lost-data", "key": key} 2889 cs: ClientState 2890 for cs in ts._who_wants: 2891 client_msgs[cs._client_key] = [report_msg] 2892 2893 if not ts._run_spec: # pure data 2894 recommendations[key] = "forgotten" 2895 elif ts._has_lost_dependencies: 2896 recommendations[key] = "forgotten" 2897 elif ts._who_wants or ts._waiters: 2898 recommendations[key] = "waiting" 2899 2900 if self._validate: 2901 assert not ts._waiting_on 2902 2903 return recommendations, client_msgs, worker_msgs 2904 except Exception as e: 2905 logger.exception(e) 2906 if LOG_PDB: 2907 import pdb 2908 2909 pdb.set_trace() 2910 raise 2911 2912 def transition_released_erred(self, key): 2913 try: 2914 ts: TaskState = self._tasks[key] 2915 dts: TaskState 2916 failing_ts: TaskState 2917 recommendations: dict = {} 2918 client_msgs: dict = {} 2919 worker_msgs: dict = {} 2920 2921 if self._validate: 2922 with log_errors(pdb=LOG_PDB): 2923 assert ts._exception_blame 2924 assert not ts._who_has 2925 assert not ts._waiting_on 2926 assert not ts._waiters 2927 2928 failing_ts = ts._exception_blame 2929 2930 for dts in ts._dependents: 2931 dts._exception_blame = failing_ts 2932 if not dts._who_has: 2933 recommendations[dts._key] = "erred" 2934 2935 report_msg = { 2936 "op": "task-erred", 2937 "key": key, 2938 "exception": failing_ts._exception, 2939 "traceback": failing_ts._traceback, 2940 } 2941 cs: ClientState 2942 for cs in ts._who_wants: 2943 client_msgs[cs._client_key] = [report_msg] 2944 2945 ts.state = "erred" 2946 2947 # TODO: waiting data? 2948 return recommendations, client_msgs, worker_msgs 2949 except Exception as e: 2950 logger.exception(e) 2951 if LOG_PDB: 2952 import pdb 2953 2954 pdb.set_trace() 2955 raise 2956 2957 def transition_erred_released(self, key): 2958 try: 2959 ts: TaskState = self._tasks[key] 2960 dts: TaskState 2961 recommendations: dict = {} 2962 client_msgs: dict = {} 2963 worker_msgs: dict = {} 2964 2965 if self._validate: 2966 with log_errors(pdb=LOG_PDB): 2967 assert ts._exception_blame 2968 assert not ts._who_has 2969 assert not ts._waiting_on 2970 assert not ts._waiters 2971 2972 ts._exception = None 2973 ts._exception_blame = None 2974 ts._traceback = None 2975 2976 for dts in ts._dependents: 2977 if dts._state == "erred": 2978 recommendations[dts._key] = "waiting" 2979 2980 w_msg = { 2981 "op": "free-keys", 2982 "keys": [key], 2983 "stimulus_id": f"erred-released-{time()}", 2984 } 2985 for ws_addr in ts._erred_on: 2986 worker_msgs[ws_addr] = [w_msg] 2987 ts._erred_on.clear() 2988 2989 report_msg = {"op": "task-retried", "key": key} 2990 cs: ClientState 2991 for cs in ts._who_wants: 2992 client_msgs[cs._client_key] = [report_msg] 2993 2994 ts.state = "released" 2995 2996 return recommendations, client_msgs, worker_msgs 2997 except Exception as e: 2998 logger.exception(e) 2999 if LOG_PDB: 3000 import pdb 3001 3002 pdb.set_trace() 3003 raise 3004 3005 def transition_waiting_released(self, key): 3006 try: 3007 ts: TaskState = self._tasks[key] 3008 recommendations: dict = {} 3009 client_msgs: dict = {} 3010 worker_msgs: dict = {} 3011 3012 if self._validate: 3013 assert not ts._who_has 3014 assert not ts._processing_on 3015 3016 dts: TaskState 3017 for dts in ts._dependencies: 3018 if ts in dts._waiters: 3019 dts._waiters.discard(ts) 3020 if not dts._waiters and not dts._who_wants: 3021 recommendations[dts._key] = "released" 3022 ts._waiting_on.clear() 3023 3024 ts.state = "released" 3025 3026 if ts._has_lost_dependencies: 3027 recommendations[key] = "forgotten" 3028 elif not ts._exception_blame and (ts._who_wants or ts._waiters): 3029 recommendations[key] = "waiting" 3030 else: 3031 ts._waiters.clear() 3032 3033 return recommendations, client_msgs, worker_msgs 3034 except Exception as e: 3035 logger.exception(e) 3036 if LOG_PDB: 3037 import pdb 3038 3039 pdb.set_trace() 3040 raise 3041 3042 def transition_processing_released(self, key): 3043 try: 3044 ts: TaskState = self._tasks[key] 3045 dts: TaskState 3046 recommendations: dict = {} 3047 client_msgs: dict = {} 3048 worker_msgs: dict = {} 3049 3050 if self._validate: 3051 assert ts._processing_on 3052 assert not ts._who_has 3053 assert not ts._waiting_on 3054 assert self._tasks[key].state == "processing" 3055 3056 w: str = _remove_from_processing(self, ts) 3057 if w: 3058 worker_msgs[w] = [ 3059 { 3060 "op": "free-keys", 3061 "keys": [key], 3062 "stimulus_id": f"processing-released-{time()}", 3063 } 3064 ] 3065 3066 ts.state = "released" 3067 3068 if ts._has_lost_dependencies: 3069 recommendations[key] = "forgotten" 3070 elif ts._waiters or ts._who_wants: 3071 recommendations[key] = "waiting" 3072 3073 if recommendations.get(key) != "waiting": 3074 for dts in ts._dependencies: 3075 if dts._state != "released": 3076 dts._waiters.discard(ts) 3077 if not dts._waiters and not dts._who_wants: 3078 recommendations[dts._key] = "released" 3079 ts._waiters.clear() 3080 3081 if self._validate: 3082 assert not ts._processing_on 3083 3084 return recommendations, client_msgs, worker_msgs 3085 except Exception as e: 3086 logger.exception(e) 3087 if LOG_PDB: 3088 import pdb 3089 3090 pdb.set_trace() 3091 raise 3092 3093 def transition_processing_erred( 3094 self, 3095 key: str, 3096 cause: str = None, 3097 exception=None, 3098 traceback=None, 3099 exception_text: str = None, 3100 traceback_text: str = None, 3101 worker: str = None, 3102 **kwargs, 3103 ): 3104 ws: WorkerState 3105 try: 3106 ts: TaskState = self._tasks[key] 3107 dts: TaskState 3108 failing_ts: TaskState 3109 recommendations: dict = {} 3110 client_msgs: dict = {} 3111 worker_msgs: dict = {} 3112 3113 if self._validate: 3114 assert cause or ts._exception_blame 3115 assert ts._processing_on 3116 assert not ts._who_has 3117 assert not ts._waiting_on 3118 3119 if ts._actor: 3120 ws = ts._processing_on 3121 ws._actors.remove(ts) 3122 3123 w = _remove_from_processing(self, ts) 3124 3125 ts._erred_on.add(w or worker) 3126 if exception is not None: 3127 ts._exception = exception 3128 ts._exception_text = exception_text # type: ignore 3129 if traceback is not None: 3130 ts._traceback = traceback 3131 ts._traceback_text = traceback_text # type: ignore 3132 if cause is not None: 3133 failing_ts = self._tasks[cause] 3134 ts._exception_blame = failing_ts 3135 else: 3136 failing_ts = ts._exception_blame # type: ignore 3137 3138 for dts in ts._dependents: 3139 dts._exception_blame = failing_ts 3140 recommendations[dts._key] = "erred" 3141 3142 for dts in ts._dependencies: 3143 dts._waiters.discard(ts) 3144 if not dts._waiters and not dts._who_wants: 3145 recommendations[dts._key] = "released" 3146 3147 ts._waiters.clear() # do anything with this? 3148 3149 ts.state = "erred" 3150 3151 report_msg = { 3152 "op": "task-erred", 3153 "key": key, 3154 "exception": failing_ts._exception, 3155 "traceback": failing_ts._traceback, 3156 } 3157 cs: ClientState 3158 for cs in ts._who_wants: 3159 client_msgs[cs._client_key] = [report_msg] 3160 3161 cs = self._clients["fire-and-forget"] 3162 if ts in cs._wants_what: 3163 _client_releases_keys( 3164 self, 3165 cs=cs, 3166 keys=[key], 3167 recommendations=recommendations, 3168 ) 3169 3170 if self._validate: 3171 assert not ts._processing_on 3172 3173 return recommendations, client_msgs, worker_msgs 3174 except Exception as e: 3175 logger.exception(e) 3176 if LOG_PDB: 3177 import pdb 3178 3179 pdb.set_trace() 3180 raise 3181 3182 def transition_no_worker_released(self, key): 3183 try: 3184 ts: TaskState = self._tasks[key] 3185 dts: TaskState 3186 recommendations: dict = {} 3187 client_msgs: dict = {} 3188 worker_msgs: dict = {} 3189 3190 if self._validate: 3191 assert self._tasks[key].state == "no-worker" 3192 assert not ts._who_has 3193 assert not ts._waiting_on 3194 3195 self._unrunnable.remove(ts) 3196 ts.state = "released" 3197 3198 for dts in ts._dependencies: 3199 dts._waiters.discard(ts) 3200 3201 ts._waiters.clear() 3202 3203 return recommendations, client_msgs, worker_msgs 3204 except Exception as e: 3205 logger.exception(e) 3206 if LOG_PDB: 3207 import pdb 3208 3209 pdb.set_trace() 3210 raise 3211 3212 @ccall 3213 def remove_key(self, key): 3214 ts: TaskState = self._tasks.pop(key) 3215 assert ts._state == "forgotten" 3216 self._unrunnable.discard(ts) 3217 cs: ClientState 3218 for cs in ts._who_wants: 3219 cs._wants_what.remove(ts) 3220 ts._who_wants.clear() 3221 ts._processing_on = None 3222 ts._exception_blame = ts._exception = ts._traceback = None 3223 self._task_metadata.pop(key, None) 3224 3225 def transition_memory_forgotten(self, key): 3226 ws: WorkerState 3227 try: 3228 ts: TaskState = self._tasks[key] 3229 recommendations: dict = {} 3230 client_msgs: dict = {} 3231 worker_msgs: dict = {} 3232 3233 if self._validate: 3234 assert ts._state == "memory" 3235 assert not ts._processing_on 3236 assert not ts._waiting_on 3237 if not ts._run_spec: 3238 # It's ok to forget a pure data task 3239 pass 3240 elif ts._has_lost_dependencies: 3241 # It's ok to forget a task with forgotten dependencies 3242 pass 3243 elif not ts._who_wants and not ts._waiters and not ts._dependents: 3244 # It's ok to forget a task that nobody needs 3245 pass 3246 else: 3247 assert 0, (ts,) 3248 3249 if ts._actor: 3250 for ws in ts._who_has: 3251 ws._actors.discard(ts) 3252 3253 _propagate_forgotten(self, ts, recommendations, worker_msgs) 3254 3255 client_msgs = _task_to_client_msgs(self, ts) 3256 self.remove_key(key) 3257 3258 return recommendations, client_msgs, worker_msgs 3259 except Exception as e: 3260 logger.exception(e) 3261 if LOG_PDB: 3262 import pdb 3263 3264 pdb.set_trace() 3265 raise 3266 3267 def transition_released_forgotten(self, key): 3268 try: 3269 ts: TaskState = self._tasks[key] 3270 recommendations: dict = {} 3271 client_msgs: dict = {} 3272 worker_msgs: dict = {} 3273 3274 if self._validate: 3275 assert ts._state in ("released", "erred") 3276 assert not ts._who_has 3277 assert not ts._processing_on 3278 assert not ts._waiting_on, (ts, ts._waiting_on) 3279 if not ts._run_spec: 3280 # It's ok to forget a pure data task 3281 pass 3282 elif ts._has_lost_dependencies: 3283 # It's ok to forget a task with forgotten dependencies 3284 pass 3285 elif not ts._who_wants and not ts._waiters and not ts._dependents: 3286 # It's ok to forget a task that nobody needs 3287 pass 3288 else: 3289 assert 0, (ts,) 3290 3291 _propagate_forgotten(self, ts, recommendations, worker_msgs) 3292 3293 client_msgs = _task_to_client_msgs(self, ts) 3294 self.remove_key(key) 3295 3296 return recommendations, client_msgs, worker_msgs 3297 except Exception as e: 3298 logger.exception(e) 3299 if LOG_PDB: 3300 import pdb 3301 3302 pdb.set_trace() 3303 raise 3304 3305 ############################## 3306 # Assigning Tasks to Workers # 3307 ############################## 3308 3309 @ccall 3310 @exceptval(check=False) 3311 def check_idle_saturated(self, ws: WorkerState, occ: double = -1.0): 3312 """Update the status of the idle and saturated state 3313 3314 The scheduler keeps track of workers that are .. 3315 3316 - Saturated: have enough work to stay busy 3317 - Idle: do not have enough work to stay busy 3318 3319 They are considered saturated if they both have enough tasks to occupy 3320 all of their threads, and if the expected runtime of those tasks is 3321 large enough. 3322 3323 This is useful for load balancing and adaptivity. 3324 """ 3325 if self._total_nthreads == 0 or ws.status == Status.closed: 3326 return 3327 if occ < 0: 3328 occ = ws._occupancy 3329 3330 nc: Py_ssize_t = ws._nthreads 3331 p: Py_ssize_t = len(ws._processing) 3332 avg: double = self._total_occupancy / self._total_nthreads 3333 3334 idle = self._idle 3335 saturated: set = self._saturated 3336 if p < nc or occ < nc * avg / 2: 3337 idle[ws._address] = ws 3338 saturated.discard(ws) 3339 else: 3340 idle.pop(ws._address, None) 3341 3342 if p > nc: 3343 pending: double = occ * (p - nc) / (p * nc) 3344 if 0.4 < pending > 1.9 * avg: 3345 saturated.add(ws) 3346 return 3347 3348 saturated.discard(ws) 3349 3350 @ccall 3351 def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> double: 3352 """ 3353 Get the estimated communication cost (in s.) to compute the task 3354 on the given worker. 3355 """ 3356 dts: TaskState 3357 deps: set = ts._dependencies.difference(ws._has_what) 3358 nbytes: Py_ssize_t = 0 3359 for dts in deps: 3360 nbytes += dts._nbytes 3361 return nbytes / self._bandwidth 3362 3363 @ccall 3364 def get_task_duration(self, ts: TaskState) -> double: 3365 """Get the estimated computation cost of the given task (not including 3366 any communication cost). 3367 3368 If no data has been observed, value of 3369 `distributed.scheduler.default-task-durations` are used. If none is set 3370 for this task, `distributed.scheduler.unknown-task-duration` is used 3371 instead. 3372 """ 3373 duration: double = ts._prefix._duration_average 3374 if duration >= 0: 3375 return duration 3376 3377 s: set = self._unknown_durations.get(ts._prefix._name) # type: ignore 3378 if s is None: 3379 self._unknown_durations[ts._prefix._name] = s = set() 3380 s.add(ts) 3381 return self.UNKNOWN_TASK_DURATION 3382 3383 @ccall 3384 @exceptval(check=False) 3385 def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None 3386 """Return set of currently valid workers for key 3387 3388 If all workers are valid then this returns ``None``. 3389 This checks tracks the following state: 3390 3391 * worker_restrictions 3392 * host_restrictions 3393 * resource_restrictions 3394 """ 3395 s: set = None # type: ignore 3396 3397 if ts._worker_restrictions: 3398 s = {addr for addr in ts._worker_restrictions if addr in self._workers_dv} 3399 3400 if ts._host_restrictions: 3401 # Resolve the alias here rather than early, for the worker 3402 # may not be connected when host_restrictions is populated 3403 hr: list = [self.coerce_hostname(h) for h in ts._host_restrictions] 3404 # XXX need HostState? 3405 sl: list = [] 3406 for h in hr: 3407 dh: dict = self._host_info.get(h) # type: ignore 3408 if dh is not None: 3409 sl.append(dh["addresses"]) 3410 3411 ss: set = set.union(*sl) if sl else set() 3412 if s is None: 3413 s = ss 3414 else: 3415 s |= ss 3416 3417 if ts._resource_restrictions: 3418 dw: dict = {} 3419 for resource, required in ts._resource_restrictions.items(): 3420 dr: dict = self._resources.get(resource) # type: ignore 3421 if dr is None: 3422 self._resources[resource] = dr = {} 3423 3424 sw: set = set() 3425 for addr, supplied in dr.items(): 3426 if supplied >= required: 3427 sw.add(addr) 3428 3429 dw[resource] = sw 3430 3431 ww: set = set.intersection(*dw.values()) 3432 if s is None: 3433 s = ww 3434 else: 3435 s &= ww 3436 3437 if s is None: 3438 if len(self._running) < len(self._workers_dv): 3439 return self._running.copy() 3440 else: 3441 s = {self._workers_dv[addr] for addr in s} 3442 if len(self._running) < len(self._workers_dv): 3443 s &= self._running 3444 3445 return s 3446 3447 @ccall 3448 def consume_resources(self, ts: TaskState, ws: WorkerState): 3449 if ts._resource_restrictions: 3450 for r, required in ts._resource_restrictions.items(): 3451 ws._used_resources[r] += required 3452 3453 @ccall 3454 def release_resources(self, ts: TaskState, ws: WorkerState): 3455 if ts._resource_restrictions: 3456 for r, required in ts._resource_restrictions.items(): 3457 ws._used_resources[r] -= required 3458 3459 @ccall 3460 def coerce_hostname(self, host): 3461 """ 3462 Coerce the hostname of a worker. 3463 """ 3464 alias = self._aliases.get(host) 3465 if alias is not None: 3466 ws: WorkerState = self._workers_dv[alias] 3467 return ws.host 3468 else: 3469 return host 3470 3471 @ccall 3472 @exceptval(check=False) 3473 def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: 3474 """ 3475 Objective function to determine which worker should get the task 3476 3477 Minimize expected start time. If a tie then break with data storage. 3478 """ 3479 dts: TaskState 3480 nbytes: Py_ssize_t 3481 comm_bytes: Py_ssize_t = 0 3482 for dts in ts._dependencies: 3483 if ws not in dts._who_has: 3484 nbytes = dts.get_nbytes() 3485 comm_bytes += nbytes 3486 3487 stack_time: double = ws._occupancy / ws._nthreads 3488 start_time: double = stack_time + comm_bytes / self._bandwidth 3489 3490 if ts._actor: 3491 return (len(ws._actors), start_time, ws._nbytes) 3492 else: 3493 return (start_time, ws._nbytes) 3494 3495 @ccall 3496 def add_replica(self, ts: TaskState, ws: WorkerState): 3497 """Note that a worker holds a replica of a task with state='memory'""" 3498 if self._validate: 3499 assert ws not in ts._who_has 3500 assert ts not in ws._has_what 3501 3502 ws._nbytes += ts.get_nbytes() 3503 ws._has_what[ts] = None 3504 ts._who_has.add(ws) 3505 if len(ts._who_has) == 2: 3506 self._replicated_tasks.add(ts) 3507 3508 @ccall 3509 def remove_replica(self, ts: TaskState, ws: WorkerState): 3510 """Note that a worker no longer holds a replica of a task""" 3511 ws._nbytes -= ts.get_nbytes() 3512 del ws._has_what[ts] 3513 ts._who_has.remove(ws) 3514 if len(ts._who_has) == 1: 3515 self._replicated_tasks.remove(ts) 3516 3517 @ccall 3518 def remove_all_replicas(self, ts: TaskState): 3519 """Remove all replicas of a task from all workers""" 3520 ws: WorkerState 3521 nbytes: Py_ssize_t = ts.get_nbytes() 3522 for ws in ts._who_has: 3523 ws._nbytes -= nbytes 3524 del ws._has_what[ts] 3525 if len(ts._who_has) > 1: 3526 self._replicated_tasks.remove(ts) 3527 ts._who_has.clear() 3528 3529 3530class Scheduler(SchedulerState, ServerNode): 3531 """Dynamic distributed task scheduler 3532 3533 The scheduler tracks the current state of workers, data, and computations. 3534 The scheduler listens for events and responds by controlling workers 3535 appropriately. It continuously tries to use the workers to execute an ever 3536 growing dask graph. 3537 3538 All events are handled quickly, in linear time with respect to their input 3539 (which is often of constant size) and generally within a millisecond. To 3540 accomplish this the scheduler tracks a lot of state. Every operation 3541 maintains the consistency of this state. 3542 3543 The scheduler communicates with the outside world through Comm objects. 3544 It maintains a consistent and valid view of the world even when listening 3545 to several clients at once. 3546 3547 A Scheduler is typically started either with the ``dask-scheduler`` 3548 executable:: 3549 3550 $ dask-scheduler 3551 Scheduler started at 127.0.0.1:8786 3552 3553 Or within a LocalCluster a Client starts up without connection 3554 information:: 3555 3556 >>> c = Client() # doctest: +SKIP 3557 >>> c.cluster.scheduler # doctest: +SKIP 3558 Scheduler(...) 3559 3560 Users typically do not interact with the scheduler directly but rather with 3561 the client object ``Client``. 3562 3563 **State** 3564 3565 The scheduler contains the following state variables. Each variable is 3566 listed along with what it stores and a brief description. 3567 3568 * **tasks:** ``{task key: TaskState}`` 3569 Tasks currently known to the scheduler 3570 * **unrunnable:** ``{TaskState}`` 3571 Tasks in the "no-worker" state 3572 3573 * **workers:** ``{worker key: WorkerState}`` 3574 Workers currently connected to the scheduler 3575 * **idle:** ``{WorkerState}``: 3576 Set of workers that are not fully utilized 3577 * **saturated:** ``{WorkerState}``: 3578 Set of workers that are not over-utilized 3579 3580 * **host_info:** ``{hostname: dict}``: 3581 Information about each worker host 3582 3583 * **clients:** ``{client key: ClientState}`` 3584 Clients currently connected to the scheduler 3585 3586 * **services:** ``{str: port}``: 3587 Other services running on this scheduler, like Bokeh 3588 * **loop:** ``IOLoop``: 3589 The running Tornado IOLoop 3590 * **client_comms:** ``{client key: Comm}`` 3591 For each client, a Comm object used to receive task requests and 3592 report task status updates. 3593 * **stream_comms:** ``{worker key: Comm}`` 3594 For each worker, a Comm object from which we both accept stimuli and 3595 report results 3596 * **task_duration:** ``{key-prefix: time}`` 3597 Time we expect certain functions to take, e.g. ``{'sum': 0.25}`` 3598 """ 3599 3600 default_port = 8786 3601 _instances: "ClassVar[weakref.WeakSet[Scheduler]]" = weakref.WeakSet() 3602 3603 def __init__( 3604 self, 3605 loop=None, 3606 delete_interval="500ms", 3607 synchronize_worker_interval="60s", 3608 services=None, 3609 service_kwargs=None, 3610 allowed_failures=None, 3611 extensions=None, 3612 validate=None, 3613 scheduler_file=None, 3614 security=None, 3615 worker_ttl=None, 3616 idle_timeout=None, 3617 interface=None, 3618 host=None, 3619 port=0, 3620 protocol=None, 3621 dashboard_address=None, 3622 dashboard=None, 3623 http_prefix="/", 3624 preload=None, 3625 preload_argv=(), 3626 plugins=(), 3627 **kwargs, 3628 ): 3629 self._setup_logging(logger) 3630 3631 # Attributes 3632 if allowed_failures is None: 3633 allowed_failures = dask.config.get("distributed.scheduler.allowed-failures") 3634 self.allowed_failures = allowed_failures 3635 if validate is None: 3636 validate = dask.config.get("distributed.scheduler.validate") 3637 self.proc = psutil.Process() 3638 self.delete_interval = parse_timedelta(delete_interval, default="ms") 3639 self.synchronize_worker_interval = parse_timedelta( 3640 synchronize_worker_interval, default="ms" 3641 ) 3642 self.digests = None 3643 self.service_specs = services or {} 3644 self.service_kwargs = service_kwargs or {} 3645 self.services = {} 3646 self.scheduler_file = scheduler_file 3647 worker_ttl = worker_ttl or dask.config.get("distributed.scheduler.worker-ttl") 3648 self.worker_ttl = parse_timedelta(worker_ttl) if worker_ttl else None 3649 idle_timeout = idle_timeout or dask.config.get( 3650 "distributed.scheduler.idle-timeout" 3651 ) 3652 if idle_timeout: 3653 self.idle_timeout = parse_timedelta(idle_timeout) 3654 else: 3655 self.idle_timeout = None 3656 self.idle_since = time() 3657 self.time_started = self.idle_since # compatibility for dask-gateway 3658 self._lock = asyncio.Lock() 3659 self.bandwidth_workers = defaultdict(float) 3660 self.bandwidth_types = defaultdict(float) 3661 3662 if not preload: 3663 preload = dask.config.get("distributed.scheduler.preload") 3664 if not preload_argv: 3665 preload_argv = dask.config.get("distributed.scheduler.preload-argv") 3666 self.preloads = preloading.process_preloads(self, preload, preload_argv) 3667 3668 if isinstance(security, dict): 3669 security = Security(**security) 3670 self.security = security or Security() 3671 assert isinstance(self.security, Security) 3672 self.connection_args = self.security.get_connection_args("scheduler") 3673 self.connection_args["handshake_overrides"] = { # common denominator 3674 "pickle-protocol": 4 3675 } 3676 3677 self._start_address = addresses_from_user_args( 3678 host=host, 3679 port=port, 3680 interface=interface, 3681 protocol=protocol, 3682 security=security, 3683 default_port=self.default_port, 3684 ) 3685 3686 http_server_modules = dask.config.get("distributed.scheduler.http.routes") 3687 show_dashboard = dashboard or (dashboard is None and dashboard_address) 3688 # install vanilla route if show_dashboard but bokeh is not installed 3689 if show_dashboard: 3690 try: 3691 import distributed.dashboard.scheduler 3692 except ImportError: 3693 show_dashboard = False 3694 http_server_modules.append("distributed.http.scheduler.missing_bokeh") 3695 routes = get_handlers( 3696 server=self, modules=http_server_modules, prefix=http_prefix 3697 ) 3698 self.start_http_server(routes, dashboard_address, default_port=8787) 3699 if show_dashboard: 3700 distributed.dashboard.scheduler.connect( 3701 self.http_application, self.http_server, self, prefix=http_prefix 3702 ) 3703 3704 # Communication state 3705 self.loop = loop or IOLoop.current() 3706 self.client_comms = {} 3707 self.stream_comms = {} 3708 self._worker_coroutines = [] 3709 self._ipython_kernel = None 3710 3711 # Task state 3712 tasks = {} 3713 for old_attr, new_attr, wrap in [ 3714 ("priority", "priority", None), 3715 ("dependencies", "dependencies", _legacy_task_key_set), 3716 ("dependents", "dependents", _legacy_task_key_set), 3717 ("retries", "retries", None), 3718 ]: 3719 func = operator.attrgetter(new_attr) 3720 if wrap is not None: 3721 func = compose(wrap, func) 3722 setattr(self, old_attr, _StateLegacyMapping(tasks, func)) 3723 3724 for old_attr, new_attr, wrap in [ 3725 ("nbytes", "nbytes", None), 3726 ("who_wants", "who_wants", _legacy_client_key_set), 3727 ("who_has", "who_has", _legacy_worker_key_set), 3728 ("waiting", "waiting_on", _legacy_task_key_set), 3729 ("waiting_data", "waiters", _legacy_task_key_set), 3730 ("rprocessing", "processing_on", None), 3731 ("host_restrictions", "host_restrictions", None), 3732 ("worker_restrictions", "worker_restrictions", None), 3733 ("resource_restrictions", "resource_restrictions", None), 3734 ("suspicious_tasks", "suspicious", None), 3735 ("exceptions", "exception", None), 3736 ("tracebacks", "traceback", None), 3737 ("exceptions_blame", "exception_blame", _task_key_or_none), 3738 ]: 3739 func = operator.attrgetter(new_attr) 3740 if wrap is not None: 3741 func = compose(wrap, func) 3742 setattr(self, old_attr, _OptionalStateLegacyMapping(tasks, func)) 3743 3744 for old_attr, new_attr, wrap in [ 3745 ("loose_restrictions", "loose_restrictions", None) 3746 ]: 3747 func = operator.attrgetter(new_attr) 3748 if wrap is not None: 3749 func = compose(wrap, func) 3750 setattr(self, old_attr, _StateLegacySet(tasks, func)) 3751 3752 self.generation = 0 3753 self._last_client = None 3754 self._last_time = 0 3755 unrunnable = set() 3756 3757 self.datasets = {} 3758 3759 # Prefix-keyed containers 3760 3761 # Client state 3762 clients = {} 3763 for old_attr, new_attr, wrap in [ 3764 ("wants_what", "wants_what", _legacy_task_key_set) 3765 ]: 3766 func = operator.attrgetter(new_attr) 3767 if wrap is not None: 3768 func = compose(wrap, func) 3769 setattr(self, old_attr, _StateLegacyMapping(clients, func)) 3770 3771 # Worker state 3772 workers = SortedDict() 3773 for old_attr, new_attr, wrap in [ 3774 ("nthreads", "nthreads", None), 3775 ("worker_bytes", "nbytes", None), 3776 ("worker_resources", "resources", None), 3777 ("used_resources", "used_resources", None), 3778 ("occupancy", "occupancy", None), 3779 ("worker_info", "metrics", None), 3780 ("processing", "processing", _legacy_task_key_dict), 3781 ("has_what", "has_what", _legacy_task_key_set), 3782 ]: 3783 func = operator.attrgetter(new_attr) 3784 if wrap is not None: 3785 func = compose(wrap, func) 3786 setattr(self, old_attr, _StateLegacyMapping(workers, func)) 3787 3788 host_info = {} 3789 resources = {} 3790 aliases = {} 3791 3792 self._task_state_collections = [unrunnable] 3793 3794 self._worker_collections = [ 3795 workers, 3796 host_info, 3797 resources, 3798 aliases, 3799 ] 3800 3801 self.transition_log = deque( 3802 maxlen=dask.config.get("distributed.scheduler.transition-log-length") 3803 ) 3804 self.log = deque( 3805 maxlen=dask.config.get("distributed.scheduler.transition-log-length") 3806 ) 3807 self.events = defaultdict( 3808 partial( 3809 deque, maxlen=dask.config.get("distributed.scheduler.events-log-length") 3810 ) 3811 ) 3812 self.event_counts = defaultdict(int) 3813 self.event_subscriber = defaultdict(set) 3814 self.worker_plugins = {} 3815 self.nanny_plugins = {} 3816 3817 worker_handlers = { 3818 "task-finished": self.handle_task_finished, 3819 "task-erred": self.handle_task_erred, 3820 "release-worker-data": self.release_worker_data, 3821 "add-keys": self.add_keys, 3822 "missing-data": self.handle_missing_data, 3823 "long-running": self.handle_long_running, 3824 "reschedule": self.reschedule, 3825 "keep-alive": lambda *args, **kwargs: None, 3826 "log-event": self.log_worker_event, 3827 "worker-status-change": self.handle_worker_status_change, 3828 } 3829 3830 client_handlers = { 3831 "update-graph": self.update_graph, 3832 "update-graph-hlg": self.update_graph_hlg, 3833 "client-desires-keys": self.client_desires_keys, 3834 "update-data": self.update_data, 3835 "report-key": self.report_on_key, 3836 "client-releases-keys": self.client_releases_keys, 3837 "heartbeat-client": self.client_heartbeat, 3838 "close-client": self.remove_client, 3839 "restart": self.restart, 3840 "subscribe-topic": self.subscribe_topic, 3841 "unsubscribe-topic": self.unsubscribe_topic, 3842 } 3843 3844 self.handlers = { 3845 "register-client": self.add_client, 3846 "scatter": self.scatter, 3847 "register-worker": self.add_worker, 3848 "register_nanny": self.add_nanny, 3849 "unregister": self.remove_worker, 3850 "gather": self.gather, 3851 "cancel": self.stimulus_cancel, 3852 "retry": self.stimulus_retry, 3853 "feed": self.feed, 3854 "terminate": self.close, 3855 "broadcast": self.broadcast, 3856 "proxy": self.proxy, 3857 "ncores": self.get_ncores, 3858 "ncores_running": self.get_ncores_running, 3859 "has_what": self.get_has_what, 3860 "who_has": self.get_who_has, 3861 "processing": self.get_processing, 3862 "call_stack": self.get_call_stack, 3863 "profile": self.get_profile, 3864 "performance_report": self.performance_report, 3865 "get_logs": self.get_logs, 3866 "logs": self.get_logs, 3867 "worker_logs": self.get_worker_logs, 3868 "log_event": self.log_worker_event, 3869 "events": self.get_events, 3870 "nbytes": self.get_nbytes, 3871 "versions": self.versions, 3872 "add_keys": self.add_keys, 3873 "rebalance": self.rebalance, 3874 "replicate": self.replicate, 3875 "start_ipython": self.start_ipython, 3876 "run_function": self.run_function, 3877 "update_data": self.update_data, 3878 "set_resources": self.add_resources, 3879 "retire_workers": self.retire_workers, 3880 "get_metadata": self.get_metadata, 3881 "set_metadata": self.set_metadata, 3882 "set_restrictions": self.set_restrictions, 3883 "heartbeat_worker": self.heartbeat_worker, 3884 "get_task_status": self.get_task_status, 3885 "get_task_stream": self.get_task_stream, 3886 "register_scheduler_plugin": self.register_scheduler_plugin, 3887 "register_worker_plugin": self.register_worker_plugin, 3888 "unregister_worker_plugin": self.unregister_worker_plugin, 3889 "register_nanny_plugin": self.register_nanny_plugin, 3890 "unregister_nanny_plugin": self.unregister_nanny_plugin, 3891 "adaptive_target": self.adaptive_target, 3892 "workers_to_close": self.workers_to_close, 3893 "subscribe_worker_status": self.subscribe_worker_status, 3894 "start_task_metadata": self.start_task_metadata, 3895 "stop_task_metadata": self.stop_task_metadata, 3896 } 3897 3898 connection_limit = get_fileno_limit() / 2 3899 3900 super().__init__( 3901 # Arguments to SchedulerState 3902 aliases=aliases, 3903 clients=clients, 3904 workers=workers, 3905 host_info=host_info, 3906 resources=resources, 3907 tasks=tasks, 3908 unrunnable=unrunnable, 3909 validate=validate, 3910 plugins=plugins, 3911 # Arguments to ServerNode 3912 handlers=self.handlers, 3913 stream_handlers=merge(worker_handlers, client_handlers), 3914 io_loop=self.loop, 3915 connection_limit=connection_limit, 3916 deserialize=False, 3917 connection_args=self.connection_args, 3918 **kwargs, 3919 ) 3920 3921 if self.worker_ttl: 3922 pc = PeriodicCallback(self.check_worker_ttl, self.worker_ttl) 3923 self.periodic_callbacks["worker-ttl"] = pc 3924 3925 if self.idle_timeout: 3926 pc = PeriodicCallback(self.check_idle, self.idle_timeout / 4) 3927 self.periodic_callbacks["idle-timeout"] = pc 3928 3929 if extensions is None: 3930 extensions = list(DEFAULT_EXTENSIONS) 3931 if dask.config.get("distributed.scheduler.work-stealing"): 3932 extensions.append(WorkStealing) 3933 for ext in extensions: 3934 ext(self) 3935 3936 setproctitle("dask-scheduler [not started]") 3937 Scheduler._instances.add(self) 3938 self.rpc.allow_offload = False 3939 self.status = Status.undefined 3940 3941 ################## 3942 # Administration # 3943 ################## 3944 3945 def __repr__(self): 3946 parent: SchedulerState = cast(SchedulerState, self) 3947 return '<Scheduler: "%s" workers: %d cores: %d, tasks: %d>' % ( 3948 self.address, 3949 len(parent._workers_dv), 3950 parent._total_nthreads, 3951 len(parent._tasks), 3952 ) 3953 3954 def _repr_html_(self): 3955 parent: SchedulerState = cast(SchedulerState, self) 3956 return get_template("scheduler.html.j2").render( 3957 address=self.address, 3958 workers=parent._workers_dv, 3959 threads=parent._total_nthreads, 3960 tasks=parent._tasks, 3961 ) 3962 3963 def identity(self, comm=None): 3964 """Basic information about ourselves and our cluster""" 3965 parent: SchedulerState = cast(SchedulerState, self) 3966 d = { 3967 "type": type(self).__name__, 3968 "id": str(self.id), 3969 "address": self.address, 3970 "services": {key: v.port for (key, v) in self.services.items()}, 3971 "started": self.time_started, 3972 "workers": { 3973 worker.address: worker.identity() 3974 for worker in parent._workers_dv.values() 3975 }, 3976 } 3977 return d 3978 3979 def _to_dict( 3980 self, comm: Comm = None, *, exclude: Container[str] = None 3981 ) -> "dict[str, Any]": 3982 """ 3983 A very verbose dictionary representation for debugging purposes. 3984 Not type stable and not inteded for roundtrips. 3985 3986 Parameters 3987 ---------- 3988 comm: 3989 exclude: 3990 A list of attributes which must not be present in the output. 3991 3992 See also 3993 -------- 3994 Server.identity 3995 Client.dump_cluster_state 3996 """ 3997 3998 info = super()._to_dict(exclude=exclude) 3999 extra = { 4000 "transition_log": self.transition_log, 4001 "log": self.log, 4002 "tasks": self.tasks, 4003 "events": self.events, 4004 } 4005 info.update(extra) 4006 extensions = {} 4007 for name, ex in self.extensions.items(): 4008 if hasattr(ex, "_to_dict"): 4009 extensions[name] = ex._to_dict() 4010 return recursive_to_dict(info, exclude=exclude) 4011 4012 def get_worker_service_addr(self, worker, service_name, protocol=False): 4013 """ 4014 Get the (host, port) address of the named service on the *worker*. 4015 Returns None if the service doesn't exist. 4016 4017 Parameters 4018 ---------- 4019 worker : address 4020 service_name : str 4021 Common services include 'bokeh' and 'nanny' 4022 protocol : boolean 4023 Whether or not to include a full address with protocol (True) 4024 or just a (host, port) pair 4025 """ 4026 parent: SchedulerState = cast(SchedulerState, self) 4027 ws: WorkerState = parent._workers_dv[worker] 4028 port = ws._services.get(service_name) 4029 if port is None: 4030 return None 4031 elif protocol: 4032 return "%(protocol)s://%(host)s:%(port)d" % { 4033 "protocol": ws._address.split("://")[0], 4034 "host": ws.host, 4035 "port": port, 4036 } 4037 else: 4038 return ws.host, port 4039 4040 async def start(self): 4041 """Clear out old state and restart all running coroutines""" 4042 await super().start() 4043 assert self.status != Status.running 4044 4045 enable_gc_diagnosis() 4046 4047 self.clear_task_state() 4048 4049 with suppress(AttributeError): 4050 for c in self._worker_coroutines: 4051 c.cancel() 4052 4053 for addr in self._start_address: 4054 await self.listen( 4055 addr, 4056 allow_offload=False, 4057 handshake_overrides={"pickle-protocol": 4, "compression": None}, 4058 **self.security.get_listen_args("scheduler"), 4059 ) 4060 self.ip = get_address_host(self.listen_address) 4061 listen_ip = self.ip 4062 4063 if listen_ip == "0.0.0.0": 4064 listen_ip = "" 4065 4066 if self.address.startswith("inproc://"): 4067 listen_ip = "localhost" 4068 4069 # Services listen on all addresses 4070 self.start_services(listen_ip) 4071 4072 for listener in self.listeners: 4073 logger.info(" Scheduler at: %25s", listener.contact_address) 4074 for k, v in self.services.items(): 4075 logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port)) 4076 4077 self.loop.add_callback(self.reevaluate_occupancy) 4078 4079 if self.scheduler_file: 4080 with open(self.scheduler_file, "w") as f: 4081 json.dump(self.identity(), f, indent=2) 4082 4083 fn = self.scheduler_file # remove file when we close the process 4084 4085 def del_scheduler_file(): 4086 if os.path.exists(fn): 4087 os.remove(fn) 4088 4089 weakref.finalize(self, del_scheduler_file) 4090 4091 for preload in self.preloads: 4092 await preload.start() 4093 4094 await asyncio.gather( 4095 *[plugin.start(self) for plugin in list(self.plugins.values())] 4096 ) 4097 4098 self.start_periodic_callbacks() 4099 4100 setproctitle(f"dask-scheduler [{self.address}]") 4101 return self 4102 4103 async def close(self, comm=None, fast=False, close_workers=False): 4104 """Send cleanup signal to all coroutines then wait until finished 4105 4106 See Also 4107 -------- 4108 Scheduler.cleanup 4109 """ 4110 parent: SchedulerState = cast(SchedulerState, self) 4111 if self.status in (Status.closing, Status.closed): 4112 await self.finished() 4113 return 4114 self.status = Status.closing 4115 4116 logger.info("Scheduler closing...") 4117 setproctitle("dask-scheduler [closing]") 4118 4119 for preload in self.preloads: 4120 await preload.teardown() 4121 4122 if close_workers: 4123 await self.broadcast(msg={"op": "close_gracefully"}, nanny=True) 4124 for worker in parent._workers_dv: 4125 # Report would require the worker to unregister with the 4126 # currently closing scheduler. This is not necessary and might 4127 # delay shutdown of the worker unnecessarily 4128 self.worker_send(worker, {"op": "close", "report": False}) 4129 for i in range(20): # wait a second for send signals to clear 4130 if parent._workers_dv: 4131 await asyncio.sleep(0.05) 4132 else: 4133 break 4134 4135 await asyncio.gather( 4136 *[plugin.close() for plugin in list(self.plugins.values())] 4137 ) 4138 4139 for pc in self.periodic_callbacks.values(): 4140 pc.stop() 4141 self.periodic_callbacks.clear() 4142 4143 self.stop_services() 4144 4145 for ext in parent._extensions.values(): 4146 with suppress(AttributeError): 4147 ext.teardown() 4148 logger.info("Scheduler closing all comms") 4149 4150 futures = [] 4151 for w, comm in list(self.stream_comms.items()): 4152 if not comm.closed(): 4153 comm.send({"op": "close", "report": False}) 4154 comm.send({"op": "close-stream"}) 4155 with suppress(AttributeError): 4156 futures.append(comm.close()) 4157 4158 for future in futures: # TODO: do all at once 4159 await future 4160 4161 for comm in self.client_comms.values(): 4162 comm.abort() 4163 4164 await self.rpc.close() 4165 4166 self.status = Status.closed 4167 self.stop() 4168 await super().close() 4169 4170 setproctitle("dask-scheduler [closed]") 4171 disable_gc_diagnosis() 4172 4173 async def close_worker(self, comm=None, worker=None, safe=None): 4174 """Remove a worker from the cluster 4175 4176 This both removes the worker from our local state and also sends a 4177 signal to the worker to shut down. This works regardless of whether or 4178 not the worker has a nanny process restarting it 4179 """ 4180 logger.info("Closing worker %s", worker) 4181 with log_errors(): 4182 self.log_event(worker, {"action": "close-worker"}) 4183 # FIXME: This does not handle nannies 4184 self.worker_send(worker, {"op": "close", "report": False}) 4185 await self.remove_worker(address=worker, safe=safe) 4186 4187 ########### 4188 # Stimuli # 4189 ########### 4190 4191 def heartbeat_worker( 4192 self, 4193 comm=None, 4194 *, 4195 address, 4196 resolve_address: bool = True, 4197 now: float = None, 4198 resources: dict = None, 4199 host_info: dict = None, 4200 metrics: dict, 4201 executing: dict = None, 4202 ): 4203 parent: SchedulerState = cast(SchedulerState, self) 4204 address = self.coerce_address(address, resolve_address) 4205 address = normalize_address(address) 4206 ws: WorkerState = parent._workers_dv.get(address) # type: ignore 4207 if ws is None: 4208 return {"status": "missing"} 4209 4210 host = get_address_host(address) 4211 local_now = time() 4212 host_info = host_info or {} 4213 4214 dh: dict = parent._host_info.setdefault(host, {}) 4215 dh["last-seen"] = local_now 4216 4217 frac = 1 / len(parent._workers_dv) 4218 parent._bandwidth = ( 4219 parent._bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac 4220 ) 4221 for other, (bw, count) in metrics["bandwidth"]["workers"].items(): 4222 if (address, other) not in self.bandwidth_workers: 4223 self.bandwidth_workers[address, other] = bw / count 4224 else: 4225 alpha = (1 - frac) ** count 4226 self.bandwidth_workers[address, other] = self.bandwidth_workers[ 4227 address, other 4228 ] * alpha + bw * (1 - alpha) 4229 for typ, (bw, count) in metrics["bandwidth"]["types"].items(): 4230 if typ not in self.bandwidth_types: 4231 self.bandwidth_types[typ] = bw / count 4232 else: 4233 alpha = (1 - frac) ** count 4234 self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * ( 4235 1 - alpha 4236 ) 4237 4238 ws._last_seen = local_now 4239 if executing is not None: 4240 ws._executing = { 4241 parent._tasks[key]: duration 4242 for key, duration in executing.items() 4243 if key in parent._tasks 4244 } 4245 4246 ws._metrics = metrics 4247 4248 # Calculate RSS - dask keys, separating "old" and "new" usage 4249 # See MemoryState for details 4250 max_memory_unmanaged_old_hist_age = local_now - parent.MEMORY_RECENT_TO_OLD_TIME 4251 memory_unmanaged_old = ws._memory_unmanaged_old 4252 while ws._memory_other_history: 4253 timestamp, size = ws._memory_other_history[0] 4254 if timestamp >= max_memory_unmanaged_old_hist_age: 4255 break 4256 ws._memory_other_history.popleft() 4257 if size == memory_unmanaged_old: 4258 memory_unmanaged_old = 0 # recalculate min() 4259 4260 # metrics["memory"] is None if the worker sent a heartbeat before its 4261 # SystemMonitor ever had a chance to run. 4262 # ws._nbytes is updated at a different time and sizeof() may not be accurate, 4263 # so size may be (temporarily) negative; floor it to zero. 4264 size = max(0, (metrics["memory"] or 0) - ws._nbytes + metrics["spilled_nbytes"]) 4265 4266 ws._memory_other_history.append((local_now, size)) 4267 if not memory_unmanaged_old: 4268 # The worker has just been started or the previous minimum has been expunged 4269 # because too old. 4270 # Note: this algorithm is capped to 200 * MEMORY_RECENT_TO_OLD_TIME elements 4271 # cluster-wide by heartbeat_interval(), regardless of the number of workers 4272 ws._memory_unmanaged_old = min(map(second, ws._memory_other_history)) 4273 elif size < memory_unmanaged_old: 4274 ws._memory_unmanaged_old = size 4275 4276 if host_info: 4277 dh = parent._host_info.setdefault(host, {}) 4278 dh.update(host_info) 4279 4280 if now: 4281 ws._time_delay = local_now - now 4282 4283 if resources: 4284 self.add_resources(worker=address, resources=resources) 4285 4286 self.log_event(address, merge({"action": "heartbeat"}, metrics)) 4287 4288 return { 4289 "status": "OK", 4290 "time": local_now, 4291 "heartbeat-interval": heartbeat_interval(len(parent._workers_dv)), 4292 } 4293 4294 async def add_worker( 4295 self, 4296 comm=None, 4297 *, 4298 address: str, 4299 status: str, 4300 keys=(), 4301 nthreads=None, 4302 name=None, 4303 resolve_address=True, 4304 nbytes=None, 4305 types=None, 4306 now=None, 4307 resources=None, 4308 host_info=None, 4309 memory_limit=None, 4310 metrics=None, 4311 pid=0, 4312 services=None, 4313 local_directory=None, 4314 versions=None, 4315 nanny=None, 4316 extra=None, 4317 ): 4318 """Add a new worker to the cluster""" 4319 parent: SchedulerState = cast(SchedulerState, self) 4320 with log_errors(): 4321 address = self.coerce_address(address, resolve_address) 4322 address = normalize_address(address) 4323 host = get_address_host(address) 4324 4325 if address in parent._workers_dv: 4326 raise ValueError("Worker already exists %s" % address) 4327 4328 if name in parent._aliases: 4329 logger.warning( 4330 "Worker tried to connect with a duplicate name: %s", name 4331 ) 4332 msg = { 4333 "status": "error", 4334 "message": "name taken, %s" % name, 4335 "time": time(), 4336 } 4337 if comm: 4338 await comm.write(msg) 4339 return 4340 4341 ws: WorkerState 4342 parent._workers[address] = ws = WorkerState( 4343 address=address, 4344 status=Status.lookup[status], # type: ignore 4345 pid=pid, 4346 nthreads=nthreads, 4347 memory_limit=memory_limit or 0, 4348 name=name, 4349 local_directory=local_directory, 4350 services=services, 4351 versions=versions, 4352 nanny=nanny, 4353 extra=extra, 4354 ) 4355 if ws._status == Status.running: 4356 parent._running.add(ws) 4357 4358 dh: dict = parent._host_info.get(host) # type: ignore 4359 if dh is None: 4360 parent._host_info[host] = dh = {} 4361 4362 dh_addresses: set = dh.get("addresses") # type: ignore 4363 if dh_addresses is None: 4364 dh["addresses"] = dh_addresses = set() 4365 dh["nthreads"] = 0 4366 4367 dh_addresses.add(address) 4368 dh["nthreads"] += nthreads 4369 4370 parent._total_nthreads += nthreads 4371 parent._aliases[name] = address 4372 4373 self.heartbeat_worker( 4374 address=address, 4375 resolve_address=resolve_address, 4376 now=now, 4377 resources=resources, 4378 host_info=host_info, 4379 metrics=metrics, 4380 ) 4381 4382 # Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot 4383 # exist before this. 4384 self.check_idle_saturated(ws) 4385 4386 # for key in keys: # TODO 4387 # self.mark_key_in_memory(key, [address]) 4388 4389 self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop) 4390 4391 if ws._nthreads > len(ws._processing): 4392 parent._idle[ws._address] = ws 4393 4394 for plugin in list(self.plugins.values()): 4395 try: 4396 result = plugin.add_worker(scheduler=self, worker=address) 4397 if inspect.isawaitable(result): 4398 await result 4399 except Exception as e: 4400 logger.exception(e) 4401 4402 recommendations: dict = {} 4403 client_msgs: dict = {} 4404 worker_msgs: dict = {} 4405 if nbytes: 4406 assert isinstance(nbytes, dict) 4407 already_released_keys = [] 4408 for key in nbytes: 4409 ts: TaskState = parent._tasks.get(key) # type: ignore 4410 if ts is not None and ts.state != "released": 4411 if ts.state == "memory": 4412 self.add_keys(worker=address, keys=[key]) 4413 else: 4414 t: tuple = parent._transition( 4415 key, 4416 "memory", 4417 worker=address, 4418 nbytes=nbytes[key], 4419 typename=types[key], 4420 ) 4421 recommendations, client_msgs, worker_msgs = t 4422 parent._transitions( 4423 recommendations, client_msgs, worker_msgs 4424 ) 4425 recommendations = {} 4426 else: 4427 already_released_keys.append(key) 4428 if already_released_keys: 4429 if address not in worker_msgs: 4430 worker_msgs[address] = [] 4431 worker_msgs[address].append( 4432 { 4433 "op": "remove-replicas", 4434 "keys": already_released_keys, 4435 "stimulus_id": f"reconnect-already-released-{time()}", 4436 } 4437 ) 4438 4439 if ws._status == Status.running: 4440 for ts in parent._unrunnable: 4441 valid: set = self.valid_workers(ts) 4442 if valid is None or ws in valid: 4443 recommendations[ts._key] = "waiting" 4444 4445 if recommendations: 4446 parent._transitions(recommendations, client_msgs, worker_msgs) 4447 4448 self.send_all(client_msgs, worker_msgs) 4449 4450 self.log_event(address, {"action": "add-worker"}) 4451 self.log_event("all", {"action": "add-worker", "worker": address}) 4452 logger.info("Register worker %s", ws) 4453 4454 msg = { 4455 "status": "OK", 4456 "time": time(), 4457 "heartbeat-interval": heartbeat_interval(len(parent._workers_dv)), 4458 "worker-plugins": self.worker_plugins, 4459 } 4460 4461 cs: ClientState 4462 version_warning = version_module.error_message( 4463 version_module.get_versions(), 4464 merge( 4465 {w: ws._versions for w, ws in parent._workers_dv.items()}, 4466 { 4467 c: cs._versions 4468 for c, cs in parent._clients.items() 4469 if cs._versions 4470 }, 4471 ), 4472 versions, 4473 client_name="This Worker", 4474 ) 4475 msg.update(version_warning) 4476 4477 if comm: 4478 await comm.write(msg) 4479 4480 await self.handle_worker(comm=comm, worker=address) 4481 4482 async def add_nanny(self, comm): 4483 msg = { 4484 "status": "OK", 4485 "nanny-plugins": self.nanny_plugins, 4486 } 4487 return msg 4488 4489 def update_graph_hlg( 4490 self, 4491 client=None, 4492 hlg=None, 4493 keys=None, 4494 dependencies=None, 4495 restrictions=None, 4496 priority=None, 4497 loose_restrictions=None, 4498 resources=None, 4499 submitting_task=None, 4500 retries=None, 4501 user_priority=0, 4502 actors=None, 4503 fifo_timeout=0, 4504 code=None, 4505 ): 4506 unpacked_graph = HighLevelGraph.__dask_distributed_unpack__(hlg) 4507 dsk = unpacked_graph["dsk"] 4508 dependencies = unpacked_graph["deps"] 4509 annotations = unpacked_graph["annotations"] 4510 4511 # Remove any self-dependencies (happens on test_publish_bag() and others) 4512 for k, v in dependencies.items(): 4513 deps = set(v) 4514 if k in deps: 4515 deps.remove(k) 4516 dependencies[k] = deps 4517 4518 if priority is None: 4519 # Removing all non-local keys before calling order() 4520 dsk_keys = set(dsk) # intersection() of sets is much faster than dict_keys 4521 stripped_deps = { 4522 k: v.intersection(dsk_keys) 4523 for k, v in dependencies.items() 4524 if k in dsk_keys 4525 } 4526 priority = dask.order.order(dsk, dependencies=stripped_deps) 4527 4528 return self.update_graph( 4529 client, 4530 dsk, 4531 keys, 4532 dependencies, 4533 restrictions, 4534 priority, 4535 loose_restrictions, 4536 resources, 4537 submitting_task, 4538 retries, 4539 user_priority, 4540 actors, 4541 fifo_timeout, 4542 annotations, 4543 code=code, 4544 ) 4545 4546 def update_graph( 4547 self, 4548 client=None, 4549 tasks=None, 4550 keys=None, 4551 dependencies=None, 4552 restrictions=None, 4553 priority=None, 4554 loose_restrictions=None, 4555 resources=None, 4556 submitting_task=None, 4557 retries=None, 4558 user_priority=0, 4559 actors=None, 4560 fifo_timeout=0, 4561 annotations=None, 4562 code=None, 4563 ): 4564 """ 4565 Add new computations to the internal dask graph 4566 4567 This happens whenever the Client calls submit, map, get, or compute. 4568 """ 4569 parent: SchedulerState = cast(SchedulerState, self) 4570 start = time() 4571 fifo_timeout = parse_timedelta(fifo_timeout) 4572 keys = set(keys) 4573 if len(tasks) > 1: 4574 self.log_event( 4575 ["all", client], {"action": "update_graph", "count": len(tasks)} 4576 ) 4577 4578 # Remove aliases 4579 for k in list(tasks): 4580 if tasks[k] is k: 4581 del tasks[k] 4582 4583 dependencies = dependencies or {} 4584 4585 if parent._total_occupancy > 1e-9 and parent._computations: 4586 # Still working on something. Assign new tasks to same computation 4587 computation = cast(Computation, parent._computations[-1]) 4588 else: 4589 computation = Computation() 4590 parent._computations.append(computation) 4591 4592 if code and code not in computation._code: # add new code blocks 4593 computation._code.add(code) 4594 4595 n = 0 4596 while len(tasks) != n: # walk through new tasks, cancel any bad deps 4597 n = len(tasks) 4598 for k, deps in list(dependencies.items()): 4599 if any( 4600 dep not in parent._tasks and dep not in tasks for dep in deps 4601 ): # bad key 4602 logger.info("User asked for computation on lost data, %s", k) 4603 del tasks[k] 4604 del dependencies[k] 4605 if k in keys: 4606 keys.remove(k) 4607 self.report({"op": "cancelled-key", "key": k}, client=client) 4608 self.client_releases_keys(keys=[k], client=client) 4609 4610 # Avoid computation that is already finished 4611 ts: TaskState 4612 already_in_memory = set() # tasks that are already done 4613 for k, v in dependencies.items(): 4614 if v and k in parent._tasks: 4615 ts = parent._tasks[k] 4616 if ts._state in ("memory", "erred"): 4617 already_in_memory.add(k) 4618 4619 dts: TaskState 4620 if already_in_memory: 4621 dependents = dask.core.reverse_dict(dependencies) 4622 stack = list(already_in_memory) 4623 done = set(already_in_memory) 4624 while stack: # remove unnecessary dependencies 4625 key = stack.pop() 4626 ts = parent._tasks[key] 4627 try: 4628 deps = dependencies[key] 4629 except KeyError: 4630 deps = self.dependencies[key] 4631 for dep in deps: 4632 if dep in dependents: 4633 child_deps = dependents[dep] 4634 else: 4635 child_deps = self.dependencies[dep] 4636 if all(d in done for d in child_deps): 4637 if dep in parent._tasks and dep not in done: 4638 done.add(dep) 4639 stack.append(dep) 4640 4641 for d in done: 4642 tasks.pop(d, None) 4643 dependencies.pop(d, None) 4644 4645 # Get or create task states 4646 stack = list(keys) 4647 touched_keys = set() 4648 touched_tasks = [] 4649 while stack: 4650 k = stack.pop() 4651 if k in touched_keys: 4652 continue 4653 # XXX Have a method get_task_state(self, k) ? 4654 ts = parent._tasks.get(k) 4655 if ts is None: 4656 ts = parent.new_task( 4657 k, tasks.get(k), "released", computation=computation 4658 ) 4659 elif not ts._run_spec: 4660 ts._run_spec = tasks.get(k) 4661 4662 touched_keys.add(k) 4663 touched_tasks.append(ts) 4664 stack.extend(dependencies.get(k, ())) 4665 4666 self.client_desires_keys(keys=keys, client=client) 4667 4668 # Add dependencies 4669 for key, deps in dependencies.items(): 4670 ts = parent._tasks.get(key) 4671 if ts is None or ts._dependencies: 4672 continue 4673 for dep in deps: 4674 dts = parent._tasks[dep] 4675 ts.add_dependency(dts) 4676 4677 # Compute priorities 4678 if isinstance(user_priority, Number): 4679 user_priority = {k: user_priority for k in tasks} 4680 4681 annotations = annotations or {} 4682 restrictions = restrictions or {} 4683 loose_restrictions = loose_restrictions or [] 4684 resources = resources or {} 4685 retries = retries or {} 4686 4687 # Override existing taxonomy with per task annotations 4688 if annotations: 4689 if "priority" in annotations: 4690 user_priority.update(annotations["priority"]) 4691 4692 if "workers" in annotations: 4693 restrictions.update(annotations["workers"]) 4694 4695 if "allow_other_workers" in annotations: 4696 loose_restrictions.extend( 4697 k for k, v in annotations["allow_other_workers"].items() if v 4698 ) 4699 4700 if "retries" in annotations: 4701 retries.update(annotations["retries"]) 4702 4703 if "resources" in annotations: 4704 resources.update(annotations["resources"]) 4705 4706 for a, kv in annotations.items(): 4707 for k, v in kv.items(): 4708 # Tasks might have been culled, in which case 4709 # we have nothing to annotate. 4710 ts = parent._tasks.get(k) 4711 if ts is not None: 4712 ts._annotations[a] = v 4713 4714 # Add actors 4715 if actors is True: 4716 actors = list(keys) 4717 for actor in actors or []: 4718 ts = parent._tasks[actor] 4719 ts._actor = True 4720 4721 priority = priority or dask.order.order( 4722 tasks 4723 ) # TODO: define order wrt old graph 4724 4725 if submitting_task: # sub-tasks get better priority than parent tasks 4726 ts = parent._tasks.get(submitting_task) 4727 if ts is not None: 4728 generation = ts._priority[0] - 0.01 4729 else: # super-task already cleaned up 4730 generation = self.generation 4731 elif self._last_time + fifo_timeout < start: 4732 self.generation += 1 # older graph generations take precedence 4733 generation = self.generation 4734 self._last_time = start 4735 else: 4736 generation = self.generation 4737 4738 for key in set(priority) & touched_keys: 4739 ts = parent._tasks[key] 4740 if ts._priority is None: 4741 ts._priority = (-(user_priority.get(key, 0)), generation, priority[key]) 4742 4743 # Ensure all runnables have a priority 4744 runnables = [ts for ts in touched_tasks if ts._run_spec] 4745 for ts in runnables: 4746 if ts._priority is None and ts._run_spec: 4747 ts._priority = (self.generation, 0) 4748 4749 if restrictions: 4750 # *restrictions* is a dict keying task ids to lists of 4751 # restriction specifications (either worker names or addresses) 4752 for k, v in restrictions.items(): 4753 if v is None: 4754 continue 4755 ts = parent._tasks.get(k) 4756 if ts is None: 4757 continue 4758 ts._host_restrictions = set() 4759 ts._worker_restrictions = set() 4760 # Make sure `v` is a collection and not a single worker name / address 4761 if not isinstance(v, (list, tuple, set)): 4762 v = [v] 4763 for w in v: 4764 try: 4765 w = self.coerce_address(w) 4766 except ValueError: 4767 # Not a valid address, but perhaps it's a hostname 4768 ts._host_restrictions.add(w) 4769 else: 4770 ts._worker_restrictions.add(w) 4771 4772 if loose_restrictions: 4773 for k in loose_restrictions: 4774 ts = parent._tasks[k] 4775 ts._loose_restrictions = True 4776 4777 if resources: 4778 for k, v in resources.items(): 4779 if v is None: 4780 continue 4781 assert isinstance(v, dict) 4782 ts = parent._tasks.get(k) 4783 if ts is None: 4784 continue 4785 ts._resource_restrictions = v 4786 4787 if retries: 4788 for k, v in retries.items(): 4789 assert isinstance(v, int) 4790 ts = parent._tasks.get(k) 4791 if ts is None: 4792 continue 4793 ts._retries = v 4794 4795 # Compute recommendations 4796 recommendations: dict = {} 4797 4798 for ts in sorted(runnables, key=operator.attrgetter("priority"), reverse=True): 4799 if ts._state == "released" and ts._run_spec: 4800 recommendations[ts._key] = "waiting" 4801 4802 for ts in touched_tasks: 4803 for dts in ts._dependencies: 4804 if dts._exception_blame: 4805 ts._exception_blame = dts._exception_blame 4806 recommendations[ts._key] = "erred" 4807 break 4808 4809 for plugin in list(self.plugins.values()): 4810 try: 4811 plugin.update_graph( 4812 self, 4813 client=client, 4814 tasks=tasks, 4815 keys=keys, 4816 restrictions=restrictions or {}, 4817 dependencies=dependencies, 4818 priority=priority, 4819 loose_restrictions=loose_restrictions, 4820 resources=resources, 4821 annotations=annotations, 4822 ) 4823 except Exception as e: 4824 logger.exception(e) 4825 4826 self.transitions(recommendations) 4827 4828 for ts in touched_tasks: 4829 if ts._state in ("memory", "erred"): 4830 self.report_on_key(ts=ts, client=client) 4831 4832 end = time() 4833 if self.digests is not None: 4834 self.digests["update-graph-duration"].add(end - start) 4835 4836 # TODO: balance workers 4837 4838 def stimulus_task_finished(self, key=None, worker=None, **kwargs): 4839 """Mark that a task has finished execution on a particular worker""" 4840 parent: SchedulerState = cast(SchedulerState, self) 4841 logger.debug("Stimulus task finished %s, %s", key, worker) 4842 4843 recommendations: dict = {} 4844 client_msgs: dict = {} 4845 worker_msgs: dict = {} 4846 4847 ws: WorkerState = parent._workers_dv[worker] 4848 ts: TaskState = parent._tasks.get(key) 4849 if ts is None or ts._state == "released": 4850 logger.debug( 4851 "Received already computed task, worker: %s, state: %s" 4852 ", key: %s, who_has: %s", 4853 worker, 4854 ts._state if ts else "forgotten", 4855 key, 4856 ts._who_has if ts else {}, 4857 ) 4858 worker_msgs[worker] = [ 4859 { 4860 "op": "free-keys", 4861 "keys": [key], 4862 "stimulus_id": f"already-released-or-forgotten-{time()}", 4863 } 4864 ] 4865 elif ts._state == "memory": 4866 self.add_keys(worker=worker, keys=[key]) 4867 else: 4868 ts._metadata.update(kwargs["metadata"]) 4869 r: tuple = parent._transition(key, "memory", worker=worker, **kwargs) 4870 recommendations, client_msgs, worker_msgs = r 4871 4872 if ts._state == "memory": 4873 assert ws in ts._who_has 4874 return recommendations, client_msgs, worker_msgs 4875 4876 def stimulus_task_erred( 4877 self, key=None, worker=None, exception=None, traceback=None, **kwargs 4878 ): 4879 """Mark that a task has erred on a particular worker""" 4880 parent: SchedulerState = cast(SchedulerState, self) 4881 logger.debug("Stimulus task erred %s, %s", key, worker) 4882 4883 ts: TaskState = parent._tasks.get(key) 4884 if ts is None or ts._state != "processing": 4885 return {}, {}, {} 4886 4887 if ts._retries > 0: 4888 ts._retries -= 1 4889 return parent._transition(key, "waiting") 4890 else: 4891 return parent._transition( 4892 key, 4893 "erred", 4894 cause=key, 4895 exception=exception, 4896 traceback=traceback, 4897 worker=worker, 4898 **kwargs, 4899 ) 4900 4901 def stimulus_retry(self, comm=None, keys=None, client=None): 4902 parent: SchedulerState = cast(SchedulerState, self) 4903 logger.info("Client %s requests to retry %d keys", client, len(keys)) 4904 if client: 4905 self.log_event(client, {"action": "retry", "count": len(keys)}) 4906 4907 stack = list(keys) 4908 seen = set() 4909 roots = [] 4910 ts: TaskState 4911 dts: TaskState 4912 while stack: 4913 key = stack.pop() 4914 seen.add(key) 4915 ts = parent._tasks[key] 4916 erred_deps = [dts._key for dts in ts._dependencies if dts._state == "erred"] 4917 if erred_deps: 4918 stack.extend(erred_deps) 4919 else: 4920 roots.append(key) 4921 4922 recommendations: dict = {key: "waiting" for key in roots} 4923 self.transitions(recommendations) 4924 4925 if parent._validate: 4926 for key in seen: 4927 assert not parent._tasks[key].exception_blame 4928 4929 return tuple(seen) 4930 4931 async def remove_worker(self, comm=None, address=None, safe=False, close=True): 4932 """ 4933 Remove worker from cluster 4934 4935 We do this when a worker reports that it plans to leave or when it 4936 appears to be unresponsive. This may send its tasks back to a released 4937 state. 4938 """ 4939 parent: SchedulerState = cast(SchedulerState, self) 4940 with log_errors(): 4941 if self.status == Status.closed: 4942 return 4943 4944 address = self.coerce_address(address) 4945 4946 if address not in parent._workers_dv: 4947 return "already-removed" 4948 4949 host = get_address_host(address) 4950 4951 ws: WorkerState = parent._workers_dv[address] 4952 4953 self.log_event( 4954 ["all", address], 4955 { 4956 "action": "remove-worker", 4957 "processing-tasks": dict(ws._processing), 4958 }, 4959 ) 4960 logger.info("Remove worker %s", ws) 4961 if close: 4962 with suppress(AttributeError, CommClosedError): 4963 self.stream_comms[address].send({"op": "close", "report": False}) 4964 4965 self.remove_resources(address) 4966 4967 dh: dict = parent._host_info.get(host) 4968 if dh is None: 4969 parent._host_info[host] = dh = {} 4970 4971 dh_addresses: set = dh["addresses"] 4972 dh_addresses.remove(address) 4973 dh["nthreads"] -= ws._nthreads 4974 parent._total_nthreads -= ws._nthreads 4975 4976 if not dh_addresses: 4977 dh = None 4978 dh_addresses = None 4979 del parent._host_info[host] 4980 4981 self.rpc.remove(address) 4982 del self.stream_comms[address] 4983 del parent._aliases[ws._name] 4984 parent._idle.pop(ws._address, None) 4985 parent._saturated.discard(ws) 4986 del parent._workers[address] 4987 ws.status = Status.closed 4988 parent._running.discard(ws) 4989 parent._total_occupancy -= ws._occupancy 4990 4991 recommendations: dict = {} 4992 4993 ts: TaskState 4994 for ts in list(ws._processing): 4995 k = ts._key 4996 recommendations[k] = "released" 4997 if not safe: 4998 ts._suspicious += 1 4999 ts._prefix._suspicious += 1 5000 if ts._suspicious > self.allowed_failures: 5001 del recommendations[k] 5002 e = pickle.dumps( 5003 KilledWorker(task=k, last_worker=ws.clean()), protocol=4 5004 ) 5005 r = self.transition(k, "erred", exception=e, cause=k) 5006 recommendations.update(r) 5007 logger.info( 5008 "Task %s marked as failed because %d workers died" 5009 " while trying to run it", 5010 ts._key, 5011 self.allowed_failures, 5012 ) 5013 5014 for ts in list(ws._has_what): 5015 parent.remove_replica(ts, ws) 5016 if not ts._who_has: 5017 if ts._run_spec: 5018 recommendations[ts._key] = "released" 5019 else: # pure data 5020 recommendations[ts._key] = "forgotten" 5021 5022 self.transitions(recommendations) 5023 5024 for plugin in list(self.plugins.values()): 5025 try: 5026 result = plugin.remove_worker(scheduler=self, worker=address) 5027 if inspect.isawaitable(result): 5028 await result 5029 except Exception as e: 5030 logger.exception(e) 5031 5032 if not parent._workers_dv: 5033 logger.info("Lost all workers") 5034 5035 for w in parent._workers_dv: 5036 self.bandwidth_workers.pop((address, w), None) 5037 self.bandwidth_workers.pop((w, address), None) 5038 5039 def remove_worker_from_events(): 5040 # If the worker isn't registered anymore after the delay, remove from events 5041 if address not in parent._workers_dv and address in self.events: 5042 del self.events[address] 5043 5044 cleanup_delay = parse_timedelta( 5045 dask.config.get("distributed.scheduler.events-cleanup-delay") 5046 ) 5047 self.loop.call_later(cleanup_delay, remove_worker_from_events) 5048 logger.debug("Removed worker %s", ws) 5049 5050 return "OK" 5051 5052 def stimulus_cancel(self, comm, keys=None, client=None, force=False): 5053 """Stop execution on a list of keys""" 5054 logger.info("Client %s requests to cancel %d keys", client, len(keys)) 5055 if client: 5056 self.log_event( 5057 client, {"action": "cancel", "count": len(keys), "force": force} 5058 ) 5059 for key in keys: 5060 self.cancel_key(key, client, force=force) 5061 5062 def cancel_key(self, key, client, retries=5, force=False): 5063 """Cancel a particular key and all dependents""" 5064 # TODO: this should be converted to use the transition mechanism 5065 parent: SchedulerState = cast(SchedulerState, self) 5066 ts: TaskState = parent._tasks.get(key) 5067 dts: TaskState 5068 try: 5069 cs: ClientState = parent._clients[client] 5070 except KeyError: 5071 return 5072 if ts is None or not ts._who_wants: # no key yet, lets try again in a moment 5073 if retries: 5074 self.loop.call_later( 5075 0.2, lambda: self.cancel_key(key, client, retries - 1) 5076 ) 5077 return 5078 if force or ts._who_wants == {cs}: # no one else wants this key 5079 for dts in list(ts._dependents): 5080 self.cancel_key(dts._key, client, force=force) 5081 logger.info("Scheduler cancels key %s. Force=%s", key, force) 5082 self.report({"op": "cancelled-key", "key": key}) 5083 clients = list(ts._who_wants) if force else [cs] 5084 for cs in clients: 5085 self.client_releases_keys(keys=[key], client=cs._client_key) 5086 5087 def client_desires_keys(self, keys=None, client=None): 5088 parent: SchedulerState = cast(SchedulerState, self) 5089 cs: ClientState = parent._clients.get(client) 5090 if cs is None: 5091 # For publish, queues etc. 5092 parent._clients[client] = cs = ClientState(client) 5093 ts: TaskState 5094 for k in keys: 5095 ts = parent._tasks.get(k) 5096 if ts is None: 5097 # For publish, queues etc. 5098 ts = parent.new_task(k, None, "released") 5099 ts._who_wants.add(cs) 5100 cs._wants_what.add(ts) 5101 5102 if ts._state in ("memory", "erred"): 5103 self.report_on_key(ts=ts, client=client) 5104 5105 def client_releases_keys(self, keys=None, client=None): 5106 """Remove keys from client desired list""" 5107 5108 parent: SchedulerState = cast(SchedulerState, self) 5109 if not isinstance(keys, list): 5110 keys = list(keys) 5111 cs: ClientState = parent._clients[client] 5112 recommendations: dict = {} 5113 5114 _client_releases_keys(parent, keys=keys, cs=cs, recommendations=recommendations) 5115 self.transitions(recommendations) 5116 5117 def client_heartbeat(self, client=None): 5118 """Handle heartbeats from Client""" 5119 parent: SchedulerState = cast(SchedulerState, self) 5120 cs: ClientState = parent._clients[client] 5121 cs._last_seen = time() 5122 5123 ################### 5124 # Task Validation # 5125 ################### 5126 5127 def validate_released(self, key): 5128 parent: SchedulerState = cast(SchedulerState, self) 5129 ts: TaskState = parent._tasks[key] 5130 dts: TaskState 5131 assert ts._state == "released" 5132 assert not ts._waiters 5133 assert not ts._waiting_on 5134 assert not ts._who_has 5135 assert not ts._processing_on 5136 assert not any([ts in dts._waiters for dts in ts._dependencies]) 5137 assert ts not in parent._unrunnable 5138 5139 def validate_waiting(self, key): 5140 parent: SchedulerState = cast(SchedulerState, self) 5141 ts: TaskState = parent._tasks[key] 5142 dts: TaskState 5143 assert ts._waiting_on 5144 assert not ts._who_has 5145 assert not ts._processing_on 5146 assert ts not in parent._unrunnable 5147 for dts in ts._dependencies: 5148 # We are waiting on a dependency iff it's not stored 5149 assert bool(dts._who_has) != (dts in ts._waiting_on) 5150 assert ts in dts._waiters # XXX even if dts._who_has? 5151 5152 def validate_processing(self, key): 5153 parent: SchedulerState = cast(SchedulerState, self) 5154 ts: TaskState = parent._tasks[key] 5155 dts: TaskState 5156 assert not ts._waiting_on 5157 ws: WorkerState = ts._processing_on 5158 assert ws 5159 assert ts in ws._processing 5160 assert not ts._who_has 5161 for dts in ts._dependencies: 5162 assert dts._who_has 5163 assert ts in dts._waiters 5164 5165 def validate_memory(self, key): 5166 parent: SchedulerState = cast(SchedulerState, self) 5167 ts: TaskState = parent._tasks[key] 5168 dts: TaskState 5169 assert ts._who_has 5170 assert bool(ts in parent._replicated_tasks) == (len(ts._who_has) > 1) 5171 assert not ts._processing_on 5172 assert not ts._waiting_on 5173 assert ts not in parent._unrunnable 5174 for dts in ts._dependents: 5175 assert (dts in ts._waiters) == (dts._state in ("waiting", "processing")) 5176 assert ts not in dts._waiting_on 5177 5178 def validate_no_worker(self, key): 5179 parent: SchedulerState = cast(SchedulerState, self) 5180 ts: TaskState = parent._tasks[key] 5181 dts: TaskState 5182 assert ts in parent._unrunnable 5183 assert not ts._waiting_on 5184 assert ts in parent._unrunnable 5185 assert not ts._processing_on 5186 assert not ts._who_has 5187 for dts in ts._dependencies: 5188 assert dts._who_has 5189 5190 def validate_erred(self, key): 5191 parent: SchedulerState = cast(SchedulerState, self) 5192 ts: TaskState = parent._tasks[key] 5193 assert ts._exception_blame 5194 assert not ts._who_has 5195 5196 def validate_key(self, key, ts: TaskState = None): 5197 parent: SchedulerState = cast(SchedulerState, self) 5198 try: 5199 if ts is None: 5200 ts = parent._tasks.get(key) 5201 if ts is None: 5202 logger.debug("Key lost: %s", key) 5203 else: 5204 ts.validate() 5205 try: 5206 func = getattr(self, "validate_" + ts._state.replace("-", "_")) 5207 except AttributeError: 5208 logger.error( 5209 "self.validate_%s not found", ts._state.replace("-", "_") 5210 ) 5211 else: 5212 func(key) 5213 except Exception as e: 5214 logger.exception(e) 5215 if LOG_PDB: 5216 import pdb 5217 5218 pdb.set_trace() 5219 raise 5220 5221 def validate_state(self, allow_overlap=False): 5222 parent: SchedulerState = cast(SchedulerState, self) 5223 validate_state(parent._tasks, parent._workers, parent._clients) 5224 5225 if not (set(parent._workers_dv) == set(self.stream_comms)): 5226 raise ValueError("Workers not the same in all collections") 5227 5228 ws: WorkerState 5229 for w, ws in parent._workers_dv.items(): 5230 assert isinstance(w, str), (type(w), w) 5231 assert isinstance(ws, WorkerState), (type(ws), ws) 5232 assert ws._address == w 5233 if not ws._processing: 5234 assert not ws._occupancy 5235 assert ws._address in parent._idle_dv 5236 assert (ws._status == Status.running) == (ws in parent._running) 5237 5238 for ws in parent._running: 5239 assert ws._status == Status.running 5240 assert ws._address in parent._workers_dv 5241 5242 ts: TaskState 5243 for k, ts in parent._tasks.items(): 5244 assert isinstance(ts, TaskState), (type(ts), ts) 5245 assert ts._key == k 5246 assert bool(ts in parent._replicated_tasks) == (len(ts._who_has) > 1) 5247 self.validate_key(k, ts) 5248 5249 for ts in parent._replicated_tasks: 5250 assert ts._state == "memory" 5251 assert ts._key in parent._tasks 5252 5253 c: str 5254 cs: ClientState 5255 for c, cs in parent._clients.items(): 5256 # client=None is often used in tests... 5257 assert c is None or type(c) == str, (type(c), c) 5258 assert type(cs) == ClientState, (type(cs), cs) 5259 assert cs._client_key == c 5260 5261 a = {w: ws._nbytes for w, ws in parent._workers_dv.items()} 5262 b = { 5263 w: sum(ts.get_nbytes() for ts in ws._has_what) 5264 for w, ws in parent._workers_dv.items() 5265 } 5266 assert a == b, (a, b) 5267 5268 actual_total_occupancy = 0 5269 for worker, ws in parent._workers_dv.items(): 5270 assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8 5271 actual_total_occupancy += ws._occupancy 5272 5273 assert abs(actual_total_occupancy - parent._total_occupancy) < 1e-8, ( 5274 actual_total_occupancy, 5275 parent._total_occupancy, 5276 ) 5277 5278 ################### 5279 # Manage Messages # 5280 ################### 5281 5282 def report(self, msg: dict, ts: TaskState = None, client: str = None): 5283 """ 5284 Publish updates to all listening Queues and Comms 5285 5286 If the message contains a key then we only send the message to those 5287 comms that care about the key. 5288 """ 5289 parent: SchedulerState = cast(SchedulerState, self) 5290 if ts is None: 5291 msg_key = msg.get("key") 5292 if msg_key is not None: 5293 tasks: dict = parent._tasks 5294 ts = tasks.get(msg_key) 5295 5296 cs: ClientState 5297 client_comms: dict = self.client_comms 5298 client_keys: list 5299 if ts is None: 5300 # Notify all clients 5301 client_keys = list(client_comms) 5302 elif client is None: 5303 # Notify clients interested in key 5304 client_keys = [cs._client_key for cs in ts._who_wants] 5305 else: 5306 # Notify clients interested in key (including `client`) 5307 client_keys = [ 5308 cs._client_key for cs in ts._who_wants if cs._client_key != client 5309 ] 5310 client_keys.append(client) 5311 5312 k: str 5313 for k in client_keys: 5314 c = client_comms.get(k) 5315 if c is None: 5316 continue 5317 try: 5318 c.send(msg) 5319 # logger.debug("Scheduler sends message to client %s", msg) 5320 except CommClosedError: 5321 if self.status == Status.running: 5322 logger.critical( 5323 "Closed comm %r while trying to write %s", c, msg, exc_info=True 5324 ) 5325 5326 async def add_client(self, comm, client=None, versions=None): 5327 """Add client to network 5328 5329 We listen to all future messages from this Comm. 5330 """ 5331 parent: SchedulerState = cast(SchedulerState, self) 5332 assert client is not None 5333 comm.name = "Scheduler->Client" 5334 logger.info("Receive client connection: %s", client) 5335 self.log_event(["all", client], {"action": "add-client", "client": client}) 5336 parent._clients[client] = ClientState(client, versions=versions) 5337 5338 for plugin in list(self.plugins.values()): 5339 try: 5340 plugin.add_client(scheduler=self, client=client) 5341 except Exception as e: 5342 logger.exception(e) 5343 5344 try: 5345 bcomm = BatchedSend(interval="2ms", loop=self.loop) 5346 bcomm.start(comm) 5347 self.client_comms[client] = bcomm 5348 msg = {"op": "stream-start"} 5349 ws: WorkerState 5350 version_warning = version_module.error_message( 5351 version_module.get_versions(), 5352 {w: ws._versions for w, ws in parent._workers_dv.items()}, 5353 versions, 5354 ) 5355 msg.update(version_warning) 5356 bcomm.send(msg) 5357 5358 try: 5359 await self.handle_stream(comm=comm, extra={"client": client}) 5360 finally: 5361 self.remove_client(client=client) 5362 logger.debug("Finished handling client %s", client) 5363 finally: 5364 if not comm.closed(): 5365 self.client_comms[client].send({"op": "stream-closed"}) 5366 try: 5367 if not sys.is_finalizing(): 5368 await self.client_comms[client].close() 5369 del self.client_comms[client] 5370 if self.status == Status.running: 5371 logger.info("Close client connection: %s", client) 5372 except TypeError: # comm becomes None during GC 5373 pass 5374 5375 def remove_client(self, client=None): 5376 """Remove client from network""" 5377 parent: SchedulerState = cast(SchedulerState, self) 5378 if self.status == Status.running: 5379 logger.info("Remove client %s", client) 5380 self.log_event(["all", client], {"action": "remove-client", "client": client}) 5381 try: 5382 cs: ClientState = parent._clients[client] 5383 except KeyError: 5384 # XXX is this a legitimate condition? 5385 pass 5386 else: 5387 ts: TaskState 5388 self.client_releases_keys( 5389 keys=[ts._key for ts in cs._wants_what], client=cs._client_key 5390 ) 5391 del parent._clients[client] 5392 5393 for plugin in list(self.plugins.values()): 5394 try: 5395 plugin.remove_client(scheduler=self, client=client) 5396 except Exception as e: 5397 logger.exception(e) 5398 5399 def remove_client_from_events(): 5400 # If the client isn't registered anymore after the delay, remove from events 5401 if client not in parent._clients and client in self.events: 5402 del self.events[client] 5403 5404 cleanup_delay = parse_timedelta( 5405 dask.config.get("distributed.scheduler.events-cleanup-delay") 5406 ) 5407 self.loop.call_later(cleanup_delay, remove_client_from_events) 5408 5409 def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): 5410 """Send a single computational task to a worker""" 5411 parent: SchedulerState = cast(SchedulerState, self) 5412 try: 5413 msg: dict = _task_to_msg(parent, ts, duration) 5414 self.worker_send(worker, msg) 5415 except Exception as e: 5416 logger.exception(e) 5417 if LOG_PDB: 5418 import pdb 5419 5420 pdb.set_trace() 5421 raise 5422 5423 def handle_uncaught_error(self, **msg): 5424 logger.exception(clean_exception(**msg)[1]) 5425 5426 def handle_task_finished(self, key=None, worker=None, **msg): 5427 parent: SchedulerState = cast(SchedulerState, self) 5428 if worker not in parent._workers_dv: 5429 return 5430 validate_key(key) 5431 5432 recommendations: dict 5433 client_msgs: dict 5434 worker_msgs: dict 5435 5436 r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg) 5437 recommendations, client_msgs, worker_msgs = r 5438 parent._transitions(recommendations, client_msgs, worker_msgs) 5439 5440 self.send_all(client_msgs, worker_msgs) 5441 5442 def handle_task_erred(self, key=None, **msg): 5443 parent: SchedulerState = cast(SchedulerState, self) 5444 recommendations: dict 5445 client_msgs: dict 5446 worker_msgs: dict 5447 r: tuple = self.stimulus_task_erred(key=key, **msg) 5448 recommendations, client_msgs, worker_msgs = r 5449 parent._transitions(recommendations, client_msgs, worker_msgs) 5450 5451 self.send_all(client_msgs, worker_msgs) 5452 5453 def handle_missing_data(self, key=None, errant_worker=None, **kwargs): 5454 parent: SchedulerState = cast(SchedulerState, self) 5455 logger.debug("handle missing data key=%s worker=%s", key, errant_worker) 5456 self.log.append(("missing", key, errant_worker)) 5457 5458 ts: TaskState = parent._tasks.get(key) 5459 if ts is None: 5460 return 5461 ws: WorkerState = parent._workers_dv.get(errant_worker) 5462 if ws is not None and ws in ts._who_has: 5463 parent.remove_replica(ts, ws) 5464 if not ts._who_has: 5465 if ts._run_spec: 5466 self.transitions({key: "released"}) 5467 else: 5468 self.transitions({key: "forgotten"}) 5469 5470 def release_worker_data(self, comm=None, key=None, worker=None): 5471 parent: SchedulerState = cast(SchedulerState, self) 5472 ws: WorkerState = parent._workers_dv.get(worker) 5473 ts: TaskState = parent._tasks.get(key) 5474 if not ws or not ts: 5475 return 5476 recommendations: dict = {} 5477 if ws in ts._who_has: 5478 parent.remove_replica(ts, ws) 5479 if not ts._who_has: 5480 recommendations[ts._key] = "released" 5481 if recommendations: 5482 self.transitions(recommendations) 5483 5484 def handle_long_running(self, key=None, worker=None, compute_duration=None): 5485 """A task has seceded from the thread pool 5486 5487 We stop the task from being stolen in the future, and change task 5488 duration accounting as if the task has stopped. 5489 """ 5490 parent: SchedulerState = cast(SchedulerState, self) 5491 if key not in parent._tasks: 5492 logger.debug("Skipping long_running since key %s was already released", key) 5493 return 5494 ts: TaskState = parent._tasks[key] 5495 steal = parent._extensions.get("stealing") 5496 if steal is not None: 5497 steal.remove_key_from_stealable(ts) 5498 5499 ws: WorkerState = ts._processing_on 5500 if ws is None: 5501 logger.debug("Received long-running signal from duplicate task. Ignoring.") 5502 return 5503 5504 if compute_duration: 5505 old_duration: double = ts._prefix._duration_average 5506 new_duration: double = compute_duration 5507 avg_duration: double 5508 if old_duration < 0: 5509 avg_duration = new_duration 5510 else: 5511 avg_duration = 0.5 * old_duration + 0.5 * new_duration 5512 5513 ts._prefix._duration_average = avg_duration 5514 5515 occ: double = ws._processing[ts] 5516 ws._occupancy -= occ 5517 parent._total_occupancy -= occ 5518 ws._processing[ts] = 0 5519 self.check_idle_saturated(ws) 5520 5521 def handle_worker_status_change(self, status: str, worker: str) -> None: 5522 parent: SchedulerState = cast(SchedulerState, self) 5523 ws: WorkerState = parent._workers_dv.get(worker) # type: ignore 5524 if not ws: 5525 return 5526 prev_status = ws._status 5527 ws._status = Status.lookup[status] # type: ignore 5528 if ws._status == prev_status: 5529 return 5530 5531 self.log_event( 5532 ws._address, 5533 { 5534 "action": "worker-status-change", 5535 "prev-status": prev_status.name, 5536 "status": status, 5537 }, 5538 ) 5539 5540 if ws._status == Status.running: 5541 parent._running.add(ws) 5542 5543 recs = {} 5544 ts: TaskState 5545 for ts in parent._unrunnable: 5546 valid: set = self.valid_workers(ts) 5547 if valid is None or ws in valid: 5548 recs[ts._key] = "waiting" 5549 if recs: 5550 client_msgs: dict = {} 5551 worker_msgs: dict = {} 5552 parent._transitions(recs, client_msgs, worker_msgs) 5553 self.send_all(client_msgs, worker_msgs) 5554 5555 else: 5556 parent._running.discard(ws) 5557 5558 async def handle_worker(self, comm=None, worker=None): 5559 """ 5560 Listen to responses from a single worker 5561 5562 This is the main loop for scheduler-worker interaction 5563 5564 See Also 5565 -------- 5566 Scheduler.handle_client: Equivalent coroutine for clients 5567 """ 5568 comm.name = "Scheduler connection to worker" 5569 worker_comm = self.stream_comms[worker] 5570 worker_comm.start(comm) 5571 logger.info("Starting worker compute stream, %s", worker) 5572 try: 5573 await self.handle_stream(comm=comm, extra={"worker": worker}) 5574 finally: 5575 if worker in self.stream_comms: 5576 worker_comm.abort() 5577 await self.remove_worker(address=worker) 5578 5579 def add_plugin( 5580 self, 5581 plugin: SchedulerPlugin, 5582 *, 5583 idempotent: bool = False, 5584 name: "str | None" = None, 5585 **kwargs, 5586 ): 5587 """Add external plugin to scheduler. 5588 5589 See https://distributed.readthedocs.io/en/latest/plugins.html 5590 5591 Parameters 5592 ---------- 5593 plugin : SchedulerPlugin 5594 SchedulerPlugin instance to add 5595 idempotent : bool 5596 If true, the plugin is assumed to already exist and no 5597 action is taken. 5598 name : str 5599 A name for the plugin, if None, the name attribute is 5600 checked on the Plugin instance and generated if not 5601 discovered. 5602 **kwargs 5603 Deprecated; additional arguments passed to the `plugin` class if it is 5604 not already an instance 5605 """ 5606 if isinstance(plugin, type): 5607 warnings.warn( 5608 "Adding plugins by class is deprecated and will be disabled in a " 5609 "future release. Please add plugins by instance instead.", 5610 category=FutureWarning, 5611 ) 5612 plugin = plugin(self, **kwargs) # type: ignore 5613 elif kwargs: 5614 raise ValueError("kwargs provided but plugin is already an instance") 5615 5616 if name is None: 5617 name = _get_plugin_name(plugin) 5618 5619 if name in self.plugins: 5620 if idempotent: 5621 return 5622 warnings.warn( 5623 f"Scheduler already contains a plugin with name {name}; overwriting.", 5624 category=UserWarning, 5625 ) 5626 5627 self.plugins[name] = plugin 5628 5629 def remove_plugin( 5630 self, 5631 name: "str | None" = None, 5632 plugin: "SchedulerPlugin | None" = None, 5633 ) -> None: 5634 """Remove external plugin from scheduler 5635 5636 Parameters 5637 ---------- 5638 name : str 5639 Name of the plugin to remove 5640 plugin : SchedulerPlugin 5641 Deprecated; use `name` argument instead. Instance of a 5642 SchedulerPlugin class to remove; 5643 """ 5644 # TODO: Remove this block of code once removing plugins by value is disabled 5645 if bool(name) == bool(plugin): 5646 raise ValueError("Must provide plugin or name (mutually exclusive)") 5647 if isinstance(name, SchedulerPlugin): 5648 # Backwards compatibility - the sig used to be (plugin, name) 5649 plugin = name 5650 name = None 5651 if plugin is not None: 5652 warnings.warn( 5653 "Removing scheduler plugins by value is deprecated and will be disabled " 5654 "in a future release. Please remove scheduler plugins by name instead.", 5655 category=FutureWarning, 5656 ) 5657 if hasattr(plugin, "name"): 5658 name = plugin.name # type: ignore 5659 else: 5660 names = [k for k, v in self.plugins.items() if v is plugin] 5661 if not names: 5662 raise ValueError( 5663 f"Could not find {plugin} among the current scheduler plugins" 5664 ) 5665 if len(names) > 1: 5666 raise ValueError( 5667 f"Multiple instances of {plugin} were found in the current " 5668 "scheduler plugins; we cannot remove this plugin." 5669 ) 5670 name = names[0] 5671 assert name is not None 5672 # End deprecated code 5673 5674 try: 5675 del self.plugins[name] 5676 except KeyError: 5677 raise ValueError( 5678 f"Could not find plugin {name!r} among the current scheduler plugins" 5679 ) 5680 5681 async def register_scheduler_plugin(self, comm=None, plugin=None, name=None): 5682 """Register a plugin on the scheduler.""" 5683 if not dask.config.get("distributed.scheduler.pickle"): 5684 raise ValueError( 5685 "Cannot register a scheduler plugin as the scheduler " 5686 "has been explicitly disallowed from deserializing " 5687 "arbitrary bytestrings using pickle via the " 5688 "'distributed.scheduler.pickle' configuration setting." 5689 ) 5690 plugin = loads(plugin) 5691 5692 if hasattr(plugin, "start"): 5693 result = plugin.start(self) 5694 if inspect.isawaitable(result): 5695 await result 5696 5697 self.add_plugin(plugin, name=name) 5698 5699 def worker_send(self, worker, msg): 5700 """Send message to worker 5701 5702 This also handles connection failures by adding a callback to remove 5703 the worker on the next cycle. 5704 """ 5705 stream_comms: dict = self.stream_comms 5706 try: 5707 stream_comms[worker].send(msg) 5708 except (CommClosedError, AttributeError): 5709 self.loop.add_callback(self.remove_worker, address=worker) 5710 5711 def client_send(self, client, msg): 5712 """Send message to client""" 5713 client_comms: dict = self.client_comms 5714 c = client_comms.get(client) 5715 if c is None: 5716 return 5717 try: 5718 c.send(msg) 5719 except CommClosedError: 5720 if self.status == Status.running: 5721 logger.critical( 5722 "Closed comm %r while trying to write %s", c, msg, exc_info=True 5723 ) 5724 5725 def send_all(self, client_msgs: dict, worker_msgs: dict): 5726 """Send messages to client and workers""" 5727 client_comms: dict = self.client_comms 5728 stream_comms: dict = self.stream_comms 5729 msgs: list 5730 5731 for client, msgs in client_msgs.items(): 5732 c = client_comms.get(client) 5733 if c is None: 5734 continue 5735 try: 5736 c.send(*msgs) 5737 except CommClosedError: 5738 if self.status == Status.running: 5739 logger.critical( 5740 "Closed comm %r while trying to write %s", 5741 c, 5742 msgs, 5743 exc_info=True, 5744 ) 5745 5746 for worker, msgs in worker_msgs.items(): 5747 try: 5748 w = stream_comms[worker] 5749 w.send(*msgs) 5750 except KeyError: 5751 # worker already gone 5752 pass 5753 except (CommClosedError, AttributeError): 5754 self.loop.add_callback(self.remove_worker, address=worker) 5755 5756 ############################ 5757 # Less common interactions # 5758 ############################ 5759 5760 async def scatter( 5761 self, 5762 comm=None, 5763 data=None, 5764 workers=None, 5765 client=None, 5766 broadcast=False, 5767 timeout=2, 5768 ): 5769 """Send data out to workers 5770 5771 See also 5772 -------- 5773 Scheduler.broadcast: 5774 """ 5775 parent: SchedulerState = cast(SchedulerState, self) 5776 ws: WorkerState 5777 5778 start = time() 5779 while True: 5780 if workers is None: 5781 wss = parent._running 5782 else: 5783 workers = [self.coerce_address(w) for w in workers] 5784 wss = {parent._workers_dv[w] for w in workers} 5785 wss = {ws for ws in wss if ws._status == Status.running} 5786 5787 if wss: 5788 break 5789 if time() > start + timeout: 5790 raise TimeoutError("No valid workers found") 5791 await asyncio.sleep(0.1) 5792 5793 nthreads = {ws._address: ws.nthreads for ws in wss} 5794 5795 assert isinstance(data, dict) 5796 5797 keys, who_has, nbytes = await scatter_to_workers( 5798 nthreads, data, rpc=self.rpc, report=False 5799 ) 5800 5801 self.update_data(who_has=who_has, nbytes=nbytes, client=client) 5802 5803 if broadcast: 5804 n = len(nthreads) if broadcast is True else broadcast 5805 await self.replicate(keys=keys, workers=workers, n=n) 5806 5807 self.log_event( 5808 [client, "all"], {"action": "scatter", "client": client, "count": len(data)} 5809 ) 5810 return keys 5811 5812 async def gather(self, comm=None, keys=None, serializers=None): 5813 """Collect data from workers to the scheduler""" 5814 parent: SchedulerState = cast(SchedulerState, self) 5815 ws: WorkerState 5816 keys = list(keys) 5817 who_has = {} 5818 for key in keys: 5819 ts: TaskState = parent._tasks.get(key) 5820 if ts is not None: 5821 who_has[key] = [ws._address for ws in ts._who_has] 5822 else: 5823 who_has[key] = [] 5824 5825 data, missing_keys, missing_workers = await gather_from_workers( 5826 who_has, rpc=self.rpc, close=False, serializers=serializers 5827 ) 5828 if not missing_keys: 5829 result = {"status": "OK", "data": data} 5830 else: 5831 missing_states = [ 5832 (parent._tasks[key].state if key in parent._tasks else None) 5833 for key in missing_keys 5834 ] 5835 logger.exception( 5836 "Couldn't gather keys %s state: %s workers: %s", 5837 missing_keys, 5838 missing_states, 5839 missing_workers, 5840 ) 5841 result = {"status": "error", "keys": missing_keys} 5842 with log_errors(): 5843 # Remove suspicious workers from the scheduler but allow them to 5844 # reconnect. 5845 await asyncio.gather( 5846 *( 5847 self.remove_worker(address=worker, close=False) 5848 for worker in missing_workers 5849 ) 5850 ) 5851 recommendations: dict 5852 client_msgs: dict = {} 5853 worker_msgs: dict = {} 5854 for key, workers in missing_keys.items(): 5855 # Task may already be gone if it was held by a 5856 # `missing_worker` 5857 ts: TaskState = parent._tasks.get(key) 5858 logger.exception( 5859 "Workers don't have promised key: %s, %s", 5860 str(workers), 5861 str(key), 5862 ) 5863 if not workers or ts is None: 5864 continue 5865 recommendations: dict = {key: "released"} 5866 for worker in workers: 5867 ws = parent._workers_dv.get(worker) 5868 if ws is not None and ws in ts._who_has: 5869 parent.remove_replica(ts, ws) 5870 parent._transitions( 5871 recommendations, client_msgs, worker_msgs 5872 ) 5873 self.send_all(client_msgs, worker_msgs) 5874 5875 self.log_event("all", {"action": "gather", "count": len(keys)}) 5876 return result 5877 5878 def clear_task_state(self): 5879 # XXX what about nested state such as ClientState.wants_what 5880 # (see also fire-and-forget...) 5881 logger.info("Clear task state") 5882 for collection in self._task_state_collections: 5883 collection.clear() 5884 5885 async def restart(self, client=None, timeout=30): 5886 """Restart all workers. Reset local state.""" 5887 parent: SchedulerState = cast(SchedulerState, self) 5888 with log_errors(): 5889 5890 n_workers = len(parent._workers_dv) 5891 5892 logger.info("Send lost future signal to clients") 5893 cs: ClientState 5894 ts: TaskState 5895 for cs in parent._clients.values(): 5896 self.client_releases_keys( 5897 keys=[ts._key for ts in cs._wants_what], client=cs._client_key 5898 ) 5899 5900 ws: WorkerState 5901 nannies = {addr: ws._nanny for addr, ws in parent._workers_dv.items()} 5902 5903 for addr in list(parent._workers_dv): 5904 try: 5905 # Ask the worker to close if it doesn't have a nanny, 5906 # otherwise the nanny will kill it anyway 5907 await self.remove_worker(address=addr, close=addr not in nannies) 5908 except Exception: 5909 logger.info( 5910 "Exception while restarting. This is normal", exc_info=True 5911 ) 5912 5913 self.clear_task_state() 5914 5915 for plugin in list(self.plugins.values()): 5916 try: 5917 plugin.restart(self) 5918 except Exception as e: 5919 logger.exception(e) 5920 5921 logger.debug("Send kill signal to nannies: %s", nannies) 5922 5923 nannies = [ 5924 rpc(nanny_address, connection_args=self.connection_args) 5925 for nanny_address in nannies.values() 5926 if nanny_address is not None 5927 ] 5928 5929 resps = All( 5930 [ 5931 nanny.restart( 5932 close=True, timeout=timeout * 0.8, executor_wait=False 5933 ) 5934 for nanny in nannies 5935 ] 5936 ) 5937 try: 5938 resps = await asyncio.wait_for(resps, timeout) 5939 except TimeoutError: 5940 logger.error( 5941 "Nannies didn't report back restarted within " 5942 "timeout. Continuuing with restart process" 5943 ) 5944 else: 5945 if not all(resp == "OK" for resp in resps): 5946 logger.error( 5947 "Not all workers responded positively: %s", resps, exc_info=True 5948 ) 5949 finally: 5950 await asyncio.gather(*[nanny.close_rpc() for nanny in nannies]) 5951 5952 self.clear_task_state() 5953 5954 with suppress(AttributeError): 5955 for c in self._worker_coroutines: 5956 c.cancel() 5957 5958 self.log_event([client, "all"], {"action": "restart", "client": client}) 5959 start = time() 5960 while time() < start + 10 and len(parent._workers_dv) < n_workers: 5961 await asyncio.sleep(0.01) 5962 5963 self.report({"op": "restart"}) 5964 5965 async def broadcast( 5966 self, 5967 comm=None, 5968 msg=None, 5969 workers=None, 5970 hosts=None, 5971 nanny=False, 5972 serializers=None, 5973 ): 5974 """Broadcast message to workers, return all results""" 5975 parent: SchedulerState = cast(SchedulerState, self) 5976 if workers is None or workers is True: 5977 if hosts is None: 5978 workers = list(parent._workers_dv) 5979 else: 5980 workers = [] 5981 if hosts is not None: 5982 for host in hosts: 5983 dh: dict = parent._host_info.get(host) 5984 if dh is not None: 5985 workers.extend(dh["addresses"]) 5986 # TODO replace with worker_list 5987 5988 if nanny: 5989 addresses = [parent._workers_dv[w].nanny for w in workers] 5990 else: 5991 addresses = workers 5992 5993 async def send_message(addr): 5994 comm = await self.rpc.connect(addr) 5995 comm.name = "Scheduler Broadcast" 5996 try: 5997 resp = await send_recv(comm, close=True, serializers=serializers, **msg) 5998 finally: 5999 self.rpc.reuse(addr, comm) 6000 return resp 6001 6002 results = await All( 6003 [send_message(address) for address in addresses if address is not None] 6004 ) 6005 6006 return dict(zip(workers, results)) 6007 6008 async def proxy(self, comm=None, msg=None, worker=None, serializers=None): 6009 """Proxy a communication through the scheduler to some other worker""" 6010 d = await self.broadcast( 6011 comm=comm, msg=msg, workers=[worker], serializers=serializers 6012 ) 6013 return d[worker] 6014 6015 async def gather_on_worker( 6016 self, worker_address: str, who_has: "dict[str, list[str]]" 6017 ) -> set: 6018 """Peer-to-peer copy of keys from multiple workers to a single worker 6019 6020 Parameters 6021 ---------- 6022 worker_address: str 6023 Recipient worker address to copy keys to 6024 who_has: dict[Hashable, list[str]] 6025 {key: [sender address, sender address, ...], key: ...} 6026 6027 Returns 6028 ------- 6029 returns: 6030 set of keys that failed to be copied 6031 """ 6032 try: 6033 result = await retry_operation( 6034 self.rpc(addr=worker_address).gather, who_has=who_has 6035 ) 6036 except OSError as e: 6037 # This can happen e.g. if the worker is going through controlled shutdown; 6038 # it doesn't necessarily mean that it went unexpectedly missing 6039 logger.warning( 6040 f"Communication with worker {worker_address} failed during " 6041 f"replication: {e.__class__.__name__}: {e}" 6042 ) 6043 return set(who_has) 6044 6045 parent: SchedulerState = cast(SchedulerState, self) 6046 ws: WorkerState = parent._workers_dv.get(worker_address) # type: ignore 6047 6048 if ws is None: 6049 logger.warning(f"Worker {worker_address} lost during replication") 6050 return set(who_has) 6051 elif result["status"] == "OK": 6052 keys_failed = set() 6053 keys_ok: Set = who_has.keys() 6054 elif result["status"] == "partial-fail": 6055 keys_failed = set(result["keys"]) 6056 keys_ok = who_has.keys() - keys_failed 6057 logger.warning( 6058 f"Worker {worker_address} failed to acquire keys: {result['keys']}" 6059 ) 6060 else: # pragma: nocover 6061 raise ValueError(f"Unexpected message from {worker_address}: {result}") 6062 6063 for key in keys_ok: 6064 ts: TaskState = parent._tasks.get(key) # type: ignore 6065 if ts is None or ts._state != "memory": 6066 logger.warning(f"Key lost during replication: {key}") 6067 continue 6068 if ws not in ts._who_has: 6069 parent.add_replica(ts, ws) 6070 6071 return keys_failed 6072 6073 async def delete_worker_data( 6074 self, worker_address: str, keys: "Collection[str]" 6075 ) -> None: 6076 """Delete data from a worker and update the corresponding worker/task states 6077 6078 Parameters 6079 ---------- 6080 worker_address: str 6081 Worker address to delete keys from 6082 keys: list[str] 6083 List of keys to delete on the specified worker 6084 """ 6085 parent: SchedulerState = cast(SchedulerState, self) 6086 6087 try: 6088 await retry_operation( 6089 self.rpc(addr=worker_address).free_keys, 6090 keys=list(keys), 6091 stimulus_id=f"delete-data-{time()}", 6092 ) 6093 except OSError as e: 6094 # This can happen e.g. if the worker is going through controlled shutdown; 6095 # it doesn't necessarily mean that it went unexpectedly missing 6096 logger.warning( 6097 f"Communication with worker {worker_address} failed during " 6098 f"replication: {e.__class__.__name__}: {e}" 6099 ) 6100 return 6101 6102 ws: WorkerState = parent._workers_dv.get(worker_address) # type: ignore 6103 if ws is None: 6104 return 6105 6106 for key in keys: 6107 ts: TaskState = parent._tasks.get(key) # type: ignore 6108 if ts is not None and ws in ts._who_has: 6109 assert ts._state == "memory" 6110 parent.remove_replica(ts, ws) 6111 if not ts._who_has: 6112 # Last copy deleted 6113 self.transitions({key: "released"}) 6114 6115 self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys}) 6116 6117 async def rebalance( 6118 self, 6119 comm=None, 6120 keys: "Iterable[Hashable]" = None, 6121 workers: "Iterable[str]" = None, 6122 ) -> dict: 6123 """Rebalance keys so that each worker ends up with roughly the same process 6124 memory (managed+unmanaged). 6125 6126 .. warning:: 6127 This operation is generally not well tested against normal operation of the 6128 scheduler. It is not recommended to use it while waiting on computations. 6129 6130 **Algorithm** 6131 6132 #. Find the mean occupancy of the cluster, defined as data managed by dask + 6133 unmanaged process memory that has been there for at least 30 seconds 6134 (``distributed.worker.memory.recent-to-old-time``). 6135 This lets us ignore temporary spikes caused by task heap usage. 6136 6137 Alternatively, you may change how memory is measured both for the individual 6138 workers as well as to calculate the mean through 6139 ``distributed.worker.memory.rebalance.measure``. Namely, this can be useful 6140 to disregard inaccurate OS memory measurements. 6141 6142 #. Discard workers whose occupancy is within 5% of the mean cluster occupancy 6143 (``distributed.worker.memory.rebalance.sender-recipient-gap`` / 2). 6144 This helps avoid data from bouncing around the cluster repeatedly. 6145 #. Workers above the mean are senders; those below are recipients. 6146 #. Discard senders whose absolute occupancy is below 30% 6147 (``distributed.worker.memory.rebalance.sender-min``). In other words, no data 6148 is moved regardless of imbalancing as long as all workers are below 30%. 6149 #. Discard recipients whose absolute occupancy is above 60% 6150 (``distributed.worker.memory.rebalance.recipient-max``). 6151 Note that this threshold by default is the same as 6152 ``distributed.worker.memory.target`` to prevent workers from accepting data 6153 and immediately spilling it out to disk. 6154 #. Iteratively pick the sender and recipient that are farthest from the mean and 6155 move the *least recently inserted* key between the two, until either all 6156 senders or all recipients fall within 5% of the mean. 6157 6158 A recipient will be skipped if it already has a copy of the data. In other 6159 words, this method does not degrade replication. 6160 A key will be skipped if there are no recipients available with enough memory 6161 to accept the key and that don't already hold a copy. 6162 6163 The least recently insertd (LRI) policy is a greedy choice with the advantage of 6164 being O(1), trivial to implement (it relies on python dict insertion-sorting) 6165 and hopefully good enough in most cases. Discarded alternative policies were: 6166 6167 - Largest first. O(n*log(n)) save for non-trivial additional data structures and 6168 risks causing the largest chunks of data to repeatedly move around the 6169 cluster like pinballs. 6170 - Least recently used (LRU). This information is currently available on the 6171 workers only and not trivial to replicate on the scheduler; transmitting it 6172 over the network would be very expensive. Also, note that dask will go out of 6173 its way to minimise the amount of time intermediate keys are held in memory, 6174 so in such a case LRI is a close approximation of LRU. 6175 6176 Parameters 6177 ---------- 6178 keys: optional 6179 whitelist of dask keys that should be considered for moving. All other keys 6180 will be ignored. Note that this offers no guarantee that a key will actually 6181 be moved (e.g. because it is unnecessary or because there are no viable 6182 recipient workers for it). 6183 workers: optional 6184 whitelist of workers addresses to be considered as senders or recipients. 6185 All other workers will be ignored. The mean cluster occupancy will be 6186 calculated only using the whitelisted workers. 6187 """ 6188 parent: SchedulerState = cast(SchedulerState, self) 6189 6190 with log_errors(): 6191 wss: "Collection[WorkerState]" 6192 if workers is not None: 6193 wss = [parent._workers_dv[w] for w in workers] 6194 else: 6195 wss = parent._workers_dv.values() 6196 if not wss: 6197 return {"status": "OK"} 6198 6199 if keys is not None: 6200 if not isinstance(keys, Set): 6201 keys = set(keys) # unless already a set-like 6202 if not keys: 6203 return {"status": "OK"} 6204 missing_data = [ 6205 k 6206 for k in keys 6207 if k not in parent._tasks or not parent._tasks[k].who_has 6208 ] 6209 if missing_data: 6210 return {"status": "partial-fail", "keys": missing_data} 6211 6212 msgs = self._rebalance_find_msgs(keys, wss) 6213 if not msgs: 6214 return {"status": "OK"} 6215 6216 async with self._lock: 6217 result = await self._rebalance_move_data(msgs) 6218 if result["status"] == "partial-fail" and keys is None: 6219 # Only return failed keys if the client explicitly asked for them 6220 result = {"status": "OK"} 6221 return result 6222 6223 def _rebalance_find_msgs( 6224 self, 6225 keys: "Set[Hashable] | None", 6226 workers: "Iterable[WorkerState]", 6227 ) -> "list[tuple[WorkerState, WorkerState, TaskState]]": 6228 """Identify workers that need to lose keys and those that can receive them, 6229 together with how many bytes each needs to lose/receive. Then, pair a sender 6230 worker with a recipient worker for each key, until the cluster is rebalanced. 6231 6232 This method only defines the work to be performed; it does not start any network 6233 transfers itself. 6234 6235 The big-O complexity is O(wt + ke*log(we)), where 6236 6237 - wt is the total number of workers on the cluster (or the number of whitelisted 6238 workers, if explicitly stated by the user) 6239 - we is the number of workers that are eligible to be senders or recipients 6240 - kt is the total number of keys on the cluster (or on the whitelisted workers) 6241 - ke is the number of keys that need to be moved in order to achieve a balanced 6242 cluster 6243 6244 There is a degenerate edge case O(wt + kt*log(we)) when kt is much greater than 6245 the number of whitelisted keys, or when most keys are replicated or cannot be 6246 moved for some other reason. 6247 6248 Returns list of tuples to feed into _rebalance_move_data: 6249 6250 - sender worker 6251 - recipient worker 6252 - task to be transferred 6253 """ 6254 parent: SchedulerState = cast(SchedulerState, self) 6255 ts: TaskState 6256 ws: WorkerState 6257 6258 # Heaps of workers, managed by the heapq module, that need to send/receive data, 6259 # with how many bytes each needs to send/receive. 6260 # 6261 # Each element of the heap is a tuple constructed as follows: 6262 # - snd_bytes_max/rec_bytes_max: maximum number of bytes to send or receive. 6263 # This number is negative, so that the workers farthest from the cluster mean 6264 # are at the top of the smallest-first heaps. 6265 # - snd_bytes_min/rec_bytes_min: minimum number of bytes after sending/receiving 6266 # which the worker should not be considered anymore. This is also negative. 6267 # - arbitrary unique number, there just to to make sure that WorkerState objects 6268 # are never used for sorting in the unlikely event that two processes have 6269 # exactly the same number of bytes allocated. 6270 # - WorkerState 6271 # - iterator of all tasks in memory on the worker (senders only), insertion 6272 # sorted (least recently inserted first). 6273 # Note that this iterator will typically *not* be exhausted. It will only be 6274 # exhausted if, after moving away from the worker all keys that can be moved, 6275 # is insufficient to drop snd_bytes_min above 0. 6276 senders: "list[tuple[int, int, int, WorkerState, Iterator[TaskState]]]" = [] 6277 recipients: "list[tuple[int, int, int, WorkerState]]" = [] 6278 6279 # Output: [(sender, recipient, task), ...] 6280 msgs: "list[tuple[WorkerState, WorkerState, TaskState]]" = [] 6281 6282 # By default, this is the optimistic memory, meaning total process memory minus 6283 # unmanaged memory that appeared over the last 30 seconds 6284 # (distributed.worker.memory.recent-to-old-time). 6285 # This lets us ignore temporary spikes caused by task heap usage. 6286 memory_by_worker = [ 6287 (ws, getattr(ws.memory, parent.MEMORY_REBALANCE_MEASURE)) for ws in workers 6288 ] 6289 mean_memory = sum(m for _, m in memory_by_worker) // len(memory_by_worker) 6290 6291 for ws, ws_memory in memory_by_worker: 6292 if ws.memory_limit: 6293 half_gap = int(parent.MEMORY_REBALANCE_HALF_GAP * ws.memory_limit) 6294 sender_min = parent.MEMORY_REBALANCE_SENDER_MIN * ws.memory_limit 6295 recipient_max = parent.MEMORY_REBALANCE_RECIPIENT_MAX * ws.memory_limit 6296 else: 6297 half_gap = 0 6298 sender_min = 0.0 6299 recipient_max = math.inf 6300 6301 if ( 6302 ws._has_what 6303 and ws_memory >= mean_memory + half_gap 6304 and ws_memory >= sender_min 6305 ): 6306 # This may send the worker below sender_min (by design) 6307 snd_bytes_max = mean_memory - ws_memory # negative 6308 snd_bytes_min = snd_bytes_max + half_gap # negative 6309 # See definition of senders above 6310 senders.append( 6311 (snd_bytes_max, snd_bytes_min, id(ws), ws, iter(ws._has_what)) 6312 ) 6313 elif ws_memory < mean_memory - half_gap and ws_memory < recipient_max: 6314 # This may send the worker above recipient_max (by design) 6315 rec_bytes_max = ws_memory - mean_memory # negative 6316 rec_bytes_min = rec_bytes_max + half_gap # negative 6317 # See definition of recipients above 6318 recipients.append((rec_bytes_max, rec_bytes_min, id(ws), ws)) 6319 6320 # Fast exit in case no transfers are necessary or possible 6321 if not senders or not recipients: 6322 self.log_event( 6323 "all", 6324 { 6325 "action": "rebalance", 6326 "senders": len(senders), 6327 "recipients": len(recipients), 6328 "moved_keys": 0, 6329 }, 6330 ) 6331 return [] 6332 6333 heapq.heapify(senders) 6334 heapq.heapify(recipients) 6335 6336 snd_ws: WorkerState 6337 rec_ws: WorkerState 6338 6339 while senders and recipients: 6340 snd_bytes_max, snd_bytes_min, _, snd_ws, ts_iter = senders[0] 6341 6342 # Iterate through tasks in memory, least recently inserted first 6343 for ts in ts_iter: 6344 if keys is not None and ts.key not in keys: 6345 continue 6346 nbytes = ts.nbytes 6347 if nbytes + snd_bytes_max > 0: 6348 # Moving this task would cause the sender to go below mean and 6349 # potentially risk becoming a recipient, which would cause tasks to 6350 # bounce around. Move on to the next task of the same sender. 6351 continue 6352 6353 # Find the recipient, farthest from the mean, which 6354 # 1. has enough available RAM for this task, and 6355 # 2. doesn't hold a copy of this task already 6356 # There may not be any that satisfies these conditions; in this case 6357 # this task won't be moved. 6358 skipped_recipients = [] 6359 use_recipient = False 6360 while recipients and not use_recipient: 6361 rec_bytes_max, rec_bytes_min, _, rec_ws = recipients[0] 6362 if nbytes + rec_bytes_max > 0: 6363 # recipients are sorted by rec_bytes_max. 6364 # The next ones will be worse; no reason to continue iterating 6365 break 6366 use_recipient = ts not in rec_ws._has_what 6367 if not use_recipient: 6368 skipped_recipients.append(heapq.heappop(recipients)) 6369 6370 for recipient in skipped_recipients: 6371 heapq.heappush(recipients, recipient) 6372 6373 if not use_recipient: 6374 # This task has no recipients available. Leave it on the sender and 6375 # move on to the next task of the same sender. 6376 continue 6377 6378 # Schedule task for transfer from sender to recipient 6379 msgs.append((snd_ws, rec_ws, ts)) 6380 6381 # *_bytes_max/min are all negative for heap sorting 6382 snd_bytes_max += nbytes 6383 snd_bytes_min += nbytes 6384 rec_bytes_max += nbytes 6385 rec_bytes_min += nbytes 6386 6387 # Stop iterating on the tasks of this sender for now and, if it still 6388 # has bytes to lose, push it back into the senders heap; it may or may 6389 # not come back on top again. 6390 if snd_bytes_min < 0: 6391 # See definition of senders above 6392 heapq.heapreplace( 6393 senders, 6394 (snd_bytes_max, snd_bytes_min, id(snd_ws), snd_ws, ts_iter), 6395 ) 6396 else: 6397 heapq.heappop(senders) 6398 6399 # If recipient still has bytes to gain, push it back into the recipients 6400 # heap; it may or may not come back on top again. 6401 if rec_bytes_min < 0: 6402 # See definition of recipients above 6403 heapq.heapreplace( 6404 recipients, 6405 (rec_bytes_max, rec_bytes_min, id(rec_ws), rec_ws), 6406 ) 6407 else: 6408 heapq.heappop(recipients) 6409 6410 # Move to next sender with the most data to lose. 6411 # It may or may not be the same sender again. 6412 break 6413 6414 else: # for ts in ts_iter 6415 # Exhausted tasks on this sender 6416 heapq.heappop(senders) 6417 6418 return msgs 6419 6420 async def _rebalance_move_data( 6421 self, msgs: "list[tuple[WorkerState, WorkerState, TaskState]]" 6422 ) -> dict: 6423 """Perform the actual transfer of data across the network in rebalance(). 6424 Takes in input the output of _rebalance_find_msgs(), that is a list of tuples: 6425 6426 - sender worker 6427 - recipient worker 6428 - task to be transferred 6429 6430 FIXME this method is not robust when the cluster is not idle. 6431 """ 6432 snd_ws: WorkerState 6433 rec_ws: WorkerState 6434 ts: TaskState 6435 6436 to_recipients: "defaultdict[str, defaultdict[str, list[str]]]" = defaultdict( 6437 lambda: defaultdict(list) 6438 ) 6439 for snd_ws, rec_ws, ts in msgs: 6440 to_recipients[rec_ws.address][ts._key].append(snd_ws.address) 6441 failed_keys_by_recipient = dict( 6442 zip( 6443 to_recipients, 6444 await asyncio.gather( 6445 *( 6446 # Note: this never raises exceptions 6447 self.gather_on_worker(w, who_has) 6448 for w, who_has in to_recipients.items() 6449 ) 6450 ), 6451 ) 6452 ) 6453 6454 to_senders = defaultdict(list) 6455 for snd_ws, rec_ws, ts in msgs: 6456 if ts._key not in failed_keys_by_recipient[rec_ws.address]: 6457 to_senders[snd_ws.address].append(ts._key) 6458 6459 # Note: this never raises exceptions 6460 await asyncio.gather( 6461 *(self.delete_worker_data(r, v) for r, v in to_senders.items()) 6462 ) 6463 6464 for r, v in to_recipients.items(): 6465 self.log_event(r, {"action": "rebalance", "who_has": v}) 6466 self.log_event( 6467 "all", 6468 { 6469 "action": "rebalance", 6470 "senders": valmap(len, to_senders), 6471 "recipients": valmap(len, to_recipients), 6472 "moved_keys": len(msgs), 6473 }, 6474 ) 6475 6476 missing_keys = {k for r in failed_keys_by_recipient.values() for k in r} 6477 if missing_keys: 6478 return {"status": "partial-fail", "keys": list(missing_keys)} 6479 else: 6480 return {"status": "OK"} 6481 6482 async def replicate( 6483 self, 6484 comm=None, 6485 keys=None, 6486 n=None, 6487 workers=None, 6488 branching_factor=2, 6489 delete=True, 6490 lock=True, 6491 ): 6492 """Replicate data throughout cluster 6493 6494 This performs a tree copy of the data throughout the network 6495 individually on each piece of data. 6496 6497 Parameters 6498 ---------- 6499 keys: Iterable 6500 list of keys to replicate 6501 n: int 6502 Number of replications we expect to see within the cluster 6503 branching_factor: int, optional 6504 The number of workers that can copy data in each generation. 6505 The larger the branching factor, the more data we copy in 6506 a single step, but the more a given worker risks being 6507 swamped by data requests. 6508 6509 See also 6510 -------- 6511 Scheduler.rebalance 6512 """ 6513 parent: SchedulerState = cast(SchedulerState, self) 6514 ws: WorkerState 6515 wws: WorkerState 6516 ts: TaskState 6517 6518 assert branching_factor > 0 6519 async with self._lock if lock else empty_context: 6520 if workers is not None: 6521 workers = {parent._workers_dv[w] for w in self.workers_list(workers)} 6522 workers = {ws for ws in workers if ws._status == Status.running} 6523 else: 6524 workers = parent._running 6525 6526 if n is None: 6527 n = len(workers) 6528 else: 6529 n = min(n, len(workers)) 6530 if n == 0: 6531 raise ValueError("Can not use replicate to delete data") 6532 6533 tasks = {parent._tasks[k] for k in keys} 6534 missing_data = [ts._key for ts in tasks if not ts._who_has] 6535 if missing_data: 6536 return {"status": "partial-fail", "keys": missing_data} 6537 6538 # Delete extraneous data 6539 if delete: 6540 del_worker_tasks = defaultdict(set) 6541 for ts in tasks: 6542 del_candidates = tuple(ts._who_has & workers) 6543 if len(del_candidates) > n: 6544 for ws in random.sample( 6545 del_candidates, len(del_candidates) - n 6546 ): 6547 del_worker_tasks[ws].add(ts) 6548 6549 # Note: this never raises exceptions 6550 await asyncio.gather( 6551 *[ 6552 self.delete_worker_data(ws._address, [t.key for t in tasks]) 6553 for ws, tasks in del_worker_tasks.items() 6554 ] 6555 ) 6556 6557 # Copy not-yet-filled data 6558 while tasks: 6559 gathers = defaultdict(dict) 6560 for ts in list(tasks): 6561 if ts._state == "forgotten": 6562 # task is no longer needed by any client or dependant task 6563 tasks.remove(ts) 6564 continue 6565 n_missing = n - len(ts._who_has & workers) 6566 if n_missing <= 0: 6567 # Already replicated enough 6568 tasks.remove(ts) 6569 continue 6570 6571 count = min(n_missing, branching_factor * len(ts._who_has)) 6572 assert count > 0 6573 6574 for ws in random.sample(tuple(workers - ts._who_has), count): 6575 gathers[ws._address][ts._key] = [ 6576 wws._address for wws in ts._who_has 6577 ] 6578 6579 await asyncio.gather( 6580 *( 6581 # Note: this never raises exceptions 6582 self.gather_on_worker(w, who_has) 6583 for w, who_has in gathers.items() 6584 ) 6585 ) 6586 for r, v in gathers.items(): 6587 self.log_event(r, {"action": "replicate-add", "who_has": v}) 6588 6589 self.log_event( 6590 "all", 6591 { 6592 "action": "replicate", 6593 "workers": list(workers), 6594 "key-count": len(keys), 6595 "branching-factor": branching_factor, 6596 }, 6597 ) 6598 6599 def workers_to_close( 6600 self, 6601 comm=None, 6602 memory_ratio: "int | float | None" = None, 6603 n: "int | None" = None, 6604 key: "Callable[[WorkerState], Hashable] | None" = None, 6605 minimum: "int | None" = None, 6606 target: "int | None" = None, 6607 attribute: str = "address", 6608 ) -> "list[str]": 6609 """ 6610 Find workers that we can close with low cost 6611 6612 This returns a list of workers that are good candidates to retire. 6613 These workers are not running anything and are storing 6614 relatively little data relative to their peers. If all workers are 6615 idle then we still maintain enough workers to have enough RAM to store 6616 our data, with a comfortable buffer. 6617 6618 This is for use with systems like ``distributed.deploy.adaptive``. 6619 6620 Parameters 6621 ---------- 6622 memory_ratio : Number 6623 Amount of extra space we want to have for our stored data. 6624 Defaults to 2, or that we want to have twice as much memory as we 6625 currently have data. 6626 n : int 6627 Number of workers to close 6628 minimum : int 6629 Minimum number of workers to keep around 6630 key : Callable(WorkerState) 6631 An optional callable mapping a WorkerState object to a group 6632 affiliation. Groups will be closed together. This is useful when 6633 closing workers must be done collectively, such as by hostname. 6634 target : int 6635 Target number of workers to have after we close 6636 attribute : str 6637 The attribute of the WorkerState object to return, like "address" 6638 or "name". Defaults to "address". 6639 6640 Examples 6641 -------- 6642 >>> scheduler.workers_to_close() 6643 ['tcp://192.168.0.1:1234', 'tcp://192.168.0.2:1234'] 6644 6645 Group workers by hostname prior to closing 6646 6647 >>> scheduler.workers_to_close(key=lambda ws: ws.host) 6648 ['tcp://192.168.0.1:1234', 'tcp://192.168.0.1:4567'] 6649 6650 Remove two workers 6651 6652 >>> scheduler.workers_to_close(n=2) 6653 6654 Keep enough workers to have twice as much memory as we we need. 6655 6656 >>> scheduler.workers_to_close(memory_ratio=2) 6657 6658 Returns 6659 ------- 6660 to_close: list of worker addresses that are OK to close 6661 6662 See Also 6663 -------- 6664 Scheduler.retire_workers 6665 """ 6666 parent: SchedulerState = cast(SchedulerState, self) 6667 if target is not None and n is None: 6668 n = len(parent._workers_dv) - target 6669 if n is not None: 6670 if n < 0: 6671 n = 0 6672 target = len(parent._workers_dv) - n 6673 6674 if n is None and memory_ratio is None: 6675 memory_ratio = 2 6676 6677 ws: WorkerState 6678 with log_errors(): 6679 if not n and all([ws._processing for ws in parent._workers_dv.values()]): 6680 return [] 6681 6682 if key is None: 6683 key = operator.attrgetter("address") 6684 if isinstance(key, bytes) and dask.config.get( 6685 "distributed.scheduler.pickle" 6686 ): 6687 key = pickle.loads(key) 6688 6689 groups = groupby(key, parent._workers.values()) 6690 6691 limit_bytes = { 6692 k: sum([ws._memory_limit for ws in v]) for k, v in groups.items() 6693 } 6694 group_bytes = {k: sum([ws._nbytes for ws in v]) for k, v in groups.items()} 6695 6696 limit = sum(limit_bytes.values()) 6697 total = sum(group_bytes.values()) 6698 6699 def _key(group): 6700 wws: WorkerState 6701 is_idle = not any([wws._processing for wws in groups[group]]) 6702 bytes = -group_bytes[group] 6703 return (is_idle, bytes) 6704 6705 idle = sorted(groups, key=_key) 6706 6707 to_close = [] 6708 n_remain = len(parent._workers_dv) 6709 6710 while idle: 6711 group = idle.pop() 6712 if n is None and any([ws._processing for ws in groups[group]]): 6713 break 6714 6715 if minimum and n_remain - len(groups[group]) < minimum: 6716 break 6717 6718 limit -= limit_bytes[group] 6719 6720 if ( 6721 n is not None and n_remain - len(groups[group]) >= cast(int, target) 6722 ) or (memory_ratio is not None and limit >= memory_ratio * total): 6723 to_close.append(group) 6724 n_remain -= len(groups[group]) 6725 6726 else: 6727 break 6728 6729 result = [getattr(ws, attribute) for g in to_close for ws in groups[g]] 6730 if result: 6731 logger.debug("Suggest closing workers: %s", result) 6732 6733 return result 6734 6735 async def retire_workers( 6736 self, 6737 comm=None, 6738 workers=None, 6739 remove=True, 6740 close_workers=False, 6741 names=None, 6742 lock=True, 6743 **kwargs, 6744 ) -> dict: 6745 """Gracefully retire workers from cluster 6746 6747 Parameters 6748 ---------- 6749 workers: list (optional) 6750 List of worker addresses to retire. 6751 If not provided we call ``workers_to_close`` which finds a good set 6752 names: list (optional) 6753 List of worker names to retire. 6754 remove: bool (defaults to True) 6755 Whether or not to remove the worker metadata immediately or else 6756 wait for the worker to contact us 6757 close_workers: bool (defaults to False) 6758 Whether or not to actually close the worker explicitly from here. 6759 Otherwise we expect some external job scheduler to finish off the 6760 worker. 6761 **kwargs: dict 6762 Extra options to pass to workers_to_close to determine which 6763 workers we should drop 6764 6765 Returns 6766 ------- 6767 Dictionary mapping worker ID/address to dictionary of information about 6768 that worker for each retired worker. 6769 6770 See Also 6771 -------- 6772 Scheduler.workers_to_close 6773 """ 6774 parent: SchedulerState = cast(SchedulerState, self) 6775 ws: WorkerState 6776 ts: TaskState 6777 with log_errors(): 6778 async with self._lock if lock else empty_context: 6779 if names is not None: 6780 if workers is not None: 6781 raise TypeError("names and workers are mutually exclusive") 6782 if names: 6783 logger.info("Retire worker names %s", names) 6784 names = set(map(str, names)) 6785 workers = { 6786 ws._address 6787 for ws in parent._workers_dv.values() 6788 if str(ws._name) in names 6789 } 6790 elif workers is None: 6791 while True: 6792 try: 6793 workers = self.workers_to_close(**kwargs) 6794 if not workers: 6795 return {} 6796 return await self.retire_workers( 6797 workers=workers, 6798 remove=remove, 6799 close_workers=close_workers, 6800 lock=False, 6801 ) 6802 except KeyError: # keys left during replicate 6803 pass 6804 6805 workers = { 6806 parent._workers_dv[w] for w in workers if w in parent._workers_dv 6807 } 6808 if not workers: 6809 return {} 6810 logger.info("Retire workers %s", workers) 6811 6812 # Keys orphaned by retiring those workers 6813 keys = {k for w in workers for k in w.has_what} 6814 keys = {ts._key for ts in keys if ts._who_has.issubset(workers)} 6815 6816 if keys: 6817 other_workers = set(parent._workers_dv.values()) - workers 6818 if not other_workers: 6819 return {} 6820 logger.info("Moving %d keys to other workers", len(keys)) 6821 await self.replicate( 6822 keys=keys, 6823 workers=[ws._address for ws in other_workers], 6824 n=1, 6825 delete=False, 6826 lock=False, 6827 ) 6828 6829 worker_keys = {ws._address: ws.identity() for ws in workers} 6830 if close_workers: 6831 await asyncio.gather( 6832 *[self.close_worker(worker=w, safe=True) for w in worker_keys] 6833 ) 6834 if remove: 6835 await asyncio.gather( 6836 *[self.remove_worker(address=w, safe=True) for w in worker_keys] 6837 ) 6838 6839 self.log_event( 6840 "all", 6841 { 6842 "action": "retire-workers", 6843 "workers": worker_keys, 6844 "moved-keys": len(keys), 6845 }, 6846 ) 6847 self.log_event(list(worker_keys), {"action": "retired"}) 6848 6849 return worker_keys 6850 6851 def add_keys(self, comm=None, worker=None, keys=(), stimulus_id=None): 6852 """ 6853 Learn that a worker has certain keys 6854 6855 This should not be used in practice and is mostly here for legacy 6856 reasons. However, it is sent by workers from time to time. 6857 """ 6858 parent: SchedulerState = cast(SchedulerState, self) 6859 if worker not in parent._workers_dv: 6860 return "not found" 6861 ws: WorkerState = parent._workers_dv[worker] 6862 redundant_replicas = [] 6863 for key in keys: 6864 ts: TaskState = parent._tasks.get(key) 6865 if ts is not None and ts._state == "memory": 6866 if ws not in ts._who_has: 6867 parent.add_replica(ts, ws) 6868 else: 6869 redundant_replicas.append(key) 6870 6871 if redundant_replicas: 6872 if not stimulus_id: 6873 stimulus_id = f"redundant-replicas-{time()}" 6874 self.worker_send( 6875 worker, 6876 { 6877 "op": "remove-replicas", 6878 "keys": redundant_replicas, 6879 "stimulus_id": stimulus_id, 6880 }, 6881 ) 6882 6883 return "OK" 6884 6885 def update_data( 6886 self, 6887 comm=None, 6888 *, 6889 who_has: dict, 6890 nbytes: dict, 6891 client=None, 6892 serializers=None, 6893 ): 6894 """ 6895 Learn that new data has entered the network from an external source 6896 6897 See Also 6898 -------- 6899 Scheduler.mark_key_in_memory 6900 """ 6901 parent: SchedulerState = cast(SchedulerState, self) 6902 with log_errors(): 6903 who_has = { 6904 k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() 6905 } 6906 logger.debug("Update data %s", who_has) 6907 6908 for key, workers in who_has.items(): 6909 ts: TaskState = parent._tasks.get(key) # type: ignore 6910 if ts is None: 6911 ts = parent.new_task(key, None, "memory") 6912 ts.state = "memory" 6913 ts_nbytes = nbytes.get(key, -1) 6914 if ts_nbytes >= 0: 6915 ts.set_nbytes(ts_nbytes) 6916 6917 for w in workers: 6918 ws: WorkerState = parent._workers_dv[w] 6919 if ws not in ts._who_has: 6920 parent.add_replica(ts, ws) 6921 self.report( 6922 {"op": "key-in-memory", "key": key, "workers": list(workers)} 6923 ) 6924 6925 if client: 6926 self.client_desires_keys(keys=list(who_has), client=client) 6927 6928 def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None): 6929 parent: SchedulerState = cast(SchedulerState, self) 6930 if ts is None: 6931 ts = parent._tasks.get(key) 6932 elif key is None: 6933 key = ts._key 6934 else: 6935 assert False, (key, ts) 6936 return 6937 6938 report_msg: dict 6939 if ts is None: 6940 report_msg = {"op": "cancelled-key", "key": key} 6941 else: 6942 report_msg = _task_to_report_msg(parent, ts) 6943 if report_msg is not None: 6944 self.report(report_msg, ts=ts, client=client) 6945 6946 async def feed( 6947 self, comm, function=None, setup=None, teardown=None, interval="1s", **kwargs 6948 ): 6949 """ 6950 Provides a data Comm to external requester 6951 6952 Caution: this runs arbitrary Python code on the scheduler. This should 6953 eventually be phased out. It is mostly used by diagnostics. 6954 """ 6955 if not dask.config.get("distributed.scheduler.pickle"): 6956 logger.warn( 6957 "Tried to call 'feed' route with custom functions, but " 6958 "pickle is disallowed. Set the 'distributed.scheduler.pickle'" 6959 "config value to True to use the 'feed' route (this is mostly " 6960 "commonly used with progress bars)" 6961 ) 6962 return 6963 6964 interval = parse_timedelta(interval) 6965 with log_errors(): 6966 if function: 6967 function = pickle.loads(function) 6968 if setup: 6969 setup = pickle.loads(setup) 6970 if teardown: 6971 teardown = pickle.loads(teardown) 6972 state = setup(self) if setup else None 6973 if inspect.isawaitable(state): 6974 state = await state 6975 try: 6976 while self.status == Status.running: 6977 if state is None: 6978 response = function(self) 6979 else: 6980 response = function(self, state) 6981 await comm.write(response) 6982 await asyncio.sleep(interval) 6983 except OSError: 6984 pass 6985 finally: 6986 if teardown: 6987 teardown(self, state) 6988 6989 def log_worker_event(self, worker=None, topic=None, msg=None): 6990 self.log_event(topic, msg) 6991 6992 def subscribe_worker_status(self, comm=None): 6993 WorkerStatusPlugin(self, comm) 6994 ident = self.identity() 6995 for v in ident["workers"].values(): 6996 del v["metrics"] 6997 del v["last_seen"] 6998 return ident 6999 7000 def get_processing(self, comm=None, workers=None): 7001 parent: SchedulerState = cast(SchedulerState, self) 7002 ws: WorkerState 7003 ts: TaskState 7004 if workers is not None: 7005 workers = set(map(self.coerce_address, workers)) 7006 return { 7007 w: [ts._key for ts in parent._workers_dv[w].processing] for w in workers 7008 } 7009 else: 7010 return { 7011 w: [ts._key for ts in ws._processing] 7012 for w, ws in parent._workers_dv.items() 7013 } 7014 7015 def get_who_has(self, comm=None, keys=None): 7016 parent: SchedulerState = cast(SchedulerState, self) 7017 ws: WorkerState 7018 ts: TaskState 7019 if keys is not None: 7020 return { 7021 k: [ws._address for ws in parent._tasks[k].who_has] 7022 if k in parent._tasks 7023 else [] 7024 for k in keys 7025 } 7026 else: 7027 return { 7028 key: [ws._address for ws in ts._who_has] 7029 for key, ts in parent._tasks.items() 7030 } 7031 7032 def get_has_what(self, comm=None, workers=None): 7033 parent: SchedulerState = cast(SchedulerState, self) 7034 ws: WorkerState 7035 ts: TaskState 7036 if workers is not None: 7037 workers = map(self.coerce_address, workers) 7038 return { 7039 w: [ts._key for ts in parent._workers_dv[w].has_what] 7040 if w in parent._workers_dv 7041 else [] 7042 for w in workers 7043 } 7044 else: 7045 return { 7046 w: [ts._key for ts in ws.has_what] 7047 for w, ws in parent._workers_dv.items() 7048 } 7049 7050 def get_ncores(self, comm=None, workers=None): 7051 parent: SchedulerState = cast(SchedulerState, self) 7052 ws: WorkerState 7053 if workers is not None: 7054 workers = map(self.coerce_address, workers) 7055 return { 7056 w: parent._workers_dv[w].nthreads 7057 for w in workers 7058 if w in parent._workers_dv 7059 } 7060 else: 7061 return {w: ws._nthreads for w, ws in parent._workers_dv.items()} 7062 7063 def get_ncores_running(self, comm=None, workers=None): 7064 parent: SchedulerState = cast(SchedulerState, self) 7065 ncores = self.get_ncores(workers=workers) 7066 return { 7067 w: n 7068 for w, n in ncores.items() 7069 if parent._workers_dv[w].status == Status.running 7070 } 7071 7072 async def get_call_stack(self, comm=None, keys=None): 7073 parent: SchedulerState = cast(SchedulerState, self) 7074 ts: TaskState 7075 dts: TaskState 7076 if keys is not None: 7077 stack = list(keys) 7078 processing = set() 7079 while stack: 7080 key = stack.pop() 7081 ts = parent._tasks[key] 7082 if ts._state == "waiting": 7083 stack.extend([dts._key for dts in ts._dependencies]) 7084 elif ts._state == "processing": 7085 processing.add(ts) 7086 7087 workers = defaultdict(list) 7088 for ts in processing: 7089 if ts._processing_on: 7090 workers[ts._processing_on.address].append(ts._key) 7091 else: 7092 workers = {w: None for w in parent._workers_dv} 7093 7094 if not workers: 7095 return {} 7096 7097 results = await asyncio.gather( 7098 *(self.rpc(w).call_stack(keys=v) for w, v in workers.items()) 7099 ) 7100 response = {w: r for w, r in zip(workers, results) if r} 7101 return response 7102 7103 def get_nbytes(self, comm=None, keys=None, summary=True): 7104 parent: SchedulerState = cast(SchedulerState, self) 7105 ts: TaskState 7106 with log_errors(): 7107 if keys is not None: 7108 result = {k: parent._tasks[k].nbytes for k in keys} 7109 else: 7110 result = { 7111 k: ts._nbytes for k, ts in parent._tasks.items() if ts._nbytes >= 0 7112 } 7113 7114 if summary: 7115 out = defaultdict(lambda: 0) 7116 for k, v in result.items(): 7117 out[key_split(k)] += v 7118 result = dict(out) 7119 7120 return result 7121 7122 def run_function(self, stream, function, args=(), kwargs={}, wait=True): 7123 """Run a function within this process 7124 7125 See Also 7126 -------- 7127 Client.run_on_scheduler 7128 """ 7129 from .worker import run 7130 7131 if not dask.config.get("distributed.scheduler.pickle"): 7132 raise ValueError( 7133 "Cannot run function as the scheduler has been explicitly disallowed from " 7134 "deserializing arbitrary bytestrings using pickle via the " 7135 "'distributed.scheduler.pickle' configuration setting." 7136 ) 7137 7138 self.log_event("all", {"action": "run-function", "function": function}) 7139 return run(self, stream, function=function, args=args, kwargs=kwargs, wait=wait) 7140 7141 def set_metadata(self, comm=None, keys=None, value=None): 7142 parent: SchedulerState = cast(SchedulerState, self) 7143 metadata = parent._task_metadata 7144 for key in keys[:-1]: 7145 if key not in metadata or not isinstance(metadata[key], (dict, list)): 7146 metadata[key] = {} 7147 metadata = metadata[key] 7148 metadata[keys[-1]] = value 7149 7150 def get_metadata(self, comm=None, keys=None, default=no_default): 7151 parent: SchedulerState = cast(SchedulerState, self) 7152 metadata = parent._task_metadata 7153 for key in keys[:-1]: 7154 metadata = metadata[key] 7155 try: 7156 return metadata[keys[-1]] 7157 except KeyError: 7158 if default != no_default: 7159 return default 7160 else: 7161 raise 7162 7163 def set_restrictions(self, comm=None, worker=None): 7164 ts: TaskState 7165 for key, restrictions in worker.items(): 7166 ts = self.tasks[key] 7167 if isinstance(restrictions, str): 7168 restrictions = {restrictions} 7169 ts._worker_restrictions = set(restrictions) 7170 7171 def get_task_status(self, comm=None, keys=None): 7172 parent: SchedulerState = cast(SchedulerState, self) 7173 return { 7174 key: (parent._tasks[key].state if key in parent._tasks else None) 7175 for key in keys 7176 } 7177 7178 def get_task_stream(self, comm=None, start=None, stop=None, count=None): 7179 from distributed.diagnostics.task_stream import TaskStreamPlugin 7180 7181 if TaskStreamPlugin.name not in self.plugins: 7182 self.add_plugin(TaskStreamPlugin(self)) 7183 7184 plugin = self.plugins[TaskStreamPlugin.name] 7185 7186 return plugin.collect(start=start, stop=stop, count=count) 7187 7188 def start_task_metadata(self, comm=None, name=None): 7189 plugin = CollectTaskMetaDataPlugin(scheduler=self, name=name) 7190 self.add_plugin(plugin) 7191 7192 def stop_task_metadata(self, comm=None, name=None): 7193 plugins = [ 7194 p 7195 for p in list(self.plugins.values()) 7196 if isinstance(p, CollectTaskMetaDataPlugin) and p.name == name 7197 ] 7198 if len(plugins) != 1: 7199 raise ValueError( 7200 "Expected to find exactly one CollectTaskMetaDataPlugin " 7201 f"with name {name} but found {len(plugins)}." 7202 ) 7203 7204 plugin = plugins[0] 7205 self.remove_plugin(name=plugin.name) 7206 return {"metadata": plugin.metadata, "state": plugin.state} 7207 7208 async def register_worker_plugin(self, comm, plugin, name=None): 7209 """Registers a setup function, and call it on every worker""" 7210 self.worker_plugins[name] = plugin 7211 7212 responses = await self.broadcast( 7213 msg=dict(op="plugin-add", plugin=plugin, name=name) 7214 ) 7215 return responses 7216 7217 async def unregister_worker_plugin(self, comm, name): 7218 """Unregisters a worker plugin""" 7219 try: 7220 self.worker_plugins.pop(name) 7221 except KeyError: 7222 raise ValueError(f"The worker plugin {name} does not exists") 7223 7224 responses = await self.broadcast(msg=dict(op="plugin-remove", name=name)) 7225 return responses 7226 7227 async def register_nanny_plugin(self, comm, plugin, name=None): 7228 """Registers a setup function, and call it on every worker""" 7229 self.nanny_plugins[name] = plugin 7230 7231 responses = await self.broadcast( 7232 msg=dict(op="plugin_add", plugin=plugin, name=name), 7233 nanny=True, 7234 ) 7235 return responses 7236 7237 async def unregister_nanny_plugin(self, comm, name): 7238 """Unregisters a worker plugin""" 7239 try: 7240 self.nanny_plugins.pop(name) 7241 except KeyError: 7242 raise ValueError(f"The nanny plugin {name} does not exists") 7243 7244 responses = await self.broadcast( 7245 msg=dict(op="plugin_remove", name=name), nanny=True 7246 ) 7247 return responses 7248 7249 def transition(self, key, finish: str, *args, **kwargs): 7250 """Transition a key from its current state to the finish state 7251 7252 Examples 7253 -------- 7254 >>> self.transition('x', 'waiting') 7255 {'x': 'processing'} 7256 7257 Returns 7258 ------- 7259 Dictionary of recommendations for future transitions 7260 7261 See Also 7262 -------- 7263 Scheduler.transitions: transitive version of this function 7264 """ 7265 parent: SchedulerState = cast(SchedulerState, self) 7266 recommendations: dict 7267 worker_msgs: dict 7268 client_msgs: dict 7269 a: tuple = parent._transition(key, finish, *args, **kwargs) 7270 recommendations, client_msgs, worker_msgs = a 7271 self.send_all(client_msgs, worker_msgs) 7272 return recommendations 7273 7274 def transitions(self, recommendations: dict): 7275 """Process transitions until none are left 7276 7277 This includes feedback from previous transitions and continues until we 7278 reach a steady state 7279 """ 7280 parent: SchedulerState = cast(SchedulerState, self) 7281 client_msgs: dict = {} 7282 worker_msgs: dict = {} 7283 parent._transitions(recommendations, client_msgs, worker_msgs) 7284 self.send_all(client_msgs, worker_msgs) 7285 7286 def story(self, *keys): 7287 """Get all transitions that touch one of the input keys""" 7288 keys = {key.key if isinstance(key, TaskState) else key for key in keys} 7289 return [ 7290 t for t in self.transition_log if t[0] in keys or keys.intersection(t[3]) 7291 ] 7292 7293 transition_story = story 7294 7295 def reschedule(self, key=None, worker=None): 7296 """Reschedule a task 7297 7298 Things may have shifted and this task may now be better suited to run 7299 elsewhere 7300 """ 7301 parent: SchedulerState = cast(SchedulerState, self) 7302 ts: TaskState 7303 try: 7304 ts = parent._tasks[key] 7305 except KeyError: 7306 logger.warning( 7307 "Attempting to reschedule task {}, which was not " 7308 "found on the scheduler. Aborting reschedule.".format(key) 7309 ) 7310 return 7311 if ts._state != "processing": 7312 return 7313 if worker and ts._processing_on.address != worker: 7314 return 7315 self.transitions({key: "released"}) 7316 7317 ##################### 7318 # Utility functions # 7319 ##################### 7320 7321 def add_resources(self, comm=None, worker=None, resources=None): 7322 parent: SchedulerState = cast(SchedulerState, self) 7323 ws: WorkerState = parent._workers_dv[worker] 7324 if resources: 7325 ws._resources.update(resources) 7326 ws._used_resources = {} 7327 for resource, quantity in ws._resources.items(): 7328 ws._used_resources[resource] = 0 7329 dr: dict = parent._resources.get(resource, None) 7330 if dr is None: 7331 parent._resources[resource] = dr = {} 7332 dr[worker] = quantity 7333 return "OK" 7334 7335 def remove_resources(self, worker): 7336 parent: SchedulerState = cast(SchedulerState, self) 7337 ws: WorkerState = parent._workers_dv[worker] 7338 for resource, quantity in ws._resources.items(): 7339 dr: dict = parent._resources.get(resource, None) 7340 if dr is None: 7341 parent._resources[resource] = dr = {} 7342 del dr[worker] 7343 7344 def coerce_address(self, addr, resolve=True): 7345 """ 7346 Coerce possible input addresses to canonical form. 7347 *resolve* can be disabled for testing with fake hostnames. 7348 7349 Handles strings, tuples, or aliases. 7350 """ 7351 # XXX how many address-parsing routines do we have? 7352 parent: SchedulerState = cast(SchedulerState, self) 7353 if addr in parent._aliases: 7354 addr = parent._aliases[addr] 7355 if isinstance(addr, tuple): 7356 addr = unparse_host_port(*addr) 7357 if not isinstance(addr, str): 7358 raise TypeError(f"addresses should be strings or tuples, got {addr!r}") 7359 7360 if resolve: 7361 addr = resolve_address(addr) 7362 else: 7363 addr = normalize_address(addr) 7364 7365 return addr 7366 7367 def workers_list(self, workers): 7368 """ 7369 List of qualifying workers 7370 7371 Takes a list of worker addresses or hostnames. 7372 Returns a list of all worker addresses that match 7373 """ 7374 parent: SchedulerState = cast(SchedulerState, self) 7375 if workers is None: 7376 return list(parent._workers) 7377 7378 out = set() 7379 for w in workers: 7380 if ":" in w: 7381 out.add(w) 7382 else: 7383 out.update({ww for ww in parent._workers if w in ww}) # TODO: quadratic 7384 return list(out) 7385 7386 def start_ipython(self, comm=None): 7387 """Start an IPython kernel 7388 7389 Returns Jupyter connection info dictionary. 7390 """ 7391 from ._ipython_utils import start_ipython 7392 7393 if self._ipython_kernel is None: 7394 self._ipython_kernel = start_ipython( 7395 ip=self.ip, ns={"scheduler": self}, log=logger 7396 ) 7397 return self._ipython_kernel.get_connection_info() 7398 7399 async def get_profile( 7400 self, 7401 comm=None, 7402 workers=None, 7403 scheduler=False, 7404 server=False, 7405 merge_workers=True, 7406 start=None, 7407 stop=None, 7408 key=None, 7409 ): 7410 parent: SchedulerState = cast(SchedulerState, self) 7411 if workers is None: 7412 workers = parent._workers_dv 7413 else: 7414 workers = set(parent._workers_dv) & set(workers) 7415 7416 if scheduler: 7417 return profile.get_profile(self.io_loop.profile, start=start, stop=stop) 7418 7419 results = await asyncio.gather( 7420 *( 7421 self.rpc(w).profile(start=start, stop=stop, key=key, server=server) 7422 for w in workers 7423 ), 7424 return_exceptions=True, 7425 ) 7426 7427 results = [r for r in results if not isinstance(r, Exception)] 7428 7429 if merge_workers: 7430 response = profile.merge(*results) 7431 else: 7432 response = dict(zip(workers, results)) 7433 return response 7434 7435 async def get_profile_metadata( 7436 self, 7437 comm=None, 7438 workers=None, 7439 merge_workers=True, 7440 start=None, 7441 stop=None, 7442 profile_cycle_interval=None, 7443 ): 7444 parent: SchedulerState = cast(SchedulerState, self) 7445 dt = profile_cycle_interval or dask.config.get( 7446 "distributed.worker.profile.cycle" 7447 ) 7448 dt = parse_timedelta(dt, default="ms") 7449 7450 if workers is None: 7451 workers = parent._workers_dv 7452 else: 7453 workers = set(parent._workers_dv) & set(workers) 7454 results = await asyncio.gather( 7455 *(self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers), 7456 return_exceptions=True, 7457 ) 7458 7459 results = [r for r in results if not isinstance(r, Exception)] 7460 counts = [v["counts"] for v in results] 7461 counts = itertools.groupby(merge_sorted(*counts), lambda t: t[0] // dt * dt) 7462 counts = [(time, sum(pluck(1, group))) for time, group in counts] 7463 7464 keys = set() 7465 for v in results: 7466 for t, d in v["keys"]: 7467 for k in d: 7468 keys.add(k) 7469 keys = {k: [] for k in keys} 7470 7471 groups1 = [v["keys"] for v in results] 7472 groups2 = list(merge_sorted(*groups1, key=first)) 7473 7474 last = 0 7475 for t, d in groups2: 7476 tt = t // dt * dt 7477 if tt > last: 7478 last = tt 7479 for k, v in keys.items(): 7480 v.append([tt, 0]) 7481 for k, v in d.items(): 7482 keys[k][-1][1] += v 7483 7484 return {"counts": counts, "keys": keys} 7485 7486 async def performance_report( 7487 self, comm=None, start=None, last_count=None, code="", mode=None 7488 ): 7489 parent: SchedulerState = cast(SchedulerState, self) 7490 stop = time() 7491 # Profiles 7492 compute, scheduler, workers = await asyncio.gather( 7493 *[ 7494 self.get_profile(start=start), 7495 self.get_profile(scheduler=True, start=start), 7496 self.get_profile(server=True, start=start), 7497 ] 7498 ) 7499 from . import profile 7500 7501 def profile_to_figure(state): 7502 data = profile.plot_data(state) 7503 figure, source = profile.plot_figure(data, sizing_mode="stretch_both") 7504 return figure 7505 7506 compute, scheduler, workers = map( 7507 profile_to_figure, (compute, scheduler, workers) 7508 ) 7509 7510 # Task stream 7511 task_stream = self.get_task_stream(start=start) 7512 total_tasks = len(task_stream) 7513 timespent = defaultdict(int) 7514 for d in task_stream: 7515 for x in d.get("startstops", []): 7516 timespent[x["action"]] += x["stop"] - x["start"] 7517 tasks_timings = "" 7518 for k in sorted(timespent.keys()): 7519 tasks_timings += f"\n<li> {k} time: {format_time(timespent[k])} </li>" 7520 7521 from .dashboard.components.scheduler import task_stream_figure 7522 from .diagnostics.task_stream import rectangles 7523 7524 rects = rectangles(task_stream) 7525 source, task_stream = task_stream_figure(sizing_mode="stretch_both") 7526 source.data.update(rects) 7527 7528 # Bandwidth 7529 from distributed.dashboard.components.scheduler import ( 7530 BandwidthTypes, 7531 BandwidthWorkers, 7532 ) 7533 7534 bandwidth_workers = BandwidthWorkers(self, sizing_mode="stretch_both") 7535 bandwidth_workers.update() 7536 bandwidth_types = BandwidthTypes(self, sizing_mode="stretch_both") 7537 bandwidth_types.update() 7538 7539 # System monitor 7540 from distributed.dashboard.components.shared import SystemMonitor 7541 7542 sysmon = SystemMonitor(self, last_count=last_count, sizing_mode="stretch_both") 7543 sysmon.update() 7544 7545 # Scheduler logs 7546 from distributed.dashboard.components.scheduler import SchedulerLogs 7547 7548 logs = SchedulerLogs(self) 7549 7550 from bokeh.models import Div, Panel, Tabs 7551 7552 import distributed 7553 7554 # HTML 7555 ws: WorkerState 7556 html = """ 7557 <h1> Dask Performance Report </h1> 7558 7559 <i> Select different tabs on the top for additional information </i> 7560 7561 <h2> Duration: {time} </h2> 7562 <h2> Tasks Information </h2> 7563 <ul> 7564 <li> number of tasks: {ntasks} </li> 7565 {tasks_timings} 7566 </ul> 7567 7568 <h2> Scheduler Information </h2> 7569 <ul> 7570 <li> Address: {address} </li> 7571 <li> Workers: {nworkers} </li> 7572 <li> Threads: {threads} </li> 7573 <li> Memory: {memory} </li> 7574 <li> Dask Version: {dask_version} </li> 7575 <li> Dask.Distributed Version: {distributed_version} </li> 7576 </ul> 7577 7578 <h2> Calling Code </h2> 7579 <pre> 7580{code} 7581 </pre> 7582 """.format( 7583 time=format_time(stop - start), 7584 ntasks=total_tasks, 7585 tasks_timings=tasks_timings, 7586 address=self.address, 7587 nworkers=len(parent._workers_dv), 7588 threads=sum([ws._nthreads for ws in parent._workers_dv.values()]), 7589 memory=format_bytes( 7590 sum([ws._memory_limit for ws in parent._workers_dv.values()]) 7591 ), 7592 code=code, 7593 dask_version=dask.__version__, 7594 distributed_version=distributed.__version__, 7595 ) 7596 html = Div( 7597 text=html, 7598 style={ 7599 "width": "100%", 7600 "height": "100%", 7601 "max-width": "1920px", 7602 "max-height": "1080px", 7603 "padding": "12px", 7604 "border": "1px solid lightgray", 7605 "box-shadow": "inset 1px 0 8px 0 lightgray", 7606 "overflow": "auto", 7607 }, 7608 ) 7609 7610 html = Panel(child=html, title="Summary") 7611 compute = Panel(child=compute, title="Worker Profile (compute)") 7612 workers = Panel(child=workers, title="Worker Profile (administrative)") 7613 scheduler = Panel(child=scheduler, title="Scheduler Profile (administrative)") 7614 task_stream = Panel(child=task_stream, title="Task Stream") 7615 bandwidth_workers = Panel( 7616 child=bandwidth_workers.root, title="Bandwidth (Workers)" 7617 ) 7618 bandwidth_types = Panel(child=bandwidth_types.root, title="Bandwidth (Types)") 7619 system = Panel(child=sysmon.root, title="System") 7620 logs = Panel(child=logs.root, title="Scheduler Logs") 7621 7622 tabs = Tabs( 7623 tabs=[ 7624 html, 7625 task_stream, 7626 system, 7627 logs, 7628 compute, 7629 workers, 7630 scheduler, 7631 bandwidth_workers, 7632 bandwidth_types, 7633 ] 7634 ) 7635 7636 from bokeh.core.templates import get_env 7637 from bokeh.plotting import output_file, save 7638 7639 with tmpfile(extension=".html") as fn: 7640 output_file(filename=fn, title="Dask Performance Report", mode=mode) 7641 template_directory = os.path.join( 7642 os.path.dirname(os.path.abspath(__file__)), "dashboard", "templates" 7643 ) 7644 template_environment = get_env() 7645 template_environment.loader.searchpath.append(template_directory) 7646 template = template_environment.get_template("performance_report.html") 7647 save(tabs, filename=fn, template=template) 7648 7649 with open(fn) as f: 7650 data = f.read() 7651 7652 return data 7653 7654 async def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False): 7655 results = await self.broadcast( 7656 msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny 7657 ) 7658 return results 7659 7660 def log_event(self, name, msg): 7661 event = (time(), msg) 7662 if isinstance(name, (list, tuple)): 7663 for n in name: 7664 self.events[n].append(event) 7665 self.event_counts[n] += 1 7666 self._report_event(n, event) 7667 else: 7668 self.events[name].append(event) 7669 self.event_counts[name] += 1 7670 self._report_event(name, event) 7671 7672 def _report_event(self, name, event): 7673 for client in self.event_subscriber[name]: 7674 self.report( 7675 { 7676 "op": "event", 7677 "topic": name, 7678 "event": event, 7679 }, 7680 client=client, 7681 ) 7682 7683 def subscribe_topic(self, topic, client): 7684 self.event_subscriber[topic].add(client) 7685 7686 def unsubscribe_topic(self, topic, client): 7687 self.event_subscriber[topic].discard(client) 7688 7689 def get_events(self, comm=None, topic=None): 7690 if topic is not None: 7691 return tuple(self.events[topic]) 7692 else: 7693 return valmap(tuple, self.events) 7694 7695 async def get_worker_monitor_info(self, recent=False, starts=None): 7696 parent: SchedulerState = cast(SchedulerState, self) 7697 if starts is None: 7698 starts = {} 7699 results = await asyncio.gather( 7700 *( 7701 self.rpc(w).get_monitor_info(recent=recent, start=starts.get(w, 0)) 7702 for w in parent._workers_dv 7703 ) 7704 ) 7705 return dict(zip(parent._workers_dv, results)) 7706 7707 ########### 7708 # Cleanup # 7709 ########### 7710 7711 def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): 7712 """Periodically reassess task duration time 7713 7714 The expected duration of a task can change over time. Unfortunately we 7715 don't have a good constant-time way to propagate the effects of these 7716 changes out to the summaries that they affect, like the total expected 7717 runtime of each of the workers, or what tasks are stealable. 7718 7719 In this coroutine we walk through all of the workers and re-align their 7720 estimates with the current state of tasks. We do this periodically 7721 rather than at every transition, and we only do it if the scheduler 7722 process isn't under load (using psutil.Process.cpu_percent()). This 7723 lets us avoid this fringe optimization when we have better things to 7724 think about. 7725 """ 7726 parent: SchedulerState = cast(SchedulerState, self) 7727 try: 7728 if self.status == Status.closed: 7729 return 7730 last = time() 7731 next_time = timedelta(seconds=0.1) 7732 7733 if self.proc.cpu_percent() < 50: 7734 workers: list = list(parent._workers.values()) 7735 nworkers: Py_ssize_t = len(workers) 7736 i: Py_ssize_t 7737 for i in range(nworkers): 7738 ws: WorkerState = workers[worker_index % nworkers] 7739 worker_index += 1 7740 try: 7741 if ws is None or not ws._processing: 7742 continue 7743 _reevaluate_occupancy_worker(parent, ws) 7744 finally: 7745 del ws # lose ref 7746 7747 duration = time() - last 7748 if duration > 0.005: # 5ms since last release 7749 next_time = timedelta(seconds=duration * 5) # 25ms gap 7750 break 7751 7752 self.loop.add_timeout( 7753 next_time, self.reevaluate_occupancy, worker_index=worker_index 7754 ) 7755 7756 except Exception: 7757 logger.error("Error in reevaluate occupancy", exc_info=True) 7758 raise 7759 7760 async def check_worker_ttl(self): 7761 parent: SchedulerState = cast(SchedulerState, self) 7762 ws: WorkerState 7763 now = time() 7764 for ws in parent._workers_dv.values(): 7765 if (ws._last_seen < now - self.worker_ttl) and ( 7766 ws._last_seen < now - 10 * heartbeat_interval(len(parent._workers_dv)) 7767 ): 7768 logger.warning( 7769 "Worker failed to heartbeat within %s seconds. Closing: %s", 7770 self.worker_ttl, 7771 ws, 7772 ) 7773 await self.remove_worker(address=ws._address) 7774 7775 def check_idle(self): 7776 parent: SchedulerState = cast(SchedulerState, self) 7777 ws: WorkerState 7778 if ( 7779 any([ws._processing for ws in parent._workers_dv.values()]) 7780 or parent._unrunnable 7781 ): 7782 self.idle_since = None 7783 return 7784 elif not self.idle_since: 7785 self.idle_since = time() 7786 7787 if time() > self.idle_since + self.idle_timeout: 7788 logger.info( 7789 "Scheduler closing after being idle for %s", 7790 format_time(self.idle_timeout), 7791 ) 7792 self.loop.add_callback(self.close) 7793 7794 def adaptive_target(self, comm=None, target_duration=None): 7795 """Desired number of workers based on the current workload 7796 7797 This looks at the current running tasks and memory use, and returns a 7798 number of desired workers. This is often used by adaptive scheduling. 7799 7800 Parameters 7801 ---------- 7802 target_duration : str 7803 A desired duration of time for computations to take. This affects 7804 how rapidly the scheduler will ask to scale. 7805 7806 See Also 7807 -------- 7808 distributed.deploy.Adaptive 7809 """ 7810 parent: SchedulerState = cast(SchedulerState, self) 7811 if target_duration is None: 7812 target_duration = dask.config.get("distributed.adaptive.target-duration") 7813 target_duration = parse_timedelta(target_duration) 7814 7815 # CPU 7816 cpu = math.ceil( 7817 parent._total_occupancy / target_duration 7818 ) # TODO: threads per worker 7819 7820 # Avoid a few long tasks from asking for many cores 7821 ws: WorkerState 7822 tasks_processing = 0 7823 for ws in parent._workers_dv.values(): 7824 tasks_processing += len(ws._processing) 7825 7826 if tasks_processing > cpu: 7827 break 7828 else: 7829 cpu = min(tasks_processing, cpu) 7830 7831 if parent._unrunnable and not parent._workers_dv: 7832 cpu = max(1, cpu) 7833 7834 # add more workers if more than 60% of memory is used 7835 limit = sum([ws._memory_limit for ws in parent._workers_dv.values()]) 7836 used = sum([ws._nbytes for ws in parent._workers_dv.values()]) 7837 memory = 0 7838 if used > 0.6 * limit and limit > 0: 7839 memory = 2 * len(parent._workers_dv) 7840 7841 target = max(memory, cpu) 7842 if target >= len(parent._workers_dv): 7843 return target 7844 else: # Scale down? 7845 to_close = self.workers_to_close() 7846 return len(parent._workers_dv) - len(to_close) 7847 7848 7849@cfunc 7850@exceptval(check=False) 7851def _remove_from_processing( 7852 state: SchedulerState, ts: TaskState 7853) -> str: # -> str | None 7854 """ 7855 Remove *ts* from the set of processing tasks. 7856 7857 See also ``Scheduler.set_duration_estimate`` 7858 """ 7859 ws: WorkerState = ts._processing_on 7860 ts._processing_on = None # type: ignore 7861 w: str = ws._address 7862 7863 if w not in state._workers_dv: # may have been removed 7864 return None # type: ignore 7865 7866 duration: double = ws._processing.pop(ts) 7867 if not ws._processing: 7868 state._total_occupancy -= ws._occupancy 7869 ws._occupancy = 0 7870 else: 7871 state._total_occupancy -= duration 7872 ws._occupancy -= duration 7873 7874 state.check_idle_saturated(ws) 7875 state.release_resources(ts, ws) 7876 7877 return w 7878 7879 7880@cfunc 7881@exceptval(check=False) 7882def _add_to_memory( 7883 state: SchedulerState, 7884 ts: TaskState, 7885 ws: WorkerState, 7886 recommendations: dict, 7887 client_msgs: dict, 7888 type=None, 7889 typename: str = None, 7890): 7891 """ 7892 Add *ts* to the set of in-memory tasks. 7893 """ 7894 if state._validate: 7895 assert ts not in ws._has_what 7896 7897 state.add_replica(ts, ws) 7898 7899 deps: list = list(ts._dependents) 7900 if len(deps) > 1: 7901 deps.sort(key=operator.attrgetter("priority"), reverse=True) 7902 7903 dts: TaskState 7904 s: set 7905 for dts in deps: 7906 s = dts._waiting_on 7907 if ts in s: 7908 s.discard(ts) 7909 if not s: # new task ready to run 7910 recommendations[dts._key] = "processing" 7911 7912 for dts in ts._dependencies: 7913 s = dts._waiters 7914 s.discard(ts) 7915 if not s and not dts._who_wants: 7916 recommendations[dts._key] = "released" 7917 7918 report_msg: dict = {} 7919 cs: ClientState 7920 if not ts._waiters and not ts._who_wants: 7921 recommendations[ts._key] = "released" 7922 else: 7923 report_msg["op"] = "key-in-memory" 7924 report_msg["key"] = ts._key 7925 if type is not None: 7926 report_msg["type"] = type 7927 7928 for cs in ts._who_wants: 7929 client_msgs[cs._client_key] = [report_msg] 7930 7931 ts.state = "memory" 7932 ts._type = typename # type: ignore 7933 ts._group._types.add(typename) 7934 7935 cs = state._clients["fire-and-forget"] 7936 if ts in cs._wants_what: 7937 _client_releases_keys( 7938 state, 7939 cs=cs, 7940 keys=[ts._key], 7941 recommendations=recommendations, 7942 ) 7943 7944 7945@cfunc 7946@exceptval(check=False) 7947def _propagate_forgotten( 7948 state: SchedulerState, ts: TaskState, recommendations: dict, worker_msgs: dict 7949): 7950 ts.state = "forgotten" 7951 key: str = ts._key 7952 dts: TaskState 7953 for dts in ts._dependents: 7954 dts._has_lost_dependencies = True 7955 dts._dependencies.remove(ts) 7956 dts._waiting_on.discard(ts) 7957 if dts._state not in ("memory", "erred"): 7958 # Cannot compute task anymore 7959 recommendations[dts._key] = "forgotten" 7960 ts._dependents.clear() 7961 ts._waiters.clear() 7962 7963 for dts in ts._dependencies: 7964 dts._dependents.remove(ts) 7965 dts._waiters.discard(ts) 7966 if not dts._dependents and not dts._who_wants: 7967 # Task not needed anymore 7968 assert dts is not ts 7969 recommendations[dts._key] = "forgotten" 7970 ts._dependencies.clear() 7971 ts._waiting_on.clear() 7972 7973 ws: WorkerState 7974 for ws in ts._who_has: 7975 w: str = ws._address 7976 if w in state._workers_dv: # in case worker has died 7977 worker_msgs[w] = [ 7978 { 7979 "op": "free-keys", 7980 "keys": [key], 7981 "stimulus_id": f"propagate-forgotten-{time()}", 7982 } 7983 ] 7984 state.remove_all_replicas(ts) 7985 7986 7987@cfunc 7988@exceptval(check=False) 7989def _client_releases_keys( 7990 state: SchedulerState, keys: list, cs: ClientState, recommendations: dict 7991): 7992 """Remove keys from client desired list""" 7993 logger.debug("Client %s releases keys: %s", cs._client_key, keys) 7994 ts: TaskState 7995 for key in keys: 7996 ts = state._tasks.get(key) # type: ignore 7997 if ts is not None and ts in cs._wants_what: 7998 cs._wants_what.remove(ts) 7999 ts._who_wants.remove(cs) 8000 if not ts._who_wants: 8001 if not ts._dependents: 8002 # No live dependents, can forget 8003 recommendations[ts._key] = "forgotten" 8004 elif ts._state != "erred" and not ts._waiters: 8005 recommendations[ts._key] = "released" 8006 8007 8008@cfunc 8009@exceptval(check=False) 8010def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> dict: 8011 """Convert a single computational task to a message""" 8012 ws: WorkerState 8013 dts: TaskState 8014 8015 # FIXME: The duration attribute is not used on worker. We could safe ourselves the time to compute and submit this 8016 if duration < 0: 8017 duration = state.get_task_duration(ts) 8018 8019 msg: dict = { 8020 "op": "compute-task", 8021 "key": ts._key, 8022 "priority": ts._priority, 8023 "duration": duration, 8024 "stimulus_id": f"compute-task-{time()}", 8025 "who_has": {}, 8026 } 8027 if ts._resource_restrictions: 8028 msg["resource_restrictions"] = ts._resource_restrictions 8029 if ts._actor: 8030 msg["actor"] = True 8031 8032 deps: set = ts._dependencies 8033 if deps: 8034 msg["who_has"] = { 8035 dts._key: [ws._address for ws in dts._who_has] for dts in deps 8036 } 8037 msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} 8038 8039 if state._validate: 8040 assert all(msg["who_has"].values()) 8041 8042 task = ts._run_spec 8043 if type(task) is dict: 8044 msg.update(task) 8045 else: 8046 msg["task"] = task 8047 8048 if ts._annotations: 8049 msg["annotations"] = ts._annotations 8050 8051 return msg 8052 8053 8054@cfunc 8055@exceptval(check=False) 8056def _task_to_report_msg(state: SchedulerState, ts: TaskState) -> dict: # -> dict | None 8057 if ts._state == "forgotten": 8058 return {"op": "cancelled-key", "key": ts._key} 8059 elif ts._state == "memory": 8060 return {"op": "key-in-memory", "key": ts._key} 8061 elif ts._state == "erred": 8062 failing_ts: TaskState = ts._exception_blame 8063 return { 8064 "op": "task-erred", 8065 "key": ts._key, 8066 "exception": failing_ts._exception, 8067 "traceback": failing_ts._traceback, 8068 } 8069 else: 8070 return None # type: ignore 8071 8072 8073@cfunc 8074@exceptval(check=False) 8075def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict: 8076 if ts._who_wants: 8077 report_msg: dict = _task_to_report_msg(state, ts) 8078 if report_msg is not None: 8079 cs: ClientState 8080 return {cs._client_key: [report_msg] for cs in ts._who_wants} 8081 return {} 8082 8083 8084@cfunc 8085@exceptval(check=False) 8086def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState): 8087 """See reevaluate_occupancy""" 8088 ts: TaskState 8089 old = ws._occupancy 8090 for ts in ws._processing: 8091 state.set_duration_estimate(ts, ws) 8092 8093 state.check_idle_saturated(ws) 8094 steal = state.extensions.get("stealing") 8095 if not steal: 8096 return 8097 if ws._occupancy > old * 1.3 or old > ws._occupancy * 1.3: 8098 for ts in ws._processing: 8099 steal.recalculate_cost(ts) 8100 8101 8102@cfunc 8103@exceptval(check=False) 8104def decide_worker( 8105 ts: TaskState, all_workers, valid_workers: set, objective 8106) -> WorkerState: # -> WorkerState | None 8107 """ 8108 Decide which worker should take task *ts*. 8109 8110 We choose the worker that has the data on which *ts* depends. 8111 8112 If several workers have dependencies then we choose the less-busy worker. 8113 8114 Optionally provide *valid_workers* of where jobs are allowed to occur 8115 (if all workers are allowed to take the task, pass None instead). 8116 8117 If the task requires data communication because no eligible worker has 8118 all the dependencies already, then we choose to minimize the number 8119 of bytes sent between workers. This is determined by calling the 8120 *objective* function. 8121 """ 8122 ws: WorkerState = None # type: ignore 8123 wws: WorkerState 8124 dts: TaskState 8125 deps: set = ts._dependencies 8126 candidates: set 8127 assert all([dts._who_has for dts in deps]) 8128 if ts._actor: 8129 candidates = set(all_workers) 8130 else: 8131 candidates = {wws for dts in deps for wws in dts._who_has} 8132 if valid_workers is None: 8133 if not candidates: 8134 candidates = set(all_workers) 8135 else: 8136 candidates &= valid_workers 8137 if not candidates: 8138 candidates = valid_workers 8139 if not candidates: 8140 if ts._loose_restrictions: 8141 ws = decide_worker(ts, all_workers, None, objective) 8142 return ws 8143 8144 ncandidates: Py_ssize_t = len(candidates) 8145 if ncandidates == 0: 8146 pass 8147 elif ncandidates == 1: 8148 for ws in candidates: 8149 break 8150 else: 8151 ws = min(candidates, key=objective) 8152 return ws 8153 8154 8155def validate_task_state(ts: TaskState): 8156 """ 8157 Validate the given TaskState. 8158 """ 8159 ws: WorkerState 8160 dts: TaskState 8161 8162 assert ts._state in ALL_TASK_STATES or ts._state == "forgotten", ts 8163 8164 if ts._waiting_on: 8165 assert ts._waiting_on.issubset(ts._dependencies), ( 8166 "waiting not subset of dependencies", 8167 str(ts._waiting_on), 8168 str(ts._dependencies), 8169 ) 8170 if ts._waiters: 8171 assert ts._waiters.issubset(ts._dependents), ( 8172 "waiters not subset of dependents", 8173 str(ts._waiters), 8174 str(ts._dependents), 8175 ) 8176 8177 for dts in ts._waiting_on: 8178 assert not dts._who_has, ("waiting on in-memory dep", str(ts), str(dts)) 8179 assert dts._state != "released", ("waiting on released dep", str(ts), str(dts)) 8180 for dts in ts._dependencies: 8181 assert ts in dts._dependents, ( 8182 "not in dependency's dependents", 8183 str(ts), 8184 str(dts), 8185 str(dts._dependents), 8186 ) 8187 if ts._state in ("waiting", "processing"): 8188 assert dts in ts._waiting_on or dts._who_has, ( 8189 "dep missing", 8190 str(ts), 8191 str(dts), 8192 ) 8193 assert dts._state != "forgotten" 8194 8195 for dts in ts._waiters: 8196 assert dts._state in ("waiting", "processing"), ( 8197 "waiter not in play", 8198 str(ts), 8199 str(dts), 8200 ) 8201 for dts in ts._dependents: 8202 assert ts in dts._dependencies, ( 8203 "not in dependent's dependencies", 8204 str(ts), 8205 str(dts), 8206 str(dts._dependencies), 8207 ) 8208 assert dts._state != "forgotten" 8209 8210 assert (ts._processing_on is not None) == (ts._state == "processing") 8211 assert bool(ts._who_has) == (ts._state == "memory"), (ts, ts._who_has, ts._state) 8212 8213 if ts._state == "processing": 8214 assert all([dts._who_has for dts in ts._dependencies]), ( 8215 "task processing without all deps", 8216 str(ts), 8217 str(ts._dependencies), 8218 ) 8219 assert not ts._waiting_on 8220 8221 if ts._who_has: 8222 assert ts._waiters or ts._who_wants, ( 8223 "unneeded task in memory", 8224 str(ts), 8225 str(ts._who_has), 8226 ) 8227 if ts._run_spec: # was computed 8228 assert ts._type 8229 assert isinstance(ts._type, str) 8230 assert not any([ts in dts._waiting_on for dts in ts._dependents]) 8231 for ws in ts._who_has: 8232 assert ts in ws._has_what, ( 8233 "not in who_has' has_what", 8234 str(ts), 8235 str(ws), 8236 str(ws._has_what), 8237 ) 8238 8239 if ts._who_wants: 8240 cs: ClientState 8241 for cs in ts._who_wants: 8242 assert ts in cs._wants_what, ( 8243 "not in who_wants' wants_what", 8244 str(ts), 8245 str(cs), 8246 str(cs._wants_what), 8247 ) 8248 8249 if ts._actor: 8250 if ts._state == "memory": 8251 assert sum([ts in ws._actors for ws in ts._who_has]) == 1 8252 if ts._state == "processing": 8253 assert ts in ts._processing_on.actors 8254 8255 8256def validate_worker_state(ws: WorkerState): 8257 ts: TaskState 8258 for ts in ws._has_what: 8259 assert ws in ts._who_has, ( 8260 "not in has_what' who_has", 8261 str(ws), 8262 str(ts), 8263 str(ts._who_has), 8264 ) 8265 8266 for ts in ws._actors: 8267 assert ts._state in ("memory", "processing") 8268 8269 8270def validate_state(tasks, workers, clients): 8271 """ 8272 Validate a current runtime state 8273 8274 This performs a sequence of checks on the entire graph, running in about 8275 linear time. This raises assert errors if anything doesn't check out. 8276 """ 8277 ts: TaskState 8278 for ts in tasks.values(): 8279 validate_task_state(ts) 8280 8281 ws: WorkerState 8282 for ws in workers.values(): 8283 validate_worker_state(ws) 8284 8285 cs: ClientState 8286 for cs in clients.values(): 8287 for ts in cs._wants_what: 8288 assert cs in ts._who_wants, ( 8289 "not in wants_what' who_wants", 8290 str(cs), 8291 str(ts), 8292 str(ts._who_wants), 8293 ) 8294 8295 8296_round_robin = [0] 8297 8298 8299def heartbeat_interval(n): 8300 """ 8301 Interval in seconds that we desire heartbeats based on number of workers 8302 """ 8303 if n <= 10: 8304 return 0.5 8305 elif n < 50: 8306 return 1 8307 elif n < 200: 8308 return 2 8309 else: 8310 # no more than 200 hearbeats a second scaled by workers 8311 return n / 200 + 1 8312 8313 8314class KilledWorker(Exception): 8315 def __init__(self, task, last_worker): 8316 super().__init__(task, last_worker) 8317 self.task = task 8318 self.last_worker = last_worker 8319 8320 8321class WorkerStatusPlugin(SchedulerPlugin): 8322 """ 8323 An plugin to share worker status with a remote observer 8324 8325 This is used in cluster managers to keep updated about the status of the 8326 scheduler. 8327 """ 8328 8329 name = "worker-status" 8330 8331 def __init__(self, scheduler, comm): 8332 self.bcomm = BatchedSend(interval="5ms") 8333 self.bcomm.start(comm) 8334 8335 self.scheduler = scheduler 8336 self.scheduler.add_plugin(self) 8337 8338 def add_worker(self, worker=None, **kwargs): 8339 ident = self.scheduler.workers[worker].identity() 8340 del ident["metrics"] 8341 del ident["last_seen"] 8342 try: 8343 self.bcomm.send(["add", {"workers": {worker: ident}}]) 8344 except CommClosedError: 8345 self.scheduler.remove_plugin(name=self.name) 8346 8347 def remove_worker(self, worker=None, **kwargs): 8348 try: 8349 self.bcomm.send(["remove", worker]) 8350 except CommClosedError: 8351 self.scheduler.remove_plugin(name=self.name) 8352 8353 def teardown(self): 8354 self.bcomm.close() 8355 8356 8357class CollectTaskMetaDataPlugin(SchedulerPlugin): 8358 def __init__(self, scheduler, name): 8359 self.scheduler = scheduler 8360 self.name = name 8361 self.keys = set() 8362 self.metadata = {} 8363 self.state = {} 8364 8365 def update_graph(self, scheduler, dsk=None, keys=None, restrictions=None, **kwargs): 8366 self.keys.update(keys) 8367 8368 def transition(self, key, start, finish, *args, **kwargs): 8369 if finish == "memory" or finish == "erred": 8370 ts: TaskState = self.scheduler.tasks.get(key) 8371 if ts is not None and ts._key in self.keys: 8372 self.metadata[key] = ts._metadata 8373 self.state[key] = finish 8374 self.keys.discard(key) 8375