1from collections.abc import Sequence, Iterable
2from functools import total_ordering
3import fnmatch
4import linecache
5import os.path
6import pickle
7
8# Import types and functions implemented in C
9from _tracemalloc import *
10from _tracemalloc import _get_object_traceback, _get_traces
11
12
13def _format_size(size, sign):
14    for unit in ('B', 'KiB', 'MiB', 'GiB', 'TiB'):
15        if abs(size) < 100 and unit != 'B':
16            # 3 digits (xx.x UNIT)
17            if sign:
18                return "%+.1f %s" % (size, unit)
19            else:
20                return "%.1f %s" % (size, unit)
21        if abs(size) < 10 * 1024 or unit == 'TiB':
22            # 4 or 5 digits (xxxx UNIT)
23            if sign:
24                return "%+.0f %s" % (size, unit)
25            else:
26                return "%.0f %s" % (size, unit)
27        size /= 1024
28
29
30class Statistic:
31    """
32    Statistic difference on memory allocations between two Snapshot instance.
33    """
34
35    __slots__ = ('traceback', 'size', 'count')
36
37    def __init__(self, traceback, size, count):
38        self.traceback = traceback
39        self.size = size
40        self.count = count
41
42    def __hash__(self):
43        return hash((self.traceback, self.size, self.count))
44
45    def __eq__(self, other):
46        if not isinstance(other, Statistic):
47            return NotImplemented
48        return (self.traceback == other.traceback
49                and self.size == other.size
50                and self.count == other.count)
51
52    def __str__(self):
53        text = ("%s: size=%s, count=%i"
54                 % (self.traceback,
55                    _format_size(self.size, False),
56                    self.count))
57        if self.count:
58            average = self.size / self.count
59            text += ", average=%s" % _format_size(average, False)
60        return text
61
62    def __repr__(self):
63        return ('<Statistic traceback=%r size=%i count=%i>'
64                % (self.traceback, self.size, self.count))
65
66    def _sort_key(self):
67        return (self.size, self.count, self.traceback)
68
69
70class StatisticDiff:
71    """
72    Statistic difference on memory allocations between an old and a new
73    Snapshot instance.
74    """
75    __slots__ = ('traceback', 'size', 'size_diff', 'count', 'count_diff')
76
77    def __init__(self, traceback, size, size_diff, count, count_diff):
78        self.traceback = traceback
79        self.size = size
80        self.size_diff = size_diff
81        self.count = count
82        self.count_diff = count_diff
83
84    def __hash__(self):
85        return hash((self.traceback, self.size, self.size_diff,
86                     self.count, self.count_diff))
87
88    def __eq__(self, other):
89        if not isinstance(other, StatisticDiff):
90            return NotImplemented
91        return (self.traceback == other.traceback
92                and self.size == other.size
93                and self.size_diff == other.size_diff
94                and self.count == other.count
95                and self.count_diff == other.count_diff)
96
97    def __str__(self):
98        text = ("%s: size=%s (%s), count=%i (%+i)"
99                % (self.traceback,
100                   _format_size(self.size, False),
101                   _format_size(self.size_diff, True),
102                   self.count,
103                   self.count_diff))
104        if self.count:
105            average = self.size / self.count
106            text += ", average=%s" % _format_size(average, False)
107        return text
108
109    def __repr__(self):
110        return ('<StatisticDiff traceback=%r size=%i (%+i) count=%i (%+i)>'
111                % (self.traceback, self.size, self.size_diff,
112                   self.count, self.count_diff))
113
114    def _sort_key(self):
115        return (abs(self.size_diff), self.size,
116                abs(self.count_diff), self.count,
117                self.traceback)
118
119
120def _compare_grouped_stats(old_group, new_group):
121    statistics = []
122    for traceback, stat in new_group.items():
123        previous = old_group.pop(traceback, None)
124        if previous is not None:
125            stat = StatisticDiff(traceback,
126                                 stat.size, stat.size - previous.size,
127                                 stat.count, stat.count - previous.count)
128        else:
129            stat = StatisticDiff(traceback,
130                                 stat.size, stat.size,
131                                 stat.count, stat.count)
132        statistics.append(stat)
133
134    for traceback, stat in old_group.items():
135        stat = StatisticDiff(traceback, 0, -stat.size, 0, -stat.count)
136        statistics.append(stat)
137    return statistics
138
139
140@total_ordering
141class Frame:
142    """
143    Frame of a traceback.
144    """
145    __slots__ = ("_frame",)
146
147    def __init__(self, frame):
148        # frame is a tuple: (filename: str, lineno: int)
149        self._frame = frame
150
151    @property
152    def filename(self):
153        return self._frame[0]
154
155    @property
156    def lineno(self):
157        return self._frame[1]
158
159    def __eq__(self, other):
160        if not isinstance(other, Frame):
161            return NotImplemented
162        return (self._frame == other._frame)
163
164    def __lt__(self, other):
165        if not isinstance(other, Frame):
166            return NotImplemented
167        return (self._frame < other._frame)
168
169    def __hash__(self):
170        return hash(self._frame)
171
172    def __str__(self):
173        return "%s:%s" % (self.filename, self.lineno)
174
175    def __repr__(self):
176        return "<Frame filename=%r lineno=%r>" % (self.filename, self.lineno)
177
178
179@total_ordering
180class Traceback(Sequence):
181    """
182    Sequence of Frame instances sorted from the oldest frame
183    to the most recent frame.
184    """
185    __slots__ = ("_frames", '_total_nframe')
186
187    def __init__(self, frames, total_nframe=None):
188        Sequence.__init__(self)
189        # frames is a tuple of frame tuples: see Frame constructor for the
190        # format of a frame tuple; it is reversed, because _tracemalloc
191        # returns frames sorted from most recent to oldest, but the
192        # Python API expects oldest to most recent
193        self._frames = tuple(reversed(frames))
194        self._total_nframe = total_nframe
195
196    @property
197    def total_nframe(self):
198        return self._total_nframe
199
200    def __len__(self):
201        return len(self._frames)
202
203    def __getitem__(self, index):
204        if isinstance(index, slice):
205            return tuple(Frame(trace) for trace in self._frames[index])
206        else:
207            return Frame(self._frames[index])
208
209    def __contains__(self, frame):
210        return frame._frame in self._frames
211
212    def __hash__(self):
213        return hash(self._frames)
214
215    def __eq__(self, other):
216        if not isinstance(other, Traceback):
217            return NotImplemented
218        return (self._frames == other._frames)
219
220    def __lt__(self, other):
221        if not isinstance(other, Traceback):
222            return NotImplemented
223        return (self._frames < other._frames)
224
225    def __str__(self):
226        return str(self[0])
227
228    def __repr__(self):
229        s = f"<Traceback {tuple(self)}"
230        if self._total_nframe is None:
231            s += ">"
232        else:
233            s += f" total_nframe={self.total_nframe}>"
234        return s
235
236    def format(self, limit=None, most_recent_first=False):
237        lines = []
238        if limit is not None:
239            if limit > 0:
240                frame_slice = self[-limit:]
241            else:
242                frame_slice = self[:limit]
243        else:
244            frame_slice = self
245
246        if most_recent_first:
247            frame_slice = reversed(frame_slice)
248        for frame in frame_slice:
249            lines.append('  File "%s", line %s'
250                         % (frame.filename, frame.lineno))
251            line = linecache.getline(frame.filename, frame.lineno).strip()
252            if line:
253                lines.append('    %s' % line)
254        return lines
255
256
257def get_object_traceback(obj):
258    """
259    Get the traceback where the Python object *obj* was allocated.
260    Return a Traceback instance.
261
262    Return None if the tracemalloc module is not tracing memory allocations or
263    did not trace the allocation of the object.
264    """
265    frames = _get_object_traceback(obj)
266    if frames is not None:
267        return Traceback(frames)
268    else:
269        return None
270
271
272class Trace:
273    """
274    Trace of a memory block.
275    """
276    __slots__ = ("_trace",)
277
278    def __init__(self, trace):
279        # trace is a tuple: (domain: int, size: int, traceback: tuple).
280        # See Traceback constructor for the format of the traceback tuple.
281        self._trace = trace
282
283    @property
284    def domain(self):
285        return self._trace[0]
286
287    @property
288    def size(self):
289        return self._trace[1]
290
291    @property
292    def traceback(self):
293        return Traceback(*self._trace[2:])
294
295    def __eq__(self, other):
296        if not isinstance(other, Trace):
297            return NotImplemented
298        return (self._trace == other._trace)
299
300    def __hash__(self):
301        return hash(self._trace)
302
303    def __str__(self):
304        return "%s: %s" % (self.traceback, _format_size(self.size, False))
305
306    def __repr__(self):
307        return ("<Trace domain=%s size=%s, traceback=%r>"
308                % (self.domain, _format_size(self.size, False), self.traceback))
309
310
311class _Traces(Sequence):
312    def __init__(self, traces):
313        Sequence.__init__(self)
314        # traces is a tuple of trace tuples: see Trace constructor
315        self._traces = traces
316
317    def __len__(self):
318        return len(self._traces)
319
320    def __getitem__(self, index):
321        if isinstance(index, slice):
322            return tuple(Trace(trace) for trace in self._traces[index])
323        else:
324            return Trace(self._traces[index])
325
326    def __contains__(self, trace):
327        return trace._trace in self._traces
328
329    def __eq__(self, other):
330        if not isinstance(other, _Traces):
331            return NotImplemented
332        return (self._traces == other._traces)
333
334    def __repr__(self):
335        return "<Traces len=%s>" % len(self)
336
337
338def _normalize_filename(filename):
339    filename = os.path.normcase(filename)
340    if filename.endswith('.pyc'):
341        filename = filename[:-1]
342    return filename
343
344
345class BaseFilter:
346    def __init__(self, inclusive):
347        self.inclusive = inclusive
348
349    def _match(self, trace):
350        raise NotImplementedError
351
352
353class Filter(BaseFilter):
354    def __init__(self, inclusive, filename_pattern,
355                 lineno=None, all_frames=False, domain=None):
356        super().__init__(inclusive)
357        self.inclusive = inclusive
358        self._filename_pattern = _normalize_filename(filename_pattern)
359        self.lineno = lineno
360        self.all_frames = all_frames
361        self.domain = domain
362
363    @property
364    def filename_pattern(self):
365        return self._filename_pattern
366
367    def _match_frame_impl(self, filename, lineno):
368        filename = _normalize_filename(filename)
369        if not fnmatch.fnmatch(filename, self._filename_pattern):
370            return False
371        if self.lineno is None:
372            return True
373        else:
374            return (lineno == self.lineno)
375
376    def _match_frame(self, filename, lineno):
377        return self._match_frame_impl(filename, lineno) ^ (not self.inclusive)
378
379    def _match_traceback(self, traceback):
380        if self.all_frames:
381            if any(self._match_frame_impl(filename, lineno)
382                   for filename, lineno in traceback):
383                return self.inclusive
384            else:
385                return (not self.inclusive)
386        else:
387            filename, lineno = traceback[0]
388            return self._match_frame(filename, lineno)
389
390    def _match(self, trace):
391        domain, size, traceback, total_nframe = trace
392        res = self._match_traceback(traceback)
393        if self.domain is not None:
394            if self.inclusive:
395                return res and (domain == self.domain)
396            else:
397                return res or (domain != self.domain)
398        return res
399
400
401class DomainFilter(BaseFilter):
402    def __init__(self, inclusive, domain):
403        super().__init__(inclusive)
404        self._domain = domain
405
406    @property
407    def domain(self):
408        return self._domain
409
410    def _match(self, trace):
411        domain, size, traceback, total_nframe = trace
412        return (domain == self.domain) ^ (not self.inclusive)
413
414
415class Snapshot:
416    """
417    Snapshot of traces of memory blocks allocated by Python.
418    """
419
420    def __init__(self, traces, traceback_limit):
421        # traces is a tuple of trace tuples: see _Traces constructor for
422        # the exact format
423        self.traces = _Traces(traces)
424        self.traceback_limit = traceback_limit
425
426    def dump(self, filename):
427        """
428        Write the snapshot into a file.
429        """
430        with open(filename, "wb") as fp:
431            pickle.dump(self, fp, pickle.HIGHEST_PROTOCOL)
432
433    @staticmethod
434    def load(filename):
435        """
436        Load a snapshot from a file.
437        """
438        with open(filename, "rb") as fp:
439            return pickle.load(fp)
440
441    def _filter_trace(self, include_filters, exclude_filters, trace):
442        if include_filters:
443            if not any(trace_filter._match(trace)
444                       for trace_filter in include_filters):
445                return False
446        if exclude_filters:
447            if any(not trace_filter._match(trace)
448                   for trace_filter in exclude_filters):
449                return False
450        return True
451
452    def filter_traces(self, filters):
453        """
454        Create a new Snapshot instance with a filtered traces sequence, filters
455        is a list of Filter or DomainFilter instances.  If filters is an empty
456        list, return a new Snapshot instance with a copy of the traces.
457        """
458        if not isinstance(filters, Iterable):
459            raise TypeError("filters must be a list of filters, not %s"
460                            % type(filters).__name__)
461        if filters:
462            include_filters = []
463            exclude_filters = []
464            for trace_filter in filters:
465                if trace_filter.inclusive:
466                    include_filters.append(trace_filter)
467                else:
468                    exclude_filters.append(trace_filter)
469            new_traces = [trace for trace in self.traces._traces
470                          if self._filter_trace(include_filters,
471                                                exclude_filters,
472                                                trace)]
473        else:
474            new_traces = self.traces._traces.copy()
475        return Snapshot(new_traces, self.traceback_limit)
476
477    def _group_by(self, key_type, cumulative):
478        if key_type not in ('traceback', 'filename', 'lineno'):
479            raise ValueError("unknown key_type: %r" % (key_type,))
480        if cumulative and key_type not in ('lineno', 'filename'):
481            raise ValueError("cumulative mode cannot by used "
482                             "with key type %r" % key_type)
483
484        stats = {}
485        tracebacks = {}
486        if not cumulative:
487            for trace in self.traces._traces:
488                domain, size, trace_traceback, total_nframe = trace
489                try:
490                    traceback = tracebacks[trace_traceback]
491                except KeyError:
492                    if key_type == 'traceback':
493                        frames = trace_traceback
494                    elif key_type == 'lineno':
495                        frames = trace_traceback[:1]
496                    else: # key_type == 'filename':
497                        frames = ((trace_traceback[0][0], 0),)
498                    traceback = Traceback(frames)
499                    tracebacks[trace_traceback] = traceback
500                try:
501                    stat = stats[traceback]
502                    stat.size += size
503                    stat.count += 1
504                except KeyError:
505                    stats[traceback] = Statistic(traceback, size, 1)
506        else:
507            # cumulative statistics
508            for trace in self.traces._traces:
509                domain, size, trace_traceback, total_nframe = trace
510                for frame in trace_traceback:
511                    try:
512                        traceback = tracebacks[frame]
513                    except KeyError:
514                        if key_type == 'lineno':
515                            frames = (frame,)
516                        else: # key_type == 'filename':
517                            frames = ((frame[0], 0),)
518                        traceback = Traceback(frames)
519                        tracebacks[frame] = traceback
520                    try:
521                        stat = stats[traceback]
522                        stat.size += size
523                        stat.count += 1
524                    except KeyError:
525                        stats[traceback] = Statistic(traceback, size, 1)
526        return stats
527
528    def statistics(self, key_type, cumulative=False):
529        """
530        Group statistics by key_type. Return a sorted list of Statistic
531        instances.
532        """
533        grouped = self._group_by(key_type, cumulative)
534        statistics = list(grouped.values())
535        statistics.sort(reverse=True, key=Statistic._sort_key)
536        return statistics
537
538    def compare_to(self, old_snapshot, key_type, cumulative=False):
539        """
540        Compute the differences with an old snapshot old_snapshot. Get
541        statistics as a sorted list of StatisticDiff instances, grouped by
542        group_by.
543        """
544        new_group = self._group_by(key_type, cumulative)
545        old_group = old_snapshot._group_by(key_type, cumulative)
546        statistics = _compare_grouped_stats(old_group, new_group)
547        statistics.sort(reverse=True, key=StatisticDiff._sort_key)
548        return statistics
549
550
551def take_snapshot():
552    """
553    Take a snapshot of traces of memory blocks allocated by Python.
554    """
555    if not is_tracing():
556        raise RuntimeError("the tracemalloc module must be tracing memory "
557                           "allocations to take a snapshot")
558    traces = _get_traces()
559    traceback_limit = get_traceback_limit()
560    return Snapshot(traces, traceback_limit)
561