1from __future__ import annotations
2
3import logging
4from collections import defaultdict, deque
5from math import log2
6from time import time
7from typing import Any, Container
8
9from tlz import topk
10from tornado.ioloop import PeriodicCallback
11
12import dask
13from dask.utils import parse_timedelta
14
15from .comm.addressing import get_address_host
16from .core import CommClosedError
17from .diagnostics.plugin import SchedulerPlugin
18from .utils import log_errors, recursive_to_dict
19
20# Stealing requires multiple network bounces and if successful also task
21# submission which may include code serialization. Therefore, be very
22# conservative in the latency estimation to suppress too aggressive stealing
23# of small tasks
24LATENCY = 0.1
25
26logger = logging.getLogger(__name__)
27
28
29LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
30
31_WORKER_STATE_CONFIRM = {
32    "ready",
33    "constrained",
34    "waiting",
35}
36
37_WORKER_STATE_REJECT = {
38    "memory",
39    "executing",
40    "long-running",
41    "cancelled",
42    "resumed",
43}
44_WORKER_STATE_UNDEFINED = {
45    "released",
46    None,
47}
48
49
50class WorkStealing(SchedulerPlugin):
51    def __init__(self, scheduler):
52        self.scheduler = scheduler
53        # { level: { task states } }
54        self.stealable_all = [set() for i in range(15)]
55        # { worker: { level: { task states } } }
56        self.stealable = dict()
57        # { task state: (worker, level) }
58        self.key_stealable = dict()
59
60        self.cost_multipliers = [1 + 2 ** (i - 6) for i in range(15)]
61        self.cost_multipliers[0] = 1
62
63        for worker in scheduler.workers:
64            self.add_worker(worker=worker)
65
66        callback_time = parse_timedelta(
67            dask.config.get("distributed.scheduler.work-stealing-interval"),
68            default="ms",
69        )
70        # `callback_time` is in milliseconds
71        pc = PeriodicCallback(callback=self.balance, callback_time=callback_time * 1000)
72        self._pc = pc
73        self.scheduler.periodic_callbacks["stealing"] = pc
74        self.scheduler.add_plugin(self)
75        self.scheduler.extensions["stealing"] = self
76        self.scheduler.events["stealing"] = deque(maxlen=100000)
77        self.count = 0
78        # { task state: <stealing info dict> }
79        self.in_flight = dict()
80        # { worker state: occupancy }
81        self.in_flight_occupancy = defaultdict(lambda: 0)
82
83        self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm
84
85    def _to_dict(self, *, exclude: Container[str] = None) -> dict[str, Any]:
86        """
87        A very verbose dictionary representation for debugging purposes.
88        Not type stable and not inteded for roundtrips.
89
90        Parameters
91        ----------
92        comm:
93        exclude:
94            A list of attributes which must not be present in the output.
95
96        See also
97        --------
98        Client.dump_cluster_state
99        """
100        return recursive_to_dict(
101            {
102                "stealable_all": self.stealable_all,
103                "stealable": self.stealable,
104                "key_stealable": self.key_stealable,
105                "in_flight": self.in_flight,
106                "in_flight_occupancy": self.in_flight_occupancy,
107            },
108            exclude=exclude,
109        )
110
111    def log(self, msg):
112        return self.scheduler.log_event("stealing", msg)
113
114    def add_worker(self, scheduler=None, worker=None):
115        self.stealable[worker] = [set() for i in range(15)]
116
117    def remove_worker(self, scheduler=None, worker=None):
118        del self.stealable[worker]
119
120    def teardown(self):
121        self._pc.stop()
122
123    def transition(
124        self, key, start, finish, compute_start=None, compute_stop=None, *args, **kwargs
125    ):
126        if finish == "processing":
127            ts = self.scheduler.tasks[key]
128            self.put_key_in_stealable(ts)
129        elif start == "processing":
130            ts = self.scheduler.tasks[key]
131            self.remove_key_from_stealable(ts)
132            d = self.in_flight.pop(ts, None)
133            if d:
134                thief = d["thief"]
135                victim = d["victim"]
136                self.in_flight_occupancy[thief] -= d["thief_duration"]
137                self.in_flight_occupancy[victim] += d["victim_duration"]
138                if not self.in_flight:
139                    self.in_flight_occupancy.clear()
140
141    def recalculate_cost(self, ts):
142        if ts not in self.in_flight:
143            self.remove_key_from_stealable(ts)
144            self.put_key_in_stealable(ts)
145
146    def put_key_in_stealable(self, ts):
147        cost_multiplier, level = self.steal_time_ratio(ts)
148        if cost_multiplier is not None:
149            ws = ts.processing_on
150            worker = ws.address
151            self.stealable_all[level].add(ts)
152            self.stealable[worker][level].add(ts)
153            self.key_stealable[ts] = (worker, level)
154
155    def remove_key_from_stealable(self, ts):
156        result = self.key_stealable.pop(ts, None)
157        if result is None:
158            return
159
160        worker, level = result
161        try:
162            self.stealable[worker][level].remove(ts)
163        except KeyError:
164            pass
165        try:
166            self.stealable_all[level].remove(ts)
167        except KeyError:
168            pass
169
170    def steal_time_ratio(self, ts):
171        """The compute to communication time ratio of a key
172
173        Returns
174        -------
175        cost_multiplier: The increased cost from moving this task as a factor.
176        For example a result of zero implies a task without dependencies.
177        level: The location within a stealable list to place this value
178        """
179        split = ts.prefix.name
180        if split in fast_tasks or split in self.scheduler.unknown_durations:
181            return None, None
182
183        if not ts.dependencies:  # no dependencies fast path
184            return 0, 0
185
186        ws = ts.processing_on
187        compute_time = ws.processing[ts]
188        if compute_time < 0.005:  # 5ms, just give up
189            return None, None
190
191        nbytes = ts.get_nbytes_deps()
192        transfer_time = nbytes / self.scheduler.bandwidth + LATENCY
193        cost_multiplier = transfer_time / compute_time
194        if cost_multiplier > 100:
195            return None, None
196
197        level = int(round(log2(cost_multiplier) + 6))
198        if level < 1:
199            level = 1
200
201        return cost_multiplier, level
202
203    def move_task_request(self, ts, victim, thief) -> str:
204        try:
205            if ts in self.in_flight:
206                return "in-flight"
207            stimulus_id = f"steal-{time()}"
208
209            key = ts.key
210            self.remove_key_from_stealable(ts)
211            logger.debug(
212                "Request move %s, %s: %2f -> %s: %2f",
213                key,
214                victim,
215                victim.occupancy,
216                thief,
217                thief.occupancy,
218            )
219
220            victim_duration = victim.processing[ts]
221
222            thief_duration = self.scheduler.get_task_duration(
223                ts
224            ) + self.scheduler.get_comm_cost(ts, thief)
225
226            self.scheduler.stream_comms[victim.address].send(
227                {"op": "steal-request", "key": key, "stimulus_id": stimulus_id}
228            )
229            self.in_flight[ts] = {
230                "victim": victim,  # guaranteed to be processing_on
231                "thief": thief,
232                "victim_duration": victim_duration,
233                "thief_duration": thief_duration,
234                "stimulus_id": stimulus_id,
235            }
236
237            self.in_flight_occupancy[victim] -= victim_duration
238            self.in_flight_occupancy[thief] += thief_duration
239            return stimulus_id
240        except CommClosedError:
241            logger.info("Worker comm %r closed while stealing: %r", victim, ts)
242            return "comm-closed"
243        except Exception as e:
244            logger.exception(e)
245            if LOG_PDB:
246                import pdb
247
248                pdb.set_trace()
249            raise
250
251    async def move_task_confirm(self, *, key, state, stimulus_id, worker=None):
252        try:
253            ts = self.scheduler.tasks[key]
254        except KeyError:
255            logger.debug("Key released between request and confirm: %s", key)
256            return
257        try:
258            d = self.in_flight.pop(ts)
259            if d["stimulus_id"] != stimulus_id:
260                self.log(("stale-response", key, state, worker, stimulus_id))
261                self.in_flight[ts] = d
262                return
263        except KeyError:
264            self.log(("already-aborted", key, state, stimulus_id))
265            return
266
267        thief = d["thief"]
268        victim = d["victim"]
269
270        logger.debug("Confirm move %s, %s -> %s.  State: %s", key, victim, thief, state)
271
272        self.in_flight_occupancy[thief] -= d["thief_duration"]
273        self.in_flight_occupancy[victim] += d["victim_duration"]
274
275        if not self.in_flight:
276            self.in_flight_occupancy.clear()
277
278        if self.scheduler.validate:
279            assert ts.processing_on == victim
280
281        try:
282            _log_msg = [key, state, victim.address, thief.address, stimulus_id]
283
284            if ts.state != "processing":
285                self.log(("not-processing", *_log_msg))
286                old_thief = thief.occupancy
287                new_thief = sum(thief.processing.values())
288                old_victim = victim.occupancy
289                new_victim = sum(victim.processing.values())
290                thief.occupancy = new_thief
291                victim.occupancy = new_victim
292                self.scheduler.total_occupancy += (
293                    new_thief - old_thief + new_victim - old_victim
294                )
295            elif (
296                state in _WORKER_STATE_UNDEFINED
297                or state in _WORKER_STATE_CONFIRM
298                and thief.address not in self.scheduler.workers
299            ):
300                self.log(
301                    (
302                        "reschedule",
303                        thief.address not in self.scheduler.workers,
304                        *_log_msg,
305                    )
306                )
307                self.scheduler.reschedule(key)
308            # Victim had already started execution
309            elif state in _WORKER_STATE_REJECT:
310                self.log(("already-computing", *_log_msg))
311            # Victim was waiting, has given up task, enact steal
312            elif state in _WORKER_STATE_CONFIRM:
313                self.remove_key_from_stealable(ts)
314                ts.processing_on = thief
315                duration = victim.processing.pop(ts)
316                victim.occupancy -= duration
317                self.scheduler.total_occupancy -= duration
318                if not victim.processing:
319                    self.scheduler.total_occupancy -= victim.occupancy
320                    victim.occupancy = 0
321                thief.processing[ts] = d["thief_duration"]
322                thief.occupancy += d["thief_duration"]
323                self.scheduler.total_occupancy += d["thief_duration"]
324                self.put_key_in_stealable(ts)
325
326                self.scheduler.send_task_to_worker(thief.address, ts)
327                self.log(("confirm", *_log_msg))
328            else:
329                raise ValueError(f"Unexpected task state: {state}")
330        except Exception as e:
331            logger.exception(e)
332            if LOG_PDB:
333                import pdb
334
335                pdb.set_trace()
336            raise
337        finally:
338            self.scheduler.check_idle_saturated(thief)
339            self.scheduler.check_idle_saturated(victim)
340
341    def balance(self):
342        s = self.scheduler
343
344        def combined_occupancy(ws):
345            return ws.occupancy + self.in_flight_occupancy[ws]
346
347        def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier):
348            occ_idl = combined_occupancy(idl)
349            occ_sat = combined_occupancy(sat)
350
351            if occ_idl + cost_multiplier * duration <= occ_sat - duration / 2:
352                self.move_task_request(ts, sat, idl)
353                log.append(
354                    (
355                        start,
356                        level,
357                        ts.key,
358                        duration,
359                        sat.address,
360                        occ_sat,
361                        idl.address,
362                        occ_idl,
363                    )
364                )
365                s.check_idle_saturated(sat, occ=occ_sat)
366                s.check_idle_saturated(idl, occ=occ_idl)
367
368        with log_errors():
369            i = 0
370            idle = s.idle.values()
371            saturated = s.saturated
372            if not idle or len(idle) == len(s.workers):
373                return
374
375            log = []
376            start = time()
377
378            if not s.saturated:
379                saturated = topk(10, s.workers.values(), key=combined_occupancy)
380                saturated = [
381                    ws
382                    for ws in saturated
383                    if combined_occupancy(ws) > 0.2 and len(ws.processing) > ws.nthreads
384                ]
385            elif len(s.saturated) < 20:
386                saturated = sorted(saturated, key=combined_occupancy, reverse=True)
387            if len(idle) < 20:
388                idle = sorted(idle, key=combined_occupancy)
389
390            for level, cost_multiplier in enumerate(self.cost_multipliers):
391                if not idle:
392                    break
393                for sat in list(saturated):
394                    stealable = self.stealable[sat.address][level]
395                    if not stealable or not idle:
396                        continue
397
398                    for ts in list(stealable):
399                        if ts not in self.key_stealable or ts.processing_on is not sat:
400                            stealable.discard(ts)
401                            continue
402                        i += 1
403                        if not idle:
404                            break
405
406                        if _has_restrictions(ts):
407                            thieves = [ws for ws in idle if _can_steal(ws, ts, sat)]
408                        else:
409                            thieves = idle
410                        if not thieves:
411                            break
412                        thief = thieves[i % len(thieves)]
413
414                        duration = sat.processing.get(ts)
415                        if duration is None:
416                            stealable.discard(ts)
417                            continue
418
419                        maybe_move_task(
420                            level, ts, sat, thief, duration, cost_multiplier
421                        )
422
423                if self.cost_multipliers[level] < 20:  # don't steal from public at cost
424                    stealable = self.stealable_all[level]
425                    for ts in list(stealable):
426                        if not idle:
427                            break
428                        if ts not in self.key_stealable:
429                            stealable.discard(ts)
430                            continue
431
432                        sat = ts.processing_on
433                        if sat is None:
434                            stealable.discard(ts)
435                            continue
436                        if combined_occupancy(sat) < 0.2:
437                            continue
438                        if len(sat.processing) <= sat.nthreads:
439                            continue
440
441                        i += 1
442                        if _has_restrictions(ts):
443                            thieves = [ws for ws in idle if _can_steal(ws, ts, sat)]
444                        else:
445                            thieves = idle
446                        if not thieves:
447                            continue
448                        thief = thieves[i % len(thieves)]
449                        duration = sat.processing[ts]
450
451                        maybe_move_task(
452                            level, ts, sat, thief, duration, cost_multiplier
453                        )
454
455            if log:
456                self.log(log)
457                self.count += 1
458            stop = time()
459            if s.digests:
460                s.digests["steal-duration"].add(stop - start)
461
462    def restart(self, scheduler):
463        for stealable in self.stealable.values():
464            for s in stealable:
465                s.clear()
466
467        for s in self.stealable_all:
468            s.clear()
469        self.key_stealable.clear()
470
471    def story(self, *keys):
472        keys = {key.key if not isinstance(key, str) else key for key in keys}
473        out = []
474        for _, L in self.scheduler.get_events(topic="stealing"):
475            if not isinstance(L, list):
476                L = [L]
477            for t in L:
478                if any(x in keys for x in t):
479                    out.append(t)
480        return out
481
482
483def _has_restrictions(ts):
484    """Determine whether the given task has restrictions and whether these
485    restrictions are strict.
486    """
487    return not ts.loose_restrictions and (
488        ts.host_restrictions or ts.worker_restrictions or ts.resource_restrictions
489    )
490
491
492def _can_steal(thief, ts, victim):
493    """Determine whether worker ``thief`` can steal task ``ts`` from worker
494    ``victim``.
495
496    Assumes that `ts` has some restrictions.
497    """
498    if (
499        ts.host_restrictions
500        and get_address_host(thief.address) not in ts.host_restrictions
501    ):
502        return False
503    elif ts.worker_restrictions and thief.address not in ts.worker_restrictions:
504        return False
505
506    if victim.resources is None:
507        return True
508
509    for resource, value in victim.resources.items():
510        try:
511            supplied = thief.resources[resource]
512        except KeyError:
513            return False
514        else:
515            if supplied < value:
516                return False
517    return True
518
519
520fast_tasks = {"split-shuffle"}
521