1from collections.abc import Sequence, Iterable
2from functools import total_ordering
3import fnmatch
4import linecache
5import os.path
6import pickle
7
8# Import types and functions implemented in C
9from _tracemalloc import *
10from _tracemalloc import _get_object_traceback, _get_traces
11
12
13def _format_size(size, sign):
14    for unit in ('B', 'KiB', 'MiB', 'GiB', 'TiB'):
15        if abs(size) < 100 and unit != 'B':
16            # 3 digits (xx.x UNIT)
17            if sign:
18                return "%+.1f %s" % (size, unit)
19            else:
20                return "%.1f %s" % (size, unit)
21        if abs(size) < 10 * 1024 or unit == 'TiB':
22            # 4 or 5 digits (xxxx UNIT)
23            if sign:
24                return "%+.0f %s" % (size, unit)
25            else:
26                return "%.0f %s" % (size, unit)
27        size /= 1024
28
29
30class Statistic:
31    """
32    Statistic difference on memory allocations between two Snapshot instance.
33    """
34
35    __slots__ = ('traceback', 'size', 'count')
36
37    def __init__(self, traceback, size, count):
38        self.traceback = traceback
39        self.size = size
40        self.count = count
41
42    def __hash__(self):
43        return hash((self.traceback, self.size, self.count))
44
45    def __eq__(self, other):
46        return (self.traceback == other.traceback
47                and self.size == other.size
48                and self.count == other.count)
49
50    def __str__(self):
51        text = ("%s: size=%s, count=%i"
52                 % (self.traceback,
53                    _format_size(self.size, False),
54                    self.count))
55        if self.count:
56            average = self.size / self.count
57            text += ", average=%s" % _format_size(average, False)
58        return text
59
60    def __repr__(self):
61        return ('<Statistic traceback=%r size=%i count=%i>'
62                % (self.traceback, self.size, self.count))
63
64    def _sort_key(self):
65        return (self.size, self.count, self.traceback)
66
67
68class StatisticDiff:
69    """
70    Statistic difference on memory allocations between an old and a new
71    Snapshot instance.
72    """
73    __slots__ = ('traceback', 'size', 'size_diff', 'count', 'count_diff')
74
75    def __init__(self, traceback, size, size_diff, count, count_diff):
76        self.traceback = traceback
77        self.size = size
78        self.size_diff = size_diff
79        self.count = count
80        self.count_diff = count_diff
81
82    def __hash__(self):
83        return hash((self.traceback, self.size, self.size_diff,
84                     self.count, self.count_diff))
85
86    def __eq__(self, other):
87        return (self.traceback == other.traceback
88                and self.size == other.size
89                and self.size_diff == other.size_diff
90                and self.count == other.count
91                and self.count_diff == other.count_diff)
92
93    def __str__(self):
94        text = ("%s: size=%s (%s), count=%i (%+i)"
95                % (self.traceback,
96                   _format_size(self.size, False),
97                   _format_size(self.size_diff, True),
98                   self.count,
99                   self.count_diff))
100        if self.count:
101            average = self.size / self.count
102            text += ", average=%s" % _format_size(average, False)
103        return text
104
105    def __repr__(self):
106        return ('<StatisticDiff traceback=%r size=%i (%+i) count=%i (%+i)>'
107                % (self.traceback, self.size, self.size_diff,
108                   self.count, self.count_diff))
109
110    def _sort_key(self):
111        return (abs(self.size_diff), self.size,
112                abs(self.count_diff), self.count,
113                self.traceback)
114
115
116def _compare_grouped_stats(old_group, new_group):
117    statistics = []
118    for traceback, stat in new_group.items():
119        previous = old_group.pop(traceback, None)
120        if previous is not None:
121            stat = StatisticDiff(traceback,
122                                 stat.size, stat.size - previous.size,
123                                 stat.count, stat.count - previous.count)
124        else:
125            stat = StatisticDiff(traceback,
126                                 stat.size, stat.size,
127                                 stat.count, stat.count)
128        statistics.append(stat)
129
130    for traceback, stat in old_group.items():
131        stat = StatisticDiff(traceback, 0, -stat.size, 0, -stat.count)
132        statistics.append(stat)
133    return statistics
134
135
136@total_ordering
137class Frame:
138    """
139    Frame of a traceback.
140    """
141    __slots__ = ("_frame",)
142
143    def __init__(self, frame):
144        # frame is a tuple: (filename: str, lineno: int)
145        self._frame = frame
146
147    @property
148    def filename(self):
149        return self._frame[0]
150
151    @property
152    def lineno(self):
153        return self._frame[1]
154
155    def __eq__(self, other):
156        return (self._frame == other._frame)
157
158    def __lt__(self, other):
159        return (self._frame < other._frame)
160
161    def __hash__(self):
162        return hash(self._frame)
163
164    def __str__(self):
165        return "%s:%s" % (self.filename, self.lineno)
166
167    def __repr__(self):
168        return "<Frame filename=%r lineno=%r>" % (self.filename, self.lineno)
169
170
171@total_ordering
172class Traceback(Sequence):
173    """
174    Sequence of Frame instances sorted from the oldest frame
175    to the most recent frame.
176    """
177    __slots__ = ("_frames",)
178
179    def __init__(self, frames):
180        Sequence.__init__(self)
181        # frames is a tuple of frame tuples: see Frame constructor for the
182        # format of a frame tuple; it is reversed, because _tracemalloc
183        # returns frames sorted from most recent to oldest, but the
184        # Python API expects oldest to most recent
185        self._frames = tuple(reversed(frames))
186
187    def __len__(self):
188        return len(self._frames)
189
190    def __getitem__(self, index):
191        if isinstance(index, slice):
192            return tuple(Frame(trace) for trace in self._frames[index])
193        else:
194            return Frame(self._frames[index])
195
196    def __contains__(self, frame):
197        return frame._frame in self._frames
198
199    def __hash__(self):
200        return hash(self._frames)
201
202    def __eq__(self, other):
203        return (self._frames == other._frames)
204
205    def __lt__(self, other):
206        return (self._frames < other._frames)
207
208    def __str__(self):
209        return str(self[0])
210
211    def __repr__(self):
212        return "<Traceback %r>" % (tuple(self),)
213
214    def format(self, limit=None, most_recent_first=False):
215        lines = []
216        if limit is not None:
217            if limit > 0:
218                frame_slice = self[-limit:]
219            else:
220                frame_slice = self[:limit]
221        else:
222            frame_slice = self
223
224        if most_recent_first:
225            frame_slice = reversed(frame_slice)
226        for frame in frame_slice:
227            lines.append('  File "%s", line %s'
228                         % (frame.filename, frame.lineno))
229            line = linecache.getline(frame.filename, frame.lineno).strip()
230            if line:
231                lines.append('    %s' % line)
232        return lines
233
234
235def get_object_traceback(obj):
236    """
237    Get the traceback where the Python object *obj* was allocated.
238    Return a Traceback instance.
239
240    Return None if the tracemalloc module is not tracing memory allocations or
241    did not trace the allocation of the object.
242    """
243    frames = _get_object_traceback(obj)
244    if frames is not None:
245        return Traceback(frames)
246    else:
247        return None
248
249
250class Trace:
251    """
252    Trace of a memory block.
253    """
254    __slots__ = ("_trace",)
255
256    def __init__(self, trace):
257        # trace is a tuple: (domain: int, size: int, traceback: tuple).
258        # See Traceback constructor for the format of the traceback tuple.
259        self._trace = trace
260
261    @property
262    def domain(self):
263        return self._trace[0]
264
265    @property
266    def size(self):
267        return self._trace[1]
268
269    @property
270    def traceback(self):
271        return Traceback(self._trace[2])
272
273    def __eq__(self, other):
274        return (self._trace == other._trace)
275
276    def __hash__(self):
277        return hash(self._trace)
278
279    def __str__(self):
280        return "%s: %s" % (self.traceback, _format_size(self.size, False))
281
282    def __repr__(self):
283        return ("<Trace domain=%s size=%s, traceback=%r>"
284                % (self.domain, _format_size(self.size, False), self.traceback))
285
286
287class _Traces(Sequence):
288    def __init__(self, traces):
289        Sequence.__init__(self)
290        # traces is a tuple of trace tuples: see Trace constructor
291        self._traces = traces
292
293    def __len__(self):
294        return len(self._traces)
295
296    def __getitem__(self, index):
297        if isinstance(index, slice):
298            return tuple(Trace(trace) for trace in self._traces[index])
299        else:
300            return Trace(self._traces[index])
301
302    def __contains__(self, trace):
303        return trace._trace in self._traces
304
305    def __eq__(self, other):
306        return (self._traces == other._traces)
307
308    def __repr__(self):
309        return "<Traces len=%s>" % len(self)
310
311
312def _normalize_filename(filename):
313    filename = os.path.normcase(filename)
314    if filename.endswith('.pyc'):
315        filename = filename[:-1]
316    return filename
317
318
319class BaseFilter:
320    def __init__(self, inclusive):
321        self.inclusive = inclusive
322
323    def _match(self, trace):
324        raise NotImplementedError
325
326
327class Filter(BaseFilter):
328    def __init__(self, inclusive, filename_pattern,
329                 lineno=None, all_frames=False, domain=None):
330        super().__init__(inclusive)
331        self.inclusive = inclusive
332        self._filename_pattern = _normalize_filename(filename_pattern)
333        self.lineno = lineno
334        self.all_frames = all_frames
335        self.domain = domain
336
337    @property
338    def filename_pattern(self):
339        return self._filename_pattern
340
341    def _match_frame_impl(self, filename, lineno):
342        filename = _normalize_filename(filename)
343        if not fnmatch.fnmatch(filename, self._filename_pattern):
344            return False
345        if self.lineno is None:
346            return True
347        else:
348            return (lineno == self.lineno)
349
350    def _match_frame(self, filename, lineno):
351        return self._match_frame_impl(filename, lineno) ^ (not self.inclusive)
352
353    def _match_traceback(self, traceback):
354        if self.all_frames:
355            if any(self._match_frame_impl(filename, lineno)
356                   for filename, lineno in traceback):
357                return self.inclusive
358            else:
359                return (not self.inclusive)
360        else:
361            filename, lineno = traceback[0]
362            return self._match_frame(filename, lineno)
363
364    def _match(self, trace):
365        domain, size, traceback = trace
366        res = self._match_traceback(traceback)
367        if self.domain is not None:
368            if self.inclusive:
369                return res and (domain == self.domain)
370            else:
371                return res or (domain != self.domain)
372        return res
373
374
375class DomainFilter(BaseFilter):
376    def __init__(self, inclusive, domain):
377        super().__init__(inclusive)
378        self._domain = domain
379
380    @property
381    def domain(self):
382        return self._domain
383
384    def _match(self, trace):
385        domain, size, traceback = trace
386        return (domain == self.domain) ^ (not self.inclusive)
387
388
389class Snapshot:
390    """
391    Snapshot of traces of memory blocks allocated by Python.
392    """
393
394    def __init__(self, traces, traceback_limit):
395        # traces is a tuple of trace tuples: see _Traces constructor for
396        # the exact format
397        self.traces = _Traces(traces)
398        self.traceback_limit = traceback_limit
399
400    def dump(self, filename):
401        """
402        Write the snapshot into a file.
403        """
404        with open(filename, "wb") as fp:
405            pickle.dump(self, fp, pickle.HIGHEST_PROTOCOL)
406
407    @staticmethod
408    def load(filename):
409        """
410        Load a snapshot from a file.
411        """
412        with open(filename, "rb") as fp:
413            return pickle.load(fp)
414
415    def _filter_trace(self, include_filters, exclude_filters, trace):
416        if include_filters:
417            if not any(trace_filter._match(trace)
418                       for trace_filter in include_filters):
419                return False
420        if exclude_filters:
421            if any(not trace_filter._match(trace)
422                   for trace_filter in exclude_filters):
423                return False
424        return True
425
426    def filter_traces(self, filters):
427        """
428        Create a new Snapshot instance with a filtered traces sequence, filters
429        is a list of Filter or DomainFilter instances.  If filters is an empty
430        list, return a new Snapshot instance with a copy of the traces.
431        """
432        if not isinstance(filters, Iterable):
433            raise TypeError("filters must be a list of filters, not %s"
434                            % type(filters).__name__)
435        if filters:
436            include_filters = []
437            exclude_filters = []
438            for trace_filter in filters:
439                if trace_filter.inclusive:
440                    include_filters.append(trace_filter)
441                else:
442                    exclude_filters.append(trace_filter)
443            new_traces = [trace for trace in self.traces._traces
444                          if self._filter_trace(include_filters,
445                                                exclude_filters,
446                                                trace)]
447        else:
448            new_traces = self.traces._traces.copy()
449        return Snapshot(new_traces, self.traceback_limit)
450
451    def _group_by(self, key_type, cumulative):
452        if key_type not in ('traceback', 'filename', 'lineno'):
453            raise ValueError("unknown key_type: %r" % (key_type,))
454        if cumulative and key_type not in ('lineno', 'filename'):
455            raise ValueError("cumulative mode cannot by used "
456                             "with key type %r" % key_type)
457
458        stats = {}
459        tracebacks = {}
460        if not cumulative:
461            for trace in self.traces._traces:
462                domain, size, trace_traceback = trace
463                try:
464                    traceback = tracebacks[trace_traceback]
465                except KeyError:
466                    if key_type == 'traceback':
467                        frames = trace_traceback
468                    elif key_type == 'lineno':
469                        frames = trace_traceback[:1]
470                    else: # key_type == 'filename':
471                        frames = ((trace_traceback[0][0], 0),)
472                    traceback = Traceback(frames)
473                    tracebacks[trace_traceback] = traceback
474                try:
475                    stat = stats[traceback]
476                    stat.size += size
477                    stat.count += 1
478                except KeyError:
479                    stats[traceback] = Statistic(traceback, size, 1)
480        else:
481            # cumulative statistics
482            for trace in self.traces._traces:
483                domain, size, trace_traceback = trace
484                for frame in trace_traceback:
485                    try:
486                        traceback = tracebacks[frame]
487                    except KeyError:
488                        if key_type == 'lineno':
489                            frames = (frame,)
490                        else: # key_type == 'filename':
491                            frames = ((frame[0], 0),)
492                        traceback = Traceback(frames)
493                        tracebacks[frame] = traceback
494                    try:
495                        stat = stats[traceback]
496                        stat.size += size
497                        stat.count += 1
498                    except KeyError:
499                        stats[traceback] = Statistic(traceback, size, 1)
500        return stats
501
502    def statistics(self, key_type, cumulative=False):
503        """
504        Group statistics by key_type. Return a sorted list of Statistic
505        instances.
506        """
507        grouped = self._group_by(key_type, cumulative)
508        statistics = list(grouped.values())
509        statistics.sort(reverse=True, key=Statistic._sort_key)
510        return statistics
511
512    def compare_to(self, old_snapshot, key_type, cumulative=False):
513        """
514        Compute the differences with an old snapshot old_snapshot. Get
515        statistics as a sorted list of StatisticDiff instances, grouped by
516        group_by.
517        """
518        new_group = self._group_by(key_type, cumulative)
519        old_group = old_snapshot._group_by(key_type, cumulative)
520        statistics = _compare_grouped_stats(old_group, new_group)
521        statistics.sort(reverse=True, key=StatisticDiff._sort_key)
522        return statistics
523
524
525def take_snapshot():
526    """
527    Take a snapshot of traces of memory blocks allocated by Python.
528    """
529    if not is_tracing():
530        raise RuntimeError("the tracemalloc module must be tracing memory "
531                           "allocations to take a snapshot")
532    traces = _get_traces()
533    traceback_limit = get_traceback_limit()
534    return Snapshot(traces, traceback_limit)
535