1# -*- coding: utf-8 -*-
2"""Utilities for debugging memory usage, blocking calls, etc."""
3from __future__ import absolute_import, print_function, unicode_literals
4
5import os
6import sys
7import traceback
8from contextlib import contextmanager
9from functools import partial
10from pprint import pprint
11
12from celery.five import WhateverIO, items, range
13from celery.platforms import signals
14
15try:
16    from psutil import Process
17except ImportError:
18    Process = None  # noqa
19
20__all__ = (
21    'blockdetection', 'sample_mem', 'memdump', 'sample',
22    'humanbytes', 'mem_rss', 'ps', 'cry',
23)
24
25UNITS = (
26    (2 ** 40.0, 'TB'),
27    (2 ** 30.0, 'GB'),
28    (2 ** 20.0, 'MB'),
29    (2 ** 10.0, 'KB'),
30    (0.0, 'b'),
31)
32
33_process = None
34_mem_sample = []
35
36
37def _on_blocking(signum, frame):
38    import inspect
39    raise RuntimeError(
40        'Blocking detection timed-out at: {0}'.format(
41            inspect.getframeinfo(frame)
42        )
43    )
44
45
46@contextmanager
47def blockdetection(timeout):
48    """Context that raises an exception if process is blocking.
49
50    Uses ``SIGALRM`` to detect blocking functions.
51    """
52    if not timeout:
53        yield
54    else:
55        old_handler = signals['ALRM']
56        old_handler = None if old_handler == _on_blocking else old_handler
57
58        signals['ALRM'] = _on_blocking
59
60        try:
61            yield signals.arm_alarm(timeout)
62        finally:
63            if old_handler:
64                signals['ALRM'] = old_handler
65            signals.reset_alarm()
66
67
68def sample_mem():
69    """Sample RSS memory usage.
70
71    Statistics can then be output by calling :func:`memdump`.
72    """
73    current_rss = mem_rss()
74    _mem_sample.append(current_rss)
75    return current_rss
76
77
78def _memdump(samples=10):  # pragma: no cover
79    S = _mem_sample
80    prev = list(S) if len(S) <= samples else sample(S, samples)
81    _mem_sample[:] = []
82    import gc
83    gc.collect()
84    after_collect = mem_rss()
85    return prev, after_collect
86
87
88def memdump(samples=10, file=None):  # pragma: no cover
89    """Dump memory statistics.
90
91    Will print a sample of all RSS memory samples added by
92    calling :func:`sample_mem`, and in addition print
93    used RSS memory after :func:`gc.collect`.
94    """
95    say = partial(print, file=file)
96    if ps() is None:
97        say('- rss: (psutil not installed).')
98        return
99    prev, after_collect = _memdump(samples)
100    if prev:
101        say('- rss (sample):')
102        for mem in prev:
103            say('-    > {0},'.format(mem))
104    say('- rss (end): {0}.'.format(after_collect))
105
106
107def sample(x, n, k=0):
108    """Given a list `x` a sample of length ``n`` of that list is returned.
109
110    For example, if `n` is 10, and `x` has 100 items, a list of every tenth.
111    item is returned.
112
113    ``k`` can be used as offset.
114    """
115    j = len(x) // n
116    for _ in range(n):
117        try:
118            yield x[k]
119        except IndexError:
120            break
121        k += j
122
123
124def hfloat(f, p=5):
125    """Convert float to value suitable for humans.
126
127    Arguments:
128        f (float): The floating point number.
129        p (int): Floating point precision (default is 5).
130    """
131    i = int(f)
132    return i if i == f else '{0:.{p}}'.format(f, p=p)
133
134
135def humanbytes(s):
136    """Convert bytes to human-readable form (e.g., KB, MB)."""
137    return next(
138        '{0}{1}'.format(hfloat(s / div if div else s), unit)
139        for div, unit in UNITS if s >= div
140    )
141
142
143def mem_rss():
144    """Return RSS memory usage as a humanized string."""
145    p = ps()
146    if p is not None:
147        return humanbytes(_process_memory_info(p).rss)
148
149
150def ps():  # pragma: no cover
151    """Return the global :class:`psutil.Process` instance.
152
153    Note:
154        Returns :const:`None` if :pypi:`psutil` is not installed.
155    """
156    global _process
157    if _process is None and Process is not None:
158        _process = Process(os.getpid())
159    return _process
160
161
162def _process_memory_info(process):
163    try:
164        return process.memory_info()
165    except AttributeError:
166        return process.get_memory_info()
167
168
169def cry(out=None, sepchr='=', seplen=49):  # pragma: no cover
170    """Return stack-trace of all active threads.
171
172    See Also:
173        Taken from https://gist.github.com/737056.
174    """
175    import threading
176
177    out = WhateverIO() if out is None else out
178    P = partial(print, file=out)
179
180    # get a map of threads by their ID so we can print their names
181    # during the traceback dump
182    tmap = {t.ident: t for t in threading.enumerate()}
183
184    sep = sepchr * seplen
185    for tid, frame in items(sys._current_frames()):
186        thread = tmap.get(tid)
187        if not thread:
188            # skip old junk (left-overs from a fork)
189            continue
190        P('{0.name}'.format(thread))
191        P(sep)
192        traceback.print_stack(frame, file=out)
193        P(sep)
194        P('LOCAL VARIABLES')
195        P(sep)
196        pprint(frame.f_locals, stream=out)
197        P('\n')
198    return out.getvalue()
199