1import random
2import warnings
3from bisect import bisect_left
4from itertools import cycle
5from operator import add, itemgetter
6
7from packaging.version import parse as parse_version
8from tlz import accumulate, groupby, pluck, unique
9
10from ..core import istask
11from ..utils import apply, funcname, import_required
12
13_BOKEH_MISSING_MSG = "Diagnostics plots require `bokeh` to be installed"
14
15
16def unquote(expr):
17    if istask(expr):
18        if expr[0] in (tuple, list, set):
19            return expr[0](map(unquote, expr[1]))
20        elif (
21            expr[0] == dict
22            and isinstance(expr[1], list)
23            and isinstance(expr[1][0], list)
24        ):
25            return dict(map(unquote, expr[1]))
26    return expr
27
28
29def pprint_task(task, keys, label_size=60):
30    """Return a nicely formatted string for a task.
31
32    Parameters
33    ----------
34    task:
35        Value within dask graph to render as text
36    keys: iterable
37        List of keys within dask graph
38    label_size: int (optional)
39        Maximum size of output label, defaults to 60
40
41    Examples
42    --------
43    >>> from operator import add, mul
44    >>> dsk = {'a': 1,
45    ...        'b': 2,
46    ...        'c': (add, 'a', 'b'),
47    ...        'd': (add, (mul, 'a', 'b'), 'c'),
48    ...        'e': (sum, ['a', 'b', 5]),
49    ...        'f': (add,),
50    ...        'g': []}
51
52    >>> pprint_task(dsk['c'], dsk)
53    'add(_, _)'
54    >>> pprint_task(dsk['d'], dsk)
55    'add(mul(_, _), _)'
56    >>> pprint_task(dsk['e'], dsk)
57    'sum([_, _, *])'
58    >>> pprint_task(dsk['f'], dsk)
59    'add()'
60    >>> pprint_task(dsk['g'], dsk)
61    '[]'
62    """
63    if istask(task):
64        func = task[0]
65        if func is apply:
66            head = funcname(task[1])
67            tail = ")"
68            args = unquote(task[2]) if len(task) > 2 else ()
69            kwargs = unquote(task[3]) if len(task) > 3 else {}
70        else:
71            if hasattr(func, "funcs"):
72                head = "(".join(funcname(f) for f in func.funcs)
73                tail = ")" * len(func.funcs)
74            else:
75                head = funcname(task[0])
76                tail = ")"
77            args = task[1:]
78            kwargs = {}
79        if args or kwargs:
80            label_size2 = int(
81                (label_size - len(head) - len(tail)) // (len(args) + len(kwargs))
82            )
83            pprint = lambda t: pprint_task(t, keys, label_size2)
84        if args:
85            if label_size2 > 5:
86                args = ", ".join(pprint(t) for t in args)
87            else:
88                args = "..."
89        else:
90            args = ""
91        if kwargs:
92            if label_size2 > 5:
93                kwargs = ", " + ", ".join(
94                    f"{k}={pprint(v)}" for k, v in sorted(kwargs.items())
95                )
96            else:
97                kwargs = ", ..."
98        else:
99            kwargs = ""
100        return f"{head}({args}{kwargs}{tail}"
101    elif isinstance(task, list):
102        if not task:
103            return "[]"
104        elif len(task) > 3:
105            result = pprint_task(task[:3], keys, label_size)
106            return result[:-1] + ", ...]"
107        else:
108            label_size2 = int((label_size - 2 - 2 * len(task)) // len(task))
109            args = ", ".join(pprint_task(t, keys, label_size2) for t in task)
110            return f"[{args}]"
111    else:
112        try:
113            if task in keys:
114                return "_"
115            else:
116                return "*"
117        except TypeError:
118            return "*"
119
120
121def get_colors(palette, funcs):
122    """Get a dict mapping funcs to colors from palette.
123
124    Parameters
125    ----------
126    palette : string
127        Name of the bokeh palette to use, must be a member of
128        bokeh.palettes.all_palettes.
129    funcs : iterable
130        Iterable of function names
131    """
132    palettes = import_required("bokeh.palettes", _BOKEH_MISSING_MSG)
133
134    unique_funcs = sorted(unique(funcs))
135    n_funcs = len(unique_funcs)
136    palette_lookup = palettes.all_palettes[palette]
137    keys = list(sorted(palette_lookup.keys()))
138    index = keys[min(bisect_left(keys, n_funcs), len(keys) - 1)]
139    palette = palette_lookup[index]
140    # Some bokeh palettes repeat colors, we want just the unique set
141    palette = list(unique(palette))
142    if len(palette) > n_funcs:
143        # Consistently shuffle palette - prevents just using low-range
144        random.Random(42).shuffle(palette)
145    color_lookup = dict(zip(unique_funcs, cycle(palette)))
146    return [color_lookup[n] for n in funcs]
147
148
149def visualize(
150    profilers, filename="profile.html", show=True, save=None, mode=None, **kwargs
151):
152    """Visualize the results of profiling in a bokeh plot.
153
154    If multiple profilers are passed in, the plots are stacked vertically.
155
156    Parameters
157    ----------
158    profilers : profiler or list
159        Profiler or list of profilers.
160    filename : string, optional
161        Name of the plot output file.
162    show : boolean, optional
163        If True (default), the plot is opened in a browser.
164    save : boolean, optional
165        If True (default when not in notebook), the plot is saved to disk.
166    mode : str, optional
167        Mode passed to bokeh.output_file()
168    **kwargs
169        Other keyword arguments, passed to bokeh.figure. These will override
170        all defaults set by visualize.
171
172    Returns
173    -------
174    The completed bokeh plot object.
175    """
176    bp = import_required("bokeh.plotting", _BOKEH_MISSING_MSG)
177    from bokeh.io import state
178
179    if "file_path" in kwargs:
180        warnings.warn(
181            "The file_path keyword argument is deprecated "
182            "and will be removed in a future release. "
183            "Please use filename instead.",
184            category=FutureWarning,
185            stacklevel=2,
186        )
187        filename = kwargs.pop("file_path")
188
189    if save is None:
190        save = not state.curstate().notebook
191
192    if not isinstance(profilers, list):
193        profilers = [profilers]
194    figs = [prof._plot(**kwargs) for prof in profilers]
195    # Stack the plots
196    if len(figs) == 1:
197        p = figs[0]
198    else:
199        top = figs[0]
200        for f in figs[1:]:
201            f.x_range = top.x_range
202            f.title = None
203            f.min_border_top = 20
204            f.plot_height -= 30
205        for f in figs[:-1]:
206            f.xaxis.axis_label = None
207            f.min_border_bottom = 20
208            f.plot_height -= 30
209        for f in figs:
210            f.min_border_left = 75
211            f.min_border_right = 75
212        p = bp.gridplot([[f] for f in figs])
213    if show:
214        bp.show(p)
215    if save:
216        bp.output_file(filename, mode=mode)
217        bp.save(p)
218    return p
219
220
221def plot_tasks(results, dsk, palette="Viridis", label_size=60, **kwargs):
222    """Visualize the results of profiling in a bokeh plot.
223
224    Parameters
225    ----------
226    results : sequence
227        Output of Profiler.results
228    dsk : dict
229        The dask graph being profiled.
230    palette : string, optional
231        Name of the bokeh palette to use, must be a member of
232        bokeh.palettes.all_palettes.
233    label_size: int (optional)
234        Maximum size of output labels in plot, defaults to 60
235    **kwargs
236        Other keyword arguments, passed to bokeh.figure. These will override
237        all defaults set by visualize.
238
239    Returns
240    -------
241    The completed bokeh plot object.
242    """
243    bp = import_required("bokeh.plotting", _BOKEH_MISSING_MSG)
244    from bokeh.models import HoverTool
245
246    defaults = dict(
247        title="Profile Results",
248        tools="hover,save,reset,xwheel_zoom,xpan",
249        toolbar_location="above",
250        width=800,
251        height=300,
252    )
253    # Support plot_width and plot_height for backwards compatibility
254    if "plot_width" in kwargs:
255        kwargs["width"] = kwargs.pop("plot_width")
256    if "plot_height" in kwargs:
257        kwargs["height"] = kwargs.pop("plot_height")
258    defaults.update(**kwargs)
259
260    if results:
261        keys, tasks, starts, ends, ids = zip(*results)
262
263        id_group = groupby(itemgetter(4), results)
264        timings = {
265            k: [i.end_time - i.start_time for i in v] for (k, v) in id_group.items()
266        }
267        id_lk = {
268            t[0]: n
269            for (n, t) in enumerate(
270                sorted(timings.items(), key=itemgetter(1), reverse=True)
271            )
272        }
273
274        left = min(starts)
275        right = max(ends)
276
277        p = bp.figure(
278            y_range=[str(i) for i in range(len(id_lk))],
279            x_range=[0, right - left],
280            **defaults,
281        )
282
283        data = {}
284        data["width"] = width = [e - s for (s, e) in zip(starts, ends)]
285        data["x"] = [w / 2 + s - left for (w, s) in zip(width, starts)]
286        data["y"] = [id_lk[i] + 1 for i in ids]
287        data["function"] = funcs = [pprint_task(i, dsk, label_size) for i in tasks]
288        data["color"] = get_colors(palette, funcs)
289        data["key"] = [str(i) for i in keys]
290
291        source = bp.ColumnDataSource(data=data)
292
293        p.rect(
294            source=source,
295            x="x",
296            y="y",
297            height=1,
298            width="width",
299            color="color",
300            line_color="gray",
301        )
302    else:
303        p = bp.figure(y_range=[str(i) for i in range(8)], x_range=[0, 10], **defaults)
304    p.grid.grid_line_color = None
305    p.axis.axis_line_color = None
306    p.axis.major_tick_line_color = None
307    p.yaxis.axis_label = "Worker ID"
308    p.xaxis.axis_label = "Time (s)"
309
310    hover = p.select(HoverTool)
311    hover.tooltips = """
312    <div>
313        <span style="font-size: 14px; font-weight: bold;">Key:</span>&nbsp;
314        <span style="font-size: 10px; font-family: Monaco, monospace;">@key</span>
315    </div>
316    <div>
317        <span style="font-size: 14px; font-weight: bold;">Task:</span>&nbsp;
318        <span style="font-size: 10px; font-family: Monaco, monospace;">@function</span>
319    </div>
320    """
321    hover.point_policy = "follow_mouse"
322
323    return p
324
325
326def plot_resources(results, palette="Viridis", **kwargs):
327    """Plot resource usage in a bokeh plot.
328
329    Parameters
330    ----------
331    results : sequence
332        Output of ResourceProfiler.results
333    palette : string, optional
334        Name of the bokeh palette to use, must be a member of
335        bokeh.palettes.all_palettes.
336    **kwargs
337        Other keyword arguments, passed to bokeh.figure. These will override
338        all defaults set by plot_resources.
339
340    Returns
341    -------
342    The completed bokeh plot object.
343    """
344    bp = import_required("bokeh.plotting", _BOKEH_MISSING_MSG)
345    import bokeh
346    from bokeh import palettes
347    from bokeh.models import LinearAxis, Range1d
348
349    defaults = dict(
350        title="Profile Results",
351        tools="save,reset,xwheel_zoom,xpan",
352        toolbar_location="above",
353        width=800,
354        height=300,
355    )
356    # Support plot_width and plot_height for backwards compatibility
357    if "plot_width" in kwargs:
358        kwargs["width"] = kwargs.pop("plot_width")
359    if "plot_height" in kwargs:
360        kwargs["height"] = kwargs.pop("plot_height")
361
362    # Drop `label_size` to match `plot_cache` and `plot_tasks` kwargs
363    if "label_size" in kwargs:
364        kwargs.pop("label_size")
365
366    defaults.update(**kwargs)
367
368    if results:
369        t, mem, cpu = zip(*results)
370        left, right = min(t), max(t)
371        t = [i - left for i in t]
372        p = bp.figure(
373            y_range=fix_bounds(0, max(cpu), 100),
374            x_range=fix_bounds(0, right - left, 1),
375            **defaults,
376        )
377    else:
378        t = mem = cpu = []
379        p = bp.figure(y_range=(0, 100), x_range=(0, 1), **defaults)
380    colors = palettes.all_palettes[palette][6]
381    p.line(
382        t,
383        cpu,
384        color=colors[0],
385        line_width=4,
386        **{
387            "legend_label"
388            if parse_version(bokeh.__version__) >= parse_version("1.4")
389            else "legend": "% CPU"
390        },
391    )
392    p.yaxis.axis_label = "% CPU"
393    p.extra_y_ranges = {
394        "memory": Range1d(
395            *fix_bounds(min(mem) if mem else 0, max(mem) if mem else 100, 100)
396        )
397    }
398    p.line(
399        t,
400        mem,
401        color=colors[2],
402        y_range_name="memory",
403        line_width=4,
404        **{
405            "legend_label"
406            if parse_version(bokeh.__version__) >= parse_version("1.4")
407            else "legend": "Memory"
408        },
409    )
410    p.add_layout(LinearAxis(y_range_name="memory", axis_label="Memory (MB)"), "right")
411    p.xaxis.axis_label = "Time (s)"
412    return p
413
414
415def fix_bounds(start, end, min_span):
416    """Adjust end point to ensure span of at least `min_span`"""
417    return start, max(end, start + min_span)
418
419
420def plot_cache(
421    results, dsk, start_time, metric_name, palette="Viridis", label_size=60, **kwargs
422):
423    """Visualize the results of profiling in a bokeh plot.
424
425    Parameters
426    ----------
427    results : sequence
428        Output of CacheProfiler.results
429    dsk : dict
430        The dask graph being profiled.
431    start_time : float
432        Start time of the profile.
433    metric_name : string
434        Metric used to measure cache size
435    palette : string, optional
436        Name of the bokeh palette to use, must be a member of
437        bokeh.palettes.all_palettes.
438    label_size: int (optional)
439        Maximum size of output labels in plot, defaults to 60
440    **kwargs
441        Other keyword arguments, passed to bokeh.figure. These will override
442        all defaults set by visualize.
443
444    Returns
445    -------
446    The completed bokeh plot object.
447    """
448    bp = import_required("bokeh.plotting", _BOKEH_MISSING_MSG)
449    from bokeh.models import HoverTool
450
451    defaults = dict(
452        title="Profile Results",
453        tools="hover,save,reset,wheel_zoom,xpan",
454        toolbar_location="above",
455        width=800,
456        height=300,
457    )
458    # Support plot_width and plot_height for backwards compatibility
459    if "plot_width" in kwargs:
460        kwargs["width"] = kwargs.pop("plot_width")
461    if "plot_height" in kwargs:
462        kwargs["height"] = kwargs.pop("plot_height")
463    defaults.update(**kwargs)
464
465    if results:
466        starts, ends = list(zip(*results))[3:]
467        tics = sorted(unique(starts + ends))
468        groups = groupby(lambda d: pprint_task(d[1], dsk, label_size), results)
469        data = {}
470        for k, vals in groups.items():
471            cnts = dict.fromkeys(tics, 0)
472            for v in vals:
473                cnts[v.cache_time] += v.metric
474                cnts[v.free_time] -= v.metric
475            data[k] = [0] + list(accumulate(add, pluck(1, sorted(cnts.items()))))
476
477        tics = [0] + [i - start_time for i in tics]
478        p = bp.figure(x_range=[0, max(tics)], **defaults)
479
480        for (key, val), color in zip(data.items(), get_colors(palette, data.keys())):
481            p.line(
482                "x",
483                "y",
484                line_color=color,
485                line_width=3,
486                source=bp.ColumnDataSource(
487                    {"x": tics, "y": val, "label": [key for i in val]}
488                ),
489            )
490
491    else:
492        p = bp.figure(y_range=[0, 10], x_range=[0, 10], **defaults)
493    p.yaxis.axis_label = f"Cache Size ({metric_name})"
494    p.xaxis.axis_label = "Time (s)"
495
496    hover = p.select(HoverTool)
497    hover.tooltips = """
498    <div>
499        <span style="font-size: 14px; font-weight: bold;">Task:</span>&nbsp;
500        <span style="font-size: 10px; font-family: Monaco, monospace;">@label</span>
501    </div>
502    """
503    return p
504