1from __future__ import annotations 2 3import logging 4from collections import defaultdict, deque 5from math import log2 6from time import time 7from typing import Any, Container 8 9from tlz import topk 10from tornado.ioloop import PeriodicCallback 11 12import dask 13from dask.utils import parse_timedelta 14 15from .comm.addressing import get_address_host 16from .core import CommClosedError 17from .diagnostics.plugin import SchedulerPlugin 18from .utils import log_errors, recursive_to_dict 19 20# Stealing requires multiple network bounces and if successful also task 21# submission which may include code serialization. Therefore, be very 22# conservative in the latency estimation to suppress too aggressive stealing 23# of small tasks 24LATENCY = 0.1 25 26logger = logging.getLogger(__name__) 27 28 29LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") 30 31_WORKER_STATE_CONFIRM = { 32 "ready", 33 "constrained", 34 "waiting", 35} 36 37_WORKER_STATE_REJECT = { 38 "memory", 39 "executing", 40 "long-running", 41 "cancelled", 42 "resumed", 43} 44_WORKER_STATE_UNDEFINED = { 45 "released", 46 None, 47} 48 49 50class WorkStealing(SchedulerPlugin): 51 def __init__(self, scheduler): 52 self.scheduler = scheduler 53 # { level: { task states } } 54 self.stealable_all = [set() for i in range(15)] 55 # { worker: { level: { task states } } } 56 self.stealable = dict() 57 # { task state: (worker, level) } 58 self.key_stealable = dict() 59 60 self.cost_multipliers = [1 + 2 ** (i - 6) for i in range(15)] 61 self.cost_multipliers[0] = 1 62 63 for worker in scheduler.workers: 64 self.add_worker(worker=worker) 65 66 callback_time = parse_timedelta( 67 dask.config.get("distributed.scheduler.work-stealing-interval"), 68 default="ms", 69 ) 70 # `callback_time` is in milliseconds 71 pc = PeriodicCallback(callback=self.balance, callback_time=callback_time * 1000) 72 self._pc = pc 73 self.scheduler.periodic_callbacks["stealing"] = pc 74 self.scheduler.add_plugin(self) 75 self.scheduler.extensions["stealing"] = self 76 self.scheduler.events["stealing"] = deque(maxlen=100000) 77 self.count = 0 78 # { task state: <stealing info dict> } 79 self.in_flight = dict() 80 # { worker state: occupancy } 81 self.in_flight_occupancy = defaultdict(lambda: 0) 82 83 self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm 84 85 def _to_dict(self, *, exclude: Container[str] = None) -> dict[str, Any]: 86 """ 87 A very verbose dictionary representation for debugging purposes. 88 Not type stable and not inteded for roundtrips. 89 90 Parameters 91 ---------- 92 comm: 93 exclude: 94 A list of attributes which must not be present in the output. 95 96 See also 97 -------- 98 Client.dump_cluster_state 99 """ 100 return recursive_to_dict( 101 { 102 "stealable_all": self.stealable_all, 103 "stealable": self.stealable, 104 "key_stealable": self.key_stealable, 105 "in_flight": self.in_flight, 106 "in_flight_occupancy": self.in_flight_occupancy, 107 }, 108 exclude=exclude, 109 ) 110 111 def log(self, msg): 112 return self.scheduler.log_event("stealing", msg) 113 114 def add_worker(self, scheduler=None, worker=None): 115 self.stealable[worker] = [set() for i in range(15)] 116 117 def remove_worker(self, scheduler=None, worker=None): 118 del self.stealable[worker] 119 120 def teardown(self): 121 self._pc.stop() 122 123 def transition( 124 self, key, start, finish, compute_start=None, compute_stop=None, *args, **kwargs 125 ): 126 if finish == "processing": 127 ts = self.scheduler.tasks[key] 128 self.put_key_in_stealable(ts) 129 elif start == "processing": 130 ts = self.scheduler.tasks[key] 131 self.remove_key_from_stealable(ts) 132 d = self.in_flight.pop(ts, None) 133 if d: 134 thief = d["thief"] 135 victim = d["victim"] 136 self.in_flight_occupancy[thief] -= d["thief_duration"] 137 self.in_flight_occupancy[victim] += d["victim_duration"] 138 if not self.in_flight: 139 self.in_flight_occupancy.clear() 140 141 def recalculate_cost(self, ts): 142 if ts not in self.in_flight: 143 self.remove_key_from_stealable(ts) 144 self.put_key_in_stealable(ts) 145 146 def put_key_in_stealable(self, ts): 147 cost_multiplier, level = self.steal_time_ratio(ts) 148 if cost_multiplier is not None: 149 ws = ts.processing_on 150 worker = ws.address 151 self.stealable_all[level].add(ts) 152 self.stealable[worker][level].add(ts) 153 self.key_stealable[ts] = (worker, level) 154 155 def remove_key_from_stealable(self, ts): 156 result = self.key_stealable.pop(ts, None) 157 if result is None: 158 return 159 160 worker, level = result 161 try: 162 self.stealable[worker][level].remove(ts) 163 except KeyError: 164 pass 165 try: 166 self.stealable_all[level].remove(ts) 167 except KeyError: 168 pass 169 170 def steal_time_ratio(self, ts): 171 """The compute to communication time ratio of a key 172 173 Returns 174 ------- 175 cost_multiplier: The increased cost from moving this task as a factor. 176 For example a result of zero implies a task without dependencies. 177 level: The location within a stealable list to place this value 178 """ 179 split = ts.prefix.name 180 if split in fast_tasks or split in self.scheduler.unknown_durations: 181 return None, None 182 183 if not ts.dependencies: # no dependencies fast path 184 return 0, 0 185 186 ws = ts.processing_on 187 compute_time = ws.processing[ts] 188 if compute_time < 0.005: # 5ms, just give up 189 return None, None 190 191 nbytes = ts.get_nbytes_deps() 192 transfer_time = nbytes / self.scheduler.bandwidth + LATENCY 193 cost_multiplier = transfer_time / compute_time 194 if cost_multiplier > 100: 195 return None, None 196 197 level = int(round(log2(cost_multiplier) + 6)) 198 if level < 1: 199 level = 1 200 201 return cost_multiplier, level 202 203 def move_task_request(self, ts, victim, thief) -> str: 204 try: 205 if ts in self.in_flight: 206 return "in-flight" 207 stimulus_id = f"steal-{time()}" 208 209 key = ts.key 210 self.remove_key_from_stealable(ts) 211 logger.debug( 212 "Request move %s, %s: %2f -> %s: %2f", 213 key, 214 victim, 215 victim.occupancy, 216 thief, 217 thief.occupancy, 218 ) 219 220 victim_duration = victim.processing[ts] 221 222 thief_duration = self.scheduler.get_task_duration( 223 ts 224 ) + self.scheduler.get_comm_cost(ts, thief) 225 226 self.scheduler.stream_comms[victim.address].send( 227 {"op": "steal-request", "key": key, "stimulus_id": stimulus_id} 228 ) 229 self.in_flight[ts] = { 230 "victim": victim, # guaranteed to be processing_on 231 "thief": thief, 232 "victim_duration": victim_duration, 233 "thief_duration": thief_duration, 234 "stimulus_id": stimulus_id, 235 } 236 237 self.in_flight_occupancy[victim] -= victim_duration 238 self.in_flight_occupancy[thief] += thief_duration 239 return stimulus_id 240 except CommClosedError: 241 logger.info("Worker comm %r closed while stealing: %r", victim, ts) 242 return "comm-closed" 243 except Exception as e: 244 logger.exception(e) 245 if LOG_PDB: 246 import pdb 247 248 pdb.set_trace() 249 raise 250 251 async def move_task_confirm(self, *, key, state, stimulus_id, worker=None): 252 try: 253 ts = self.scheduler.tasks[key] 254 except KeyError: 255 logger.debug("Key released between request and confirm: %s", key) 256 return 257 try: 258 d = self.in_flight.pop(ts) 259 if d["stimulus_id"] != stimulus_id: 260 self.log(("stale-response", key, state, worker, stimulus_id)) 261 self.in_flight[ts] = d 262 return 263 except KeyError: 264 self.log(("already-aborted", key, state, stimulus_id)) 265 return 266 267 thief = d["thief"] 268 victim = d["victim"] 269 270 logger.debug("Confirm move %s, %s -> %s. State: %s", key, victim, thief, state) 271 272 self.in_flight_occupancy[thief] -= d["thief_duration"] 273 self.in_flight_occupancy[victim] += d["victim_duration"] 274 275 if not self.in_flight: 276 self.in_flight_occupancy.clear() 277 278 if self.scheduler.validate: 279 assert ts.processing_on == victim 280 281 try: 282 _log_msg = [key, state, victim.address, thief.address, stimulus_id] 283 284 if ts.state != "processing": 285 self.log(("not-processing", *_log_msg)) 286 old_thief = thief.occupancy 287 new_thief = sum(thief.processing.values()) 288 old_victim = victim.occupancy 289 new_victim = sum(victim.processing.values()) 290 thief.occupancy = new_thief 291 victim.occupancy = new_victim 292 self.scheduler.total_occupancy += ( 293 new_thief - old_thief + new_victim - old_victim 294 ) 295 elif ( 296 state in _WORKER_STATE_UNDEFINED 297 or state in _WORKER_STATE_CONFIRM 298 and thief.address not in self.scheduler.workers 299 ): 300 self.log( 301 ( 302 "reschedule", 303 thief.address not in self.scheduler.workers, 304 *_log_msg, 305 ) 306 ) 307 self.scheduler.reschedule(key) 308 # Victim had already started execution 309 elif state in _WORKER_STATE_REJECT: 310 self.log(("already-computing", *_log_msg)) 311 # Victim was waiting, has given up task, enact steal 312 elif state in _WORKER_STATE_CONFIRM: 313 self.remove_key_from_stealable(ts) 314 ts.processing_on = thief 315 duration = victim.processing.pop(ts) 316 victim.occupancy -= duration 317 self.scheduler.total_occupancy -= duration 318 if not victim.processing: 319 self.scheduler.total_occupancy -= victim.occupancy 320 victim.occupancy = 0 321 thief.processing[ts] = d["thief_duration"] 322 thief.occupancy += d["thief_duration"] 323 self.scheduler.total_occupancy += d["thief_duration"] 324 self.put_key_in_stealable(ts) 325 326 self.scheduler.send_task_to_worker(thief.address, ts) 327 self.log(("confirm", *_log_msg)) 328 else: 329 raise ValueError(f"Unexpected task state: {state}") 330 except Exception as e: 331 logger.exception(e) 332 if LOG_PDB: 333 import pdb 334 335 pdb.set_trace() 336 raise 337 finally: 338 self.scheduler.check_idle_saturated(thief) 339 self.scheduler.check_idle_saturated(victim) 340 341 def balance(self): 342 s = self.scheduler 343 344 def combined_occupancy(ws): 345 return ws.occupancy + self.in_flight_occupancy[ws] 346 347 def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier): 348 occ_idl = combined_occupancy(idl) 349 occ_sat = combined_occupancy(sat) 350 351 if occ_idl + cost_multiplier * duration <= occ_sat - duration / 2: 352 self.move_task_request(ts, sat, idl) 353 log.append( 354 ( 355 start, 356 level, 357 ts.key, 358 duration, 359 sat.address, 360 occ_sat, 361 idl.address, 362 occ_idl, 363 ) 364 ) 365 s.check_idle_saturated(sat, occ=occ_sat) 366 s.check_idle_saturated(idl, occ=occ_idl) 367 368 with log_errors(): 369 i = 0 370 idle = s.idle.values() 371 saturated = s.saturated 372 if not idle or len(idle) == len(s.workers): 373 return 374 375 log = [] 376 start = time() 377 378 if not s.saturated: 379 saturated = topk(10, s.workers.values(), key=combined_occupancy) 380 saturated = [ 381 ws 382 for ws in saturated 383 if combined_occupancy(ws) > 0.2 and len(ws.processing) > ws.nthreads 384 ] 385 elif len(s.saturated) < 20: 386 saturated = sorted(saturated, key=combined_occupancy, reverse=True) 387 if len(idle) < 20: 388 idle = sorted(idle, key=combined_occupancy) 389 390 for level, cost_multiplier in enumerate(self.cost_multipliers): 391 if not idle: 392 break 393 for sat in list(saturated): 394 stealable = self.stealable[sat.address][level] 395 if not stealable or not idle: 396 continue 397 398 for ts in list(stealable): 399 if ts not in self.key_stealable or ts.processing_on is not sat: 400 stealable.discard(ts) 401 continue 402 i += 1 403 if not idle: 404 break 405 406 if _has_restrictions(ts): 407 thieves = [ws for ws in idle if _can_steal(ws, ts, sat)] 408 else: 409 thieves = idle 410 if not thieves: 411 break 412 thief = thieves[i % len(thieves)] 413 414 duration = sat.processing.get(ts) 415 if duration is None: 416 stealable.discard(ts) 417 continue 418 419 maybe_move_task( 420 level, ts, sat, thief, duration, cost_multiplier 421 ) 422 423 if self.cost_multipliers[level] < 20: # don't steal from public at cost 424 stealable = self.stealable_all[level] 425 for ts in list(stealable): 426 if not idle: 427 break 428 if ts not in self.key_stealable: 429 stealable.discard(ts) 430 continue 431 432 sat = ts.processing_on 433 if sat is None: 434 stealable.discard(ts) 435 continue 436 if combined_occupancy(sat) < 0.2: 437 continue 438 if len(sat.processing) <= sat.nthreads: 439 continue 440 441 i += 1 442 if _has_restrictions(ts): 443 thieves = [ws for ws in idle if _can_steal(ws, ts, sat)] 444 else: 445 thieves = idle 446 if not thieves: 447 continue 448 thief = thieves[i % len(thieves)] 449 duration = sat.processing[ts] 450 451 maybe_move_task( 452 level, ts, sat, thief, duration, cost_multiplier 453 ) 454 455 if log: 456 self.log(log) 457 self.count += 1 458 stop = time() 459 if s.digests: 460 s.digests["steal-duration"].add(stop - start) 461 462 def restart(self, scheduler): 463 for stealable in self.stealable.values(): 464 for s in stealable: 465 s.clear() 466 467 for s in self.stealable_all: 468 s.clear() 469 self.key_stealable.clear() 470 471 def story(self, *keys): 472 keys = {key.key if not isinstance(key, str) else key for key in keys} 473 out = [] 474 for _, L in self.scheduler.get_events(topic="stealing"): 475 if not isinstance(L, list): 476 L = [L] 477 for t in L: 478 if any(x in keys for x in t): 479 out.append(t) 480 return out 481 482 483def _has_restrictions(ts): 484 """Determine whether the given task has restrictions and whether these 485 restrictions are strict. 486 """ 487 return not ts.loose_restrictions and ( 488 ts.host_restrictions or ts.worker_restrictions or ts.resource_restrictions 489 ) 490 491 492def _can_steal(thief, ts, victim): 493 """Determine whether worker ``thief`` can steal task ``ts`` from worker 494 ``victim``. 495 496 Assumes that `ts` has some restrictions. 497 """ 498 if ( 499 ts.host_restrictions 500 and get_address_host(thief.address) not in ts.host_restrictions 501 ): 502 return False 503 elif ts.worker_restrictions and thief.address not in ts.worker_restrictions: 504 return False 505 506 if victim.resources is None: 507 return True 508 509 for resource, value in victim.resources.items(): 510 try: 511 supplied = thief.resources[resource] 512 except KeyError: 513 return False 514 else: 515 if supplied < value: 516 return False 517 return True 518 519 520fast_tasks = {"split-shuffle"} 521