1# -*- coding: utf-8 -*- 2"""Utilities for debugging memory usage, blocking calls, etc.""" 3from __future__ import absolute_import, print_function, unicode_literals 4 5import os 6import sys 7import traceback 8from contextlib import contextmanager 9from functools import partial 10from pprint import pprint 11 12from celery.five import WhateverIO, items, range 13from celery.platforms import signals 14 15try: 16 from psutil import Process 17except ImportError: 18 Process = None # noqa 19 20__all__ = ( 21 'blockdetection', 'sample_mem', 'memdump', 'sample', 22 'humanbytes', 'mem_rss', 'ps', 'cry', 23) 24 25UNITS = ( 26 (2 ** 40.0, 'TB'), 27 (2 ** 30.0, 'GB'), 28 (2 ** 20.0, 'MB'), 29 (2 ** 10.0, 'KB'), 30 (0.0, 'b'), 31) 32 33_process = None 34_mem_sample = [] 35 36 37def _on_blocking(signum, frame): 38 import inspect 39 raise RuntimeError( 40 'Blocking detection timed-out at: {0}'.format( 41 inspect.getframeinfo(frame) 42 ) 43 ) 44 45 46@contextmanager 47def blockdetection(timeout): 48 """Context that raises an exception if process is blocking. 49 50 Uses ``SIGALRM`` to detect blocking functions. 51 """ 52 if not timeout: 53 yield 54 else: 55 old_handler = signals['ALRM'] 56 old_handler = None if old_handler == _on_blocking else old_handler 57 58 signals['ALRM'] = _on_blocking 59 60 try: 61 yield signals.arm_alarm(timeout) 62 finally: 63 if old_handler: 64 signals['ALRM'] = old_handler 65 signals.reset_alarm() 66 67 68def sample_mem(): 69 """Sample RSS memory usage. 70 71 Statistics can then be output by calling :func:`memdump`. 72 """ 73 current_rss = mem_rss() 74 _mem_sample.append(current_rss) 75 return current_rss 76 77 78def _memdump(samples=10): # pragma: no cover 79 S = _mem_sample 80 prev = list(S) if len(S) <= samples else sample(S, samples) 81 _mem_sample[:] = [] 82 import gc 83 gc.collect() 84 after_collect = mem_rss() 85 return prev, after_collect 86 87 88def memdump(samples=10, file=None): # pragma: no cover 89 """Dump memory statistics. 90 91 Will print a sample of all RSS memory samples added by 92 calling :func:`sample_mem`, and in addition print 93 used RSS memory after :func:`gc.collect`. 94 """ 95 say = partial(print, file=file) 96 if ps() is None: 97 say('- rss: (psutil not installed).') 98 return 99 prev, after_collect = _memdump(samples) 100 if prev: 101 say('- rss (sample):') 102 for mem in prev: 103 say('- > {0},'.format(mem)) 104 say('- rss (end): {0}.'.format(after_collect)) 105 106 107def sample(x, n, k=0): 108 """Given a list `x` a sample of length ``n`` of that list is returned. 109 110 For example, if `n` is 10, and `x` has 100 items, a list of every tenth. 111 item is returned. 112 113 ``k`` can be used as offset. 114 """ 115 j = len(x) // n 116 for _ in range(n): 117 try: 118 yield x[k] 119 except IndexError: 120 break 121 k += j 122 123 124def hfloat(f, p=5): 125 """Convert float to value suitable for humans. 126 127 Arguments: 128 f (float): The floating point number. 129 p (int): Floating point precision (default is 5). 130 """ 131 i = int(f) 132 return i if i == f else '{0:.{p}}'.format(f, p=p) 133 134 135def humanbytes(s): 136 """Convert bytes to human-readable form (e.g., KB, MB).""" 137 return next( 138 '{0}{1}'.format(hfloat(s / div if div else s), unit) 139 for div, unit in UNITS if s >= div 140 ) 141 142 143def mem_rss(): 144 """Return RSS memory usage as a humanized string.""" 145 p = ps() 146 if p is not None: 147 return humanbytes(_process_memory_info(p).rss) 148 149 150def ps(): # pragma: no cover 151 """Return the global :class:`psutil.Process` instance. 152 153 Note: 154 Returns :const:`None` if :pypi:`psutil` is not installed. 155 """ 156 global _process 157 if _process is None and Process is not None: 158 _process = Process(os.getpid()) 159 return _process 160 161 162def _process_memory_info(process): 163 try: 164 return process.memory_info() 165 except AttributeError: 166 return process.get_memory_info() 167 168 169def cry(out=None, sepchr='=', seplen=49): # pragma: no cover 170 """Return stack-trace of all active threads. 171 172 See Also: 173 Taken from https://gist.github.com/737056. 174 """ 175 import threading 176 177 out = WhateverIO() if out is None else out 178 P = partial(print, file=out) 179 180 # get a map of threads by their ID so we can print their names 181 # during the traceback dump 182 tmap = {t.ident: t for t in threading.enumerate()} 183 184 sep = sepchr * seplen 185 for tid, frame in items(sys._current_frames()): 186 thread = tmap.get(tid) 187 if not thread: 188 # skip old junk (left-overs from a fork) 189 continue 190 P('{0.name}'.format(thread)) 191 P(sep) 192 traceback.print_stack(frame, file=out) 193 P(sep) 194 P('LOCAL VARIABLES') 195 P(sep) 196 pprint(frame.f_locals, stream=out) 197 P('\n') 198 return out.getvalue() 199