1""" This module contains utility functions to construct and manipulate counting
2data structures for frames.
3
4When performing statistical profiling we obtain many call stacks.  We aggregate
5these call stacks into data structures that maintain counts of how many times
6each function in that call stack has been called.  Because these stacks will
7overlap this aggregation counting structure forms a tree, such as is commonly
8visualized by profiling tools.
9
10We represent this tree as a nested dictionary with the following form:
11
12    {
13     'identifier': 'root',
14     'description': 'A long description of the line of code being run.',
15     'count': 10  # the number of times we have seen this line
16     'children': {  # callers of this line. Recursive dicts
17         'ident-b': {'description': ...
18                   'identifier': 'ident-a',
19                   'count': ...
20                   'children': {...}},
21         'ident-b': {'description': ...
22                   'identifier': 'ident-b',
23                   'count': ...
24                   'children': {...}}}
25    }
26"""
27from __future__ import annotations
28
29import bisect
30import linecache
31import sys
32import threading
33from collections import defaultdict, deque
34from time import sleep
35from typing import Any
36
37import tlz as toolz
38
39from dask.utils import format_time, parse_timedelta
40
41from .metrics import time
42from .utils import color_of
43
44
45def identifier(frame):
46    """A string identifier from a frame
47
48    Strings are cheaper to use as indexes into dicts than tuples or dicts
49    """
50    if frame is None:
51        return "None"
52    else:
53        return ";".join(
54            (
55                frame.f_code.co_name,
56                frame.f_code.co_filename,
57                str(frame.f_code.co_firstlineno),
58            )
59        )
60
61
62def repr_frame(frame):
63    """Render a frame as a line for inclusion into a text traceback"""
64    co = frame.f_code
65    text = f'  File "{co.co_filename}", line {frame.f_lineno}, in {co.co_name}'
66    line = linecache.getline(co.co_filename, frame.f_lineno, frame.f_globals).lstrip()
67    return text + "\n\t" + line
68
69
70def info_frame(frame):
71    co = frame.f_code
72    line = linecache.getline(co.co_filename, frame.f_lineno, frame.f_globals).lstrip()
73    return {
74        "filename": co.co_filename,
75        "name": co.co_name,
76        "line_number": frame.f_lineno,
77        "line": line,
78    }
79
80
81def process(frame, child, state, stop=None, omit=None):
82    """Add counts from a frame stack onto existing state
83
84    This recursively adds counts to the existing state dictionary and creates
85    new entries for new functions.
86
87    Examples
88    --------
89    >>> import sys, threading
90    >>> ident = threading.get_ident()  # replace with your thread of interest
91    >>> frame = sys._current_frames()[ident]
92    >>> state = {'children': {}, 'count': 0, 'description': 'root',
93    ...          'identifier': 'root'}
94    >>> process(frame, None, state)
95    >>> state
96    {'count': 1,
97     'identifier': 'root',
98     'description': 'root',
99     'children': {'...'}}
100    """
101    if omit is not None and any(frame.f_code.co_filename.endswith(o) for o in omit):
102        return False
103
104    prev = frame.f_back
105    if prev is not None and (
106        stop is None or not prev.f_code.co_filename.endswith(stop)
107    ):
108        state = process(prev, frame, state, stop=stop)
109        if state is False:
110            return False
111
112    ident = identifier(frame)
113
114    try:
115        d = state["children"][ident]
116    except KeyError:
117        d = {
118            "count": 0,
119            "description": info_frame(frame),
120            "children": {},
121            "identifier": ident,
122        }
123        state["children"][ident] = d
124
125    state["count"] += 1
126
127    if child is not None:
128        return d
129    else:
130        d["count"] += 1
131
132
133def merge(*args):
134    """Merge multiple frame states together"""
135    if not args:
136        return create()
137    s = {arg["identifier"] for arg in args}
138    if len(s) != 1:
139        raise ValueError("Expected identifiers, got %s" % str(s))
140    children = defaultdict(list)
141    for arg in args:
142        for child in arg["children"]:
143            children[child].append(arg["children"][child])
144
145    try:
146        children = {k: merge(*v) for k, v in children.items()}
147    except RecursionError:
148        children = {}
149    count = sum(arg["count"] for arg in args)
150    return {
151        "description": args[0]["description"],
152        "children": dict(children),
153        "count": count,
154        "identifier": args[0]["identifier"],
155    }
156
157
158def create() -> dict[str, Any]:
159    return {
160        "count": 0,
161        "children": {},
162        "identifier": "root",
163        "description": {"filename": "", "name": "", "line_number": 0, "line": ""},
164    }
165
166
167def call_stack(frame):
168    """Create a call text stack from a frame
169
170    Returns
171    -------
172    list of strings
173    """
174    L = []
175    while frame:
176        L.append(repr_frame(frame))
177        frame = frame.f_back
178    return L[::-1]
179
180
181def plot_data(state, profile_interval=0.010):
182    """Convert a profile state into data useful by Bokeh
183
184    See Also
185    --------
186    plot_figure
187    distributed.bokeh.components.ProfilePlot
188    """
189    starts = []
190    stops = []
191    heights = []
192    widths = []
193    colors = []
194    states = []
195    times = []
196
197    filenames = []
198    lines = []
199    line_numbers = []
200    names = []
201
202    def traverse(state, start, stop, height):
203        if not state["count"]:
204            return
205        starts.append(start)
206        stops.append(stop)
207        heights.append(height)
208        width = stop - start
209        widths.append(width)
210        states.append(state)
211        times.append(format_time(state["count"] * profile_interval))
212
213        desc = state["description"]
214        filenames.append(desc["filename"])
215        lines.append(desc["line"])
216        line_numbers.append(desc["line_number"])
217        names.append(desc["name"])
218
219        try:
220            fn = desc["filename"]
221        except IndexError:
222            colors.append("gray")
223        else:
224            if fn == "<low-level>":
225                colors.append("lightgray")
226            else:
227                colors.append(color_of(fn))
228
229        delta = (stop - start) / state["count"]
230
231        x = start
232
233        for _, child in state["children"].items():
234            width = child["count"] * delta
235            traverse(child, x, x + width, height + 1)
236            x += width
237
238    traverse(state, 0, 1, 0)
239    percentages = [f"{100 * w:.1f}%" for w in widths]
240    return {
241        "left": starts,
242        "right": stops,
243        "bottom": heights,
244        "width": widths,
245        "top": [x + 1 for x in heights],
246        "color": colors,
247        "states": states,
248        "filename": filenames,
249        "line": lines,
250        "line_number": line_numbers,
251        "name": names,
252        "time": times,
253        "percentage": percentages,
254    }
255
256
257def _watch(thread_id, log, interval="20ms", cycle="2s", omit=None, stop=lambda: False):
258    interval = parse_timedelta(interval)
259    cycle = parse_timedelta(cycle)
260
261    recent = create()
262    last = time()
263
264    while not stop():
265        if time() > last + cycle:
266            log.append((time(), recent))
267            recent = create()
268            last = time()
269        try:
270            frame = sys._current_frames()[thread_id]
271        except KeyError:
272            return
273
274        process(frame, None, recent, omit=omit)
275        sleep(interval)
276
277
278def watch(
279    thread_id=None,
280    interval="20ms",
281    cycle="2s",
282    maxlen=1000,
283    omit=None,
284    stop=lambda: False,
285):
286    """Gather profile information on a particular thread
287
288    This starts a new thread to watch a particular thread and returns a deque
289    that holds periodic profile information.
290
291    Parameters
292    ----------
293    thread_id : int
294    interval : str
295        Time per sample
296    cycle : str
297        Time per refreshing to a new profile state
298    maxlen : int
299        Passed onto deque, maximum number of periods
300    omit : str
301        Don't include entries that start with this filename
302    stop : callable
303        Function to call to see if we should stop
304
305    Returns
306    -------
307    deque
308    """
309    if thread_id is None:
310        thread_id = threading.get_ident()
311
312    log = deque(maxlen=maxlen)
313
314    thread = threading.Thread(
315        target=_watch,
316        name="Profile",
317        kwargs={
318            "thread_id": thread_id,
319            "interval": interval,
320            "cycle": cycle,
321            "log": log,
322            "omit": omit,
323            "stop": stop,
324        },
325    )
326    thread.daemon = True
327    thread.start()
328
329    return log
330
331
332def get_profile(history, recent=None, start=None, stop=None, key=None):
333    """Collect profile information from a sequence of profile states
334
335    Parameters
336    ----------
337    history : Sequence[Tuple[time, Dict]]
338        A list or deque of profile states
339    recent : dict
340        The most recent accumulating state
341    start : time
342    stop : time
343    """
344    if start is None:
345        istart = 0
346    else:
347        istart = bisect.bisect_left(history, (start,))
348
349    if stop is None:
350        istop = None
351    else:
352        istop = bisect.bisect_right(history, (stop,)) + 1
353        if istop >= len(history):
354            istop = None  # include end
355
356    if istart == 0 and istop is None:
357        history = list(history)
358    else:
359        iistop = len(history) if istop is None else istop
360        history = [history[i] for i in range(istart, iistop)]
361
362    prof = merge(*toolz.pluck(1, history))
363
364    if not history:
365        return create()
366
367    if recent:
368        prof = merge(prof, recent)
369
370    return prof
371
372
373def plot_figure(data, **kwargs):
374    """Plot profile data using Bokeh
375
376    This takes the output from the function ``plot_data`` and produces a Bokeh
377    figure
378
379    See Also
380    --------
381    plot_data
382    """
383    from bokeh.models import HoverTool
384    from bokeh.plotting import ColumnDataSource, figure
385
386    if "states" in data:
387        data = toolz.dissoc(data, "states")
388
389    source = ColumnDataSource(data=data)
390
391    fig = figure(tools="tap,box_zoom,xwheel_zoom,reset", **kwargs)
392    r = fig.quad(
393        "left",
394        "right",
395        "top",
396        "bottom",
397        color="color",
398        line_color="black",
399        line_width=2,
400        source=source,
401    )
402
403    r.selection_glyph = None
404    r.nonselection_glyph = None
405
406    hover = HoverTool(
407        point_policy="follow_mouse",
408        tooltips="""
409            <div>
410                <span style="font-size: 14px; font-weight: bold;">Name:</span>&nbsp;
411                <span style="font-size: 10px; font-family: Monaco, monospace;">@name</span>
412            </div>
413            <div>
414                <span style="font-size: 14px; font-weight: bold;">Filename:</span>&nbsp;
415                <span style="font-size: 10px; font-family: Monaco, monospace;">@filename</span>
416            </div>
417            <div>
418                <span style="font-size: 14px; font-weight: bold;">Line number:</span>&nbsp;
419                <span style="font-size: 10px; font-family: Monaco, monospace;">@line_number</span>
420            </div>
421            <div>
422                <span style="font-size: 14px; font-weight: bold;">Line:</span>&nbsp;
423                <span style="font-size: 10px; font-family: Monaco, monospace;">@line</span>
424            </div>
425            <div>
426                <span style="font-size: 14px; font-weight: bold;">Time:</span>&nbsp;
427                <span style="font-size: 10px; font-family: Monaco, monospace;">@time</span>
428            </div>
429            <div>
430                <span style="font-size: 14px; font-weight: bold;">Percentage:</span>&nbsp;
431                <span style="font-size: 10px; font-family: Monaco, monospace;">@percentage</span>
432            </div>
433            """,
434    )
435    fig.add_tools(hover)
436
437    fig.xaxis.visible = False
438    fig.yaxis.visible = False
439    fig.grid.visible = False
440
441    return fig, source
442
443
444def _remove_py_stack(frames):
445    for entry in frames:
446        if entry.is_python:
447            break
448        yield entry
449
450
451def llprocess(frames, child, state):
452    """Add counts from low level profile information onto existing state
453
454    This uses the ``stacktrace`` module to collect low level stack trace
455    information and place it onto the given sttate.
456
457    It is configured with the ``distributed.worker.profile.low-level`` config
458    entry.
459
460    See Also
461    --------
462    process
463    ll_get_stack
464    """
465    if not frames:
466        return
467    frame = frames.pop()
468    if frames:
469        state = llprocess(frames, frame, state)
470
471    addr = hex(frame.addr - frame.offset)
472    ident = ";".join(map(str, (frame.name, "<low-level>", addr)))
473    try:
474        d = state["children"][ident]
475    except KeyError:
476        d = {
477            "count": 0,
478            "description": {
479                "filename": "<low-level>",
480                "name": frame.name,
481                "line_number": 0,
482                "line": str(frame),
483            },
484            "children": {},
485            "identifier": ident,
486        }
487        state["children"][ident] = d
488
489    state["count"] += 1
490
491    if child is not None:
492        return d
493    else:
494        d["count"] += 1
495
496
497def ll_get_stack(tid):
498    """Collect low level stack information from thread id"""
499    from stacktrace import get_thread_stack
500
501    frames = get_thread_stack(tid, show_python=False)
502    llframes = list(_remove_py_stack(frames))[::-1]
503    return llframes
504