1import logging
2import math
3import os
4
5from bokeh.core.properties import without_property_validation
6from bokeh.layouts import column, row
7from bokeh.models import (
8    BoxZoomTool,
9    ColumnDataSource,
10    DataRange1d,
11    HoverTool,
12    NumeralTickFormatter,
13    PanTool,
14    ResetTool,
15    Select,
16    WheelZoomTool,
17)
18from bokeh.models.widgets import DataTable, TableColumn
19from bokeh.palettes import RdBu
20from bokeh.plotting import figure
21from bokeh.themes import Theme
22from tlz import merge, partition_all
23
24from dask.utils import format_bytes, format_time
25
26from distributed.dashboard.components import add_periodic_callback
27from distributed.dashboard.components.shared import (
28    DashboardComponent,
29    ProfileServer,
30    ProfileTimePlot,
31    SystemMonitor,
32)
33from distributed.dashboard.utils import transpose, update
34from distributed.diagnostics.progress_stream import color_of
35from distributed.metrics import time
36from distributed.utils import key_split, log_errors
37
38logger = logging.getLogger(__name__)
39
40from jinja2 import Environment, FileSystemLoader
41
42env = Environment(
43    loader=FileSystemLoader(
44        os.path.join(os.path.dirname(__file__), "..", "..", "http", "templates")
45    )
46)
47
48BOKEH_THEME = Theme(
49    filename=os.path.join(os.path.dirname(__file__), "..", "theme.yaml")
50)
51
52template_variables = {"pages": ["status", "system", "profile", "crossfilter"]}
53
54
55def standard_doc(title, active_page, *, template="simple.html"):
56    def decorator(f):
57        def wrapper(arg, extra, doc):
58            with log_errors():
59                doc.title = title
60                doc.template = env.get_template(template)
61                if active_page is not None:
62                    doc.template_variables["active_page"] = active_page
63                doc.template_variables.update(extra)
64                doc.theme = BOKEH_THEME
65                return f(arg, extra, doc)
66
67        return wrapper
68
69    return decorator
70
71
72class StateTable(DashboardComponent):
73    """Currently running tasks"""
74
75    def __init__(self, worker):
76        self.worker = worker
77
78        names = ["Stored", "Executing", "Ready", "Waiting", "Connections", "Serving"]
79        self.source = ColumnDataSource({name: [] for name in names})
80
81        columns = {name: TableColumn(field=name, title=name) for name in names}
82
83        table = DataTable(
84            source=self.source, columns=[columns[n] for n in names], height=70
85        )
86        self.root = table
87
88    @without_property_validation
89    def update(self):
90        with log_errors():
91            w = self.worker
92            d = {
93                "Stored": [len(w.data)],
94                "Executing": ["%d / %d" % (w.executing_count, w.nthreads)],
95                "Ready": [len(w.ready)],
96                "Waiting": [w.waiting_for_data_count],
97                "Connections": [len(w.in_flight_workers)],
98                "Serving": [len(w._comms)],
99            }
100            update(self.source, d)
101
102
103class CommunicatingStream(DashboardComponent):
104    def __init__(self, worker, height=300, **kwargs):
105        with log_errors():
106            self.worker = worker
107            names = [
108                "start",
109                "stop",
110                "middle",
111                "duration",
112                "who",
113                "y",
114                "hover",
115                "alpha",
116                "bandwidth",
117                "total",
118            ]
119
120            self.incoming = ColumnDataSource({name: [] for name in names})
121            self.outgoing = ColumnDataSource({name: [] for name in names})
122
123            x_range = DataRange1d(range_padding=0)
124            y_range = DataRange1d(range_padding=0)
125
126            fig = figure(
127                title="Peer Communications",
128                x_axis_type="datetime",
129                x_range=x_range,
130                y_range=y_range,
131                height=height,
132                tools="",
133                **kwargs,
134            )
135
136            fig.rect(
137                source=self.incoming,
138                x="middle",
139                y="y",
140                width="duration",
141                height=0.9,
142                color="red",
143                alpha="alpha",
144            )
145            fig.rect(
146                source=self.outgoing,
147                x="middle",
148                y="y",
149                width="duration",
150                height=0.9,
151                color="blue",
152                alpha="alpha",
153            )
154
155            hover = HoverTool(point_policy="follow_mouse", tooltips="""@hover""")
156            fig.add_tools(
157                hover,
158                ResetTool(),
159                PanTool(dimensions="width"),
160                WheelZoomTool(dimensions="width"),
161            )
162
163            self.root = fig
164
165            self.last_incoming = 0
166            self.last_outgoing = 0
167            self.who = dict()
168
169    @without_property_validation
170    def update(self):
171        with log_errors():
172            outgoing = self.worker.outgoing_transfer_log
173            n = self.worker.outgoing_count - self.last_outgoing
174            outgoing = [outgoing[-i].copy() for i in range(1, n + 1)]
175            self.last_outgoing = self.worker.outgoing_count
176
177            incoming = self.worker.incoming_transfer_log
178            n = self.worker.incoming_count - self.last_incoming
179            incoming = [incoming[-i].copy() for i in range(1, n + 1)]
180            self.last_incoming = self.worker.incoming_count
181
182            for [msgs, source] in [
183                [incoming, self.incoming],
184                [outgoing, self.outgoing],
185            ]:
186
187                for msg in msgs:
188                    if "compressed" in msg:
189                        del msg["compressed"]
190                    del msg["keys"]
191
192                    bandwidth = msg["total"] / (msg["duration"] or 0.5)
193                    bw = max(min(bandwidth / 500e6, 1), 0.3)
194                    msg["alpha"] = bw
195                    try:
196                        msg["y"] = self.who[msg["who"]]
197                    except KeyError:
198                        self.who[msg["who"]] = len(self.who)
199                        msg["y"] = self.who[msg["who"]]
200
201                    msg["hover"] = "{} / {} = {}/s".format(
202                        format_bytes(msg["total"]),
203                        format_time(msg["duration"]),
204                        format_bytes(msg["total"] / msg["duration"]),
205                    )
206
207                    for k in ["middle", "duration", "start", "stop"]:
208                        msg[k] = msg[k] * 1000
209
210                if msgs:
211                    msgs = transpose(msgs)
212                    if (
213                        len(source.data["stop"])
214                        and min(msgs["start"]) > source.data["stop"][-1] + 10000
215                    ):
216                        source.data.update(msgs)
217                    else:
218                        source.stream(msgs, rollover=10000)
219
220
221class CommunicatingTimeSeries(DashboardComponent):
222    def __init__(self, worker, **kwargs):
223        self.worker = worker
224        self.source = ColumnDataSource({"x": [], "in": [], "out": []})
225
226        x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0)
227
228        fig = figure(
229            title="Communication History",
230            x_axis_type="datetime",
231            y_range=[-0.1, worker.total_out_connections + 0.5],
232            height=150,
233            tools="",
234            x_range=x_range,
235            **kwargs,
236        )
237        fig.line(source=self.source, x="x", y="in", color="red")
238        fig.line(source=self.source, x="x", y="out", color="blue")
239
240        fig.add_tools(
241            ResetTool(), PanTool(dimensions="width"), WheelZoomTool(dimensions="width")
242        )
243
244        self.root = fig
245
246    @without_property_validation
247    def update(self):
248        with log_errors():
249            self.source.stream(
250                {
251                    "x": [time() * 1000],
252                    "out": [len(self.worker._comms)],
253                    "in": [len(self.worker.in_flight_workers)],
254                },
255                10000,
256            )
257
258
259class ExecutingTimeSeries(DashboardComponent):
260    def __init__(self, worker, **kwargs):
261        self.worker = worker
262        self.source = ColumnDataSource({"x": [], "y": []})
263
264        x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0)
265
266        fig = figure(
267            title="Executing History",
268            x_axis_type="datetime",
269            y_range=[-0.1, worker.nthreads + 0.1],
270            height=150,
271            tools="",
272            x_range=x_range,
273            **kwargs,
274        )
275        fig.line(source=self.source, x="x", y="y")
276
277        fig.add_tools(
278            ResetTool(), PanTool(dimensions="width"), WheelZoomTool(dimensions="width")
279        )
280
281        self.root = fig
282
283    @without_property_validation
284    def update(self):
285        with log_errors():
286            self.source.stream(
287                {"x": [time() * 1000], "y": [self.worker.executing_count]}, 1000
288            )
289
290
291class CrossFilter(DashboardComponent):
292    def __init__(self, worker, **kwargs):
293        with log_errors():
294            self.worker = worker
295
296            quantities = ["nbytes", "duration", "bandwidth", "count", "start", "stop"]
297            colors = ["inout-color", "type-color", "key-color"]
298
299            # self.source = ColumnDataSource({name: [] for name in names})
300            self.source = ColumnDataSource(
301                {
302                    "nbytes": [1, 2],
303                    "duration": [0.01, 0.02],
304                    "bandwidth": [0.01, 0.02],
305                    "count": [1, 2],
306                    "type": ["int", "str"],
307                    "inout-color": ["blue", "red"],
308                    "type-color": ["blue", "red"],
309                    "key": ["add", "inc"],
310                    "start": [1, 2],
311                    "stop": [1, 2],
312                }
313            )
314
315            self.x = Select(title="X-Axis", value="nbytes", options=quantities)
316            self.x.on_change("value", self.update_figure)
317
318            self.y = Select(title="Y-Axis", value="bandwidth", options=quantities)
319            self.y.on_change("value", self.update_figure)
320
321            self.color = Select(
322                title="Color", value="inout-color", options=["black"] + colors
323            )
324            self.color.on_change("value", self.update_figure)
325
326            if "sizing_mode" in kwargs:
327                kw = {"sizing_mode": kwargs["sizing_mode"]}
328            else:
329                kw = {}
330
331            self.control = column([self.x, self.y, self.color], width=200, **kw)
332
333            self.last_outgoing = 0
334            self.last_incoming = 0
335            self.kwargs = kwargs
336
337            self.layout = row(self.control, self.create_figure(**self.kwargs), **kw)
338
339            self.root = self.layout
340
341    @without_property_validation
342    def update(self):
343        with log_errors():
344            outgoing = self.worker.outgoing_transfer_log
345            n = self.worker.outgoing_count - self.last_outgoing
346            n = min(n, 1000)
347            outgoing = [outgoing[-i].copy() for i in range(1, n)]
348            self.last_outgoing = self.worker.outgoing_count
349
350            incoming = self.worker.incoming_transfer_log
351            n = self.worker.incoming_count - self.last_incoming
352            n = min(n, 1000)
353            incoming = [incoming[-i].copy() for i in range(1, n)]
354            self.last_incoming = self.worker.incoming_count
355
356            out = []
357
358            for msg in incoming:
359                if msg["keys"]:
360                    d = self.process_msg(msg)
361                    d["inout-color"] = "red"
362                    out.append(d)
363
364            for msg in outgoing:
365                if msg["keys"]:
366                    d = self.process_msg(msg)
367                    d["inout-color"] = "blue"
368                    out.append(d)
369
370            if out:
371                out = transpose(out)
372                if (
373                    len(self.source.data["stop"])
374                    and min(out["start"]) > self.source.data["stop"][-1] + 10
375                ):
376                    update(self.source, out)
377                else:
378                    self.source.stream(out, rollover=1000)
379
380    def create_figure(self, **kwargs):
381        with log_errors():
382            fig = figure(title="", tools="", **kwargs)
383            fig.circle(
384                source=self.source,
385                x=self.x.value,
386                y=self.y.value,
387                color=self.color.value,
388                size=10,
389                alpha=0.5,
390                hover_alpha=1,
391            )
392            fig.xaxis.axis_label = self.x.value
393            fig.yaxis.axis_label = self.y.value
394
395            fig.add_tools(
396                # self.hover,
397                ResetTool(),
398                PanTool(),
399                WheelZoomTool(),
400                BoxZoomTool(),
401            )
402            return fig
403
404    @without_property_validation
405    def update_figure(self, attr, old, new):
406        with log_errors():
407            fig = self.create_figure(**self.kwargs)
408            self.layout.children[1] = fig
409
410    def process_msg(self, msg):
411        try:
412            status_key = max(msg["keys"], key=lambda x: msg["keys"].get(x, 0))
413            typ = self.worker.types.get(status_key, object).__name__
414            keyname = key_split(status_key)
415            d = {
416                "nbytes": msg["total"],
417                "duration": msg["duration"],
418                "bandwidth": msg["bandwidth"],
419                "count": len(msg["keys"]),
420                "type": typ,
421                "type-color": color_of(typ),
422                "key": keyname,
423                "key-color": color_of(keyname),
424                "start": msg["start"],
425                "stop": msg["stop"],
426            }
427            return d
428        except Exception as e:
429            logger.exception(e)
430            raise
431
432
433class Counters(DashboardComponent):
434    def __init__(self, server, sizing_mode="stretch_both", **kwargs):
435        self.server = server
436        self.counter_figures = {}
437        self.counter_sources = {}
438        self.digest_figures = {}
439        self.digest_sources = {}
440        self.sizing_mode = sizing_mode
441
442        if self.server.digests:
443            for name in self.server.digests:
444                self.add_digest_figure(name)
445        for name in self.server.counters:
446            self.add_counter_figure(name)
447
448        figures = merge(self.digest_figures, self.counter_figures)
449        figures = [figures[k] for k in sorted(figures)]
450
451        if len(figures) <= 5:
452            self.root = column(figures, sizing_mode=sizing_mode)
453        else:
454            self.root = column(
455                *(
456                    row(*pair, sizing_mode=sizing_mode)
457                    for pair in partition_all(2, figures)
458                ),
459                sizing_mode=sizing_mode,
460            )
461
462    def add_digest_figure(self, name):
463        with log_errors():
464            n = len(self.server.digests[name].intervals)
465            sources = {i: ColumnDataSource({"x": [], "y": []}) for i in range(n)}
466
467            kwargs = {}
468            if name.endswith("duration"):
469                kwargs["x_axis_type"] = "datetime"
470
471            fig = figure(
472                title=name, tools="", height=150, sizing_mode=self.sizing_mode, **kwargs
473            )
474            fig.yaxis.visible = False
475            fig.ygrid.visible = False
476            if name.endswith("bandwidth") or name.endswith("bytes"):
477                fig.xaxis[0].formatter = NumeralTickFormatter(format="0.0b")
478
479            for i in range(n):
480                alpha = 0.3 + 0.3 * (n - i) / n
481                fig.line(
482                    source=sources[i],
483                    x="x",
484                    y="y",
485                    alpha=alpha,
486                    color=RdBu[max(n, 3)][-i],
487                )
488
489            fig.xaxis.major_label_orientation = math.pi / 12
490            self.digest_sources[name] = sources
491            self.digest_figures[name] = fig
492            return fig
493
494    def add_counter_figure(self, name):
495        with log_errors():
496            n = len(self.server.counters[name].intervals)
497            sources = {
498                i: ColumnDataSource({"x": [], "y": [], "y-center": [], "counts": []})
499                for i in range(n)
500            }
501
502            fig = figure(
503                title=name,
504                tools="",
505                height=150,
506                sizing_mode=self.sizing_mode,
507                x_range=sorted(
508                    str(x) for x in self.server.counters[name].components[0]
509                ),
510            )
511            fig.ygrid.visible = False
512
513            for i in range(n):
514                width = 0.5 + 0.4 * i / n
515                fig.rect(
516                    source=sources[i],
517                    x="x",
518                    y="y-center",
519                    width=width,
520                    height="y",
521                    alpha=0.3,
522                    color=RdBu[max(n, 3)][-i],
523                )
524                hover = HoverTool(
525                    point_policy="follow_mouse", tooltips="""@x : @counts"""
526                )
527                fig.add_tools(hover)
528                fig.xaxis.major_label_orientation = math.pi / 12
529
530            self.counter_sources[name] = sources
531            self.counter_figures[name] = fig
532            return fig
533
534    @without_property_validation
535    def update(self):
536        with log_errors():
537            for name, fig in self.digest_figures.items():
538                digest = self.server.digests[name]
539                d = {}
540                for i, d in enumerate(digest.components):
541                    if d.size():
542                        ys, xs = d.histogram(100)
543                        xs = xs[1:]
544                        if name.endswith("duration"):
545                            xs *= 1000
546                        self.digest_sources[name][i].data.update({"x": xs, "y": ys})
547                fig.title.text = "%s: %d" % (name, digest.size())
548
549            for name, fig in self.counter_figures.items():
550                counter = self.server.counters[name]
551                d = {}
552                for i, d in enumerate(counter.components):
553                    if d:
554                        xs = sorted(d)
555                        factor = counter.intervals[0] / counter.intervals[i]
556                        counts = [d[x] for x in xs]
557                        ys = [factor * c for c in counts]
558                        y_centers = [y / 2 for y in ys]
559                        xs = [str(x) for x in xs]
560                        d = {"x": xs, "y": ys, "y-center": y_centers, "counts": counts}
561                        self.counter_sources[name][i].data.update(d)
562                    fig.title.text = "%s: %d" % (name, counter.size())
563                    fig.x_range.factors = [str(x) for x in xs]
564
565
566@standard_doc("Dask Worker Internal Monitor", active_page="status")
567def status_doc(worker, extra, doc):
568    statetable = StateTable(worker)
569    executing_ts = ExecutingTimeSeries(worker, sizing_mode="scale_width")
570    communicating_ts = CommunicatingTimeSeries(worker, sizing_mode="scale_width")
571    communicating_stream = CommunicatingStream(worker, sizing_mode="scale_width")
572
573    xr = executing_ts.root.x_range
574    communicating_ts.root.x_range = xr
575    communicating_stream.root.x_range = xr
576
577    add_periodic_callback(doc, statetable, 200)
578    add_periodic_callback(doc, executing_ts, 200)
579    add_periodic_callback(doc, communicating_ts, 200)
580    add_periodic_callback(doc, communicating_stream, 200)
581    doc.add_root(
582        column(
583            statetable.root,
584            executing_ts.root,
585            communicating_ts.root,
586            communicating_stream.root,
587            sizing_mode="scale_width",
588        )
589    )
590
591
592@standard_doc("Dask Worker Cross-filter", active_page="crossfilter")
593def crossfilter_doc(worker, extra, doc):
594    statetable = StateTable(worker)
595    crossfilter = CrossFilter(worker)
596    add_periodic_callback(doc, statetable, 500)
597    add_periodic_callback(doc, crossfilter, 500)
598    doc.add_root(column(statetable.root, crossfilter.root))
599
600
601@standard_doc("Dask Worker Monitor", active_page="system")
602def systemmonitor_doc(worker, extra, doc):
603    sysmon = SystemMonitor(worker, sizing_mode="scale_width")
604    add_periodic_callback(doc, sysmon, 500)
605    doc.add_root(sysmon.root)
606
607
608@standard_doc("Dask Work Counters", active_page="counters")
609def counters_doc(server, extra, doc):
610    counter = Counters(server, sizing_mode="stretch_both")
611    add_periodic_callback(doc, counter, 500)
612    doc.add_root(counter.root)
613
614
615@standard_doc("Dask Worker Profile", active_page="profile")
616def profile_doc(server, extra, doc):
617    profile = ProfileTimePlot(server, sizing_mode="stretch_both", doc=doc)
618    doc.add_root(profile.root)
619    profile.trigger_update()
620
621
622@standard_doc("Dask: Profile of Event Loop", active_page=None)
623def profile_server_doc(server, extra, doc):
624    profile = ProfileServer(server, sizing_mode="stretch_both", doc=doc)
625    doc.add_root(profile.root)
626    profile.trigger_update()
627