1import collections
2import itertools
3import operator
4import typing
5
6import attr
7import six
8
9from ddtrace import ext
10from ddtrace.internal.utils import config
11from ddtrace.profiling import exporter
12from ddtrace.profiling.collector import memalloc
13from ddtrace.profiling.collector import stack
14from ddtrace.profiling.collector import threading
15
16
17def _protobuf_post_312():
18    # type: (...) -> bool
19    """Check if protobuf version is post 3.12"""
20    import google.protobuf
21
22    from ddtrace.internal.utils.version import parse_version
23
24    v = parse_version(google.protobuf.__version__)
25    return v[0] >= 3 and v[1] >= 12
26
27
28if _protobuf_post_312():
29    from ddtrace.profiling.exporter import pprof_pb2
30else:
31    from ddtrace.profiling.exporter import pprof_pre312_pb2 as pprof_pb2
32
33
34_ITEMGETTER_ZERO = operator.itemgetter(0)
35_ITEMGETTER_ONE = operator.itemgetter(1)
36_ATTRGETTER_ID = operator.attrgetter("id")
37
38
39@attr.s
40class _Sequence(object):
41    start_at = attr.ib(default=1, type=int)
42    next_id = attr.ib(init=False, default=None, type=int)
43
44    def __attrs_post_init__(self) -> None:
45        self.next_id = self.start_at
46
47    def generate(self) -> int:
48        """Generate a new unique id and return it."""
49        generated_id = self.next_id
50        self.next_id += 1
51        return generated_id
52
53
54@attr.s
55class _StringTable(object):
56    _strings = attr.ib(init=False, factory=lambda: {"": 0})
57    _seq_id = attr.ib(init=False, factory=_Sequence)
58
59    def to_id(self, string: str) -> int:
60        try:
61            return self._strings[string]
62        except KeyError:
63            generated_id = self._strings[string] = self._seq_id.generate()
64            return generated_id
65
66    def __iter__(self):
67        for string, _ in sorted(self._strings.items(), key=_ITEMGETTER_ONE):
68            yield string
69
70    def __len__(self) -> int:
71        return len(self._strings)
72
73
74@attr.s
75class _PprofConverter(object):
76    """Convert stacks generated by a Profiler to pprof format."""
77
78    # Those attributes will be serialize in a `pprof_pb2.Profile`
79    _functions = attr.ib(init=False, factory=dict)
80    _locations = attr.ib(init=False, factory=dict)
81    _string_table = attr.ib(init=False, factory=_StringTable)
82
83    _last_location_id = attr.ib(init=False, factory=_Sequence)
84    _last_func_id = attr.ib(init=False, factory=_Sequence)
85
86    # A dict where key is a (Location, [Labels]) and value is a a dict.
87    # This dict has sample-type (e.g. "cpu-time") as key and the numeric value.
88    _location_values = attr.ib(
89        factory=lambda: collections.defaultdict(lambda: collections.defaultdict(lambda: 0)), init=False, repr=False
90    )
91
92    def _to_Function(self, filename, funcname):
93        try:
94            return self._functions[(filename, funcname)]
95        except KeyError:
96            func = pprof_pb2.Function(
97                id=self._last_func_id.generate(),
98                name=self._str(funcname),
99                filename=self._str(filename),
100            )
101            self._functions[(filename, funcname)] = func
102            return func
103
104    def _to_Location(self, filename, lineno, funcname=None):
105        try:
106            return self._locations[(filename, lineno, funcname)]
107        except KeyError:
108            if funcname is None:
109                real_funcname = "<unknown function>"
110            else:
111                real_funcname = funcname
112            location = pprof_pb2.Location(
113                id=self._last_location_id.generate(),
114                line=[
115                    pprof_pb2.Line(
116                        function_id=self._to_Function(filename, real_funcname).id,
117                        line=lineno,
118                    ),
119                ],
120            )
121            self._locations[(filename, lineno, funcname)] = location
122            return location
123
124    def _str(self, string):
125        """Convert a string to an id from the string table."""
126        return self._string_table.to_id(str(string))
127
128    def _to_locations(self, frames, nframes):
129        locations = [self._to_Location(filename, lineno, funcname).id for filename, lineno, funcname in frames]
130
131        omitted = nframes - len(frames)
132        if omitted:
133            locations.append(
134                self._to_Location("", 0, "<%d frame%s omitted>" % (omitted, ("s" if omitted > 1 else ""))).id
135            )
136
137        return tuple(locations)
138
139    def convert_stack_event(
140        self,
141        thread_id,  # type: int
142        thread_native_id,  # type: int
143        thread_name,  # type: str
144        task_id,  # type: typing.Optional[int]
145        task_name,  # type: str
146        local_root_span_id,  # type: int
147        span_id,  # type: int
148        trace_resource,  # type: str
149        trace_type,  # type: str
150        frames,
151        nframes,  # type: int
152        samples,  # type: typing.Iterable[stack.StackSampleEvent]
153    ):
154        # type: (...) -> None
155        location_key = (
156            self._to_locations(frames, nframes),
157            (
158                ("thread id", str(thread_id)),
159                ("thread native id", str(thread_native_id)),
160                ("thread name", thread_name),
161                ("task id", task_id),
162                ("task name", task_name),
163                ("local root span id", local_root_span_id),
164                ("span id", span_id),
165                ("trace endpoint", trace_resource),
166                ("trace type", trace_type),
167            ),
168        )
169
170        self._location_values[location_key]["cpu-samples"] = len(samples)
171        self._location_values[location_key]["cpu-time"] = sum(s.cpu_time_ns for s in samples)
172        self._location_values[location_key]["wall-time"] = sum(s.wall_time_ns for s in samples)
173
174    def convert_memalloc_event(self, thread_id, thread_native_id, thread_name, frames, nframes, events):
175        location_key = (
176            self._to_locations(frames, nframes),
177            (
178                ("thread id", str(thread_id)),
179                ("thread native id", str(thread_native_id)),
180                ("thread name", thread_name),
181            ),
182        )
183
184        self._location_values[location_key]["alloc-samples"] = sum(event.nevents for event in events)
185        self._location_values[location_key]["alloc-space"] = round(
186            sum(event.size / event.capture_pct * 100.0 for event in events)
187        )
188
189    def convert_memalloc_heap_event(self, event):
190        location_key = (
191            self._to_locations(event.frames, event.nframes),
192            (
193                ("thread id", str(event.thread_id)),
194                ("thread native id", str(event.thread_native_id)),
195                ("thread name", event.thread_name),
196            ),
197        )
198
199        self._location_values[location_key]["heap-space"] += event.size
200
201    def convert_lock_acquire_event(
202        self,
203        lock_name,
204        thread_id,
205        thread_name,
206        task_id,
207        task_name,
208        local_root_span_id,
209        span_id,
210        trace_resource,
211        trace_type,
212        frames,
213        nframes,
214        events,
215        sampling_ratio,
216    ):
217        location_key = (
218            self._to_locations(frames, nframes),
219            (
220                ("thread id", str(thread_id)),
221                ("thread name", thread_name),
222                ("task id", str(task_id)),
223                ("task name", task_name),
224                ("local root span id", local_root_span_id),
225                ("span id", span_id),
226                ("trace endpoint", trace_resource),
227                ("trace type", trace_type),
228                ("lock name", lock_name),
229            ),
230        )
231
232        self._location_values[location_key]["lock-acquire"] = len(events)
233        self._location_values[location_key]["lock-acquire-wait"] = int(
234            sum(e.wait_time_ns for e in events) / sampling_ratio
235        )
236
237    def convert_lock_release_event(
238        self,
239        lock_name,
240        thread_id,
241        thread_name,
242        task_id,
243        task_name,
244        local_root_span_id,
245        span_id,
246        trace_resource,
247        trace_type,
248        frames,
249        nframes,
250        events,
251        sampling_ratio,
252    ):
253        location_key = (
254            self._to_locations(frames, nframes),
255            (
256                ("thread id", str(thread_id)),
257                ("thread name", thread_name),
258                ("task id", str(task_id)),
259                ("task name", task_name),
260                ("local root span id", local_root_span_id),
261                ("span id", span_id),
262                ("trace endpoint", trace_resource),
263                ("trace type", trace_type),
264                ("lock name", lock_name),
265            ),
266        )
267
268        self._location_values[location_key]["lock-release"] = len(events)
269        self._location_values[location_key]["lock-release-hold"] = int(
270            sum(e.locked_for_ns for e in events) / sampling_ratio
271        )
272
273    def convert_stack_exception_event(
274        self,
275        thread_id,
276        thread_native_id,
277        thread_name,
278        local_root_span_id,
279        span_id,
280        trace_resource,
281        trace_type,
282        frames,
283        nframes,
284        exc_type_name,
285        events,
286    ):
287        location_key = (
288            self._to_locations(frames, nframes),
289            (
290                ("thread id", str(thread_id)),
291                ("thread native id", str(thread_native_id)),
292                ("thread name", thread_name),
293                ("local root span id", local_root_span_id),
294                ("span id", span_id),
295                ("trace endpoint", trace_resource),
296                ("trace type", trace_type),
297                ("exception type", exc_type_name),
298            ),
299        )
300
301        self._location_values[location_key]["exception-samples"] = len(events)
302
303    def _build_profile(self, start_time_ns, duration_ns, period, sample_types, program_name) -> pprof_pb2.Profile:
304        pprof_sample_type = [
305            pprof_pb2.ValueType(type=self._str(type_), unit=self._str(unit)) for type_, unit in sample_types
306        ]
307
308        sample = [
309            pprof_pb2.Sample(
310                location_id=locations,
311                value=[values.get(sample_type_name, 0) for sample_type_name, unit in sample_types],
312                label=[pprof_pb2.Label(key=self._str(key), str=self._str(s)) for key, s in labels],
313            )
314            for (locations, labels), values in sorted(six.iteritems(self._location_values), key=_ITEMGETTER_ZERO)
315        ]
316
317        period_type = pprof_pb2.ValueType(type=self._str("time"), unit=self._str("nanoseconds"))
318
319        # WARNING: no code should use _str() here as once the _string_table is serialized below,
320        # it won't be updated if you call _str later in the code here
321        return pprof_pb2.Profile(
322            sample_type=pprof_sample_type,
323            sample=sample,
324            mapping=[
325                pprof_pb2.Mapping(
326                    id=1,
327                    filename=self._str(program_name),
328                ),
329            ],
330            # Sort location and function by id so the output is reproducible
331            location=sorted(self._locations.values(), key=_ATTRGETTER_ID),
332            function=sorted(self._functions.values(), key=_ATTRGETTER_ID),
333            string_table=list(self._string_table),
334            time_nanos=start_time_ns,
335            duration_nanos=duration_ns,
336            period=period,
337            period_type=period_type,
338        )
339
340
341_stack_event_group_key_T = typing.Tuple[
342    int,  # thread_id
343    int,  # thread_native_id
344    str,  # thread name
345    str,  # task_id
346    str,  # task_name
347    str,  # local_root_span_id
348    str,  # span_id
349    str,  # trace resource
350    str,  # trace type
351    typing.Tuple,  # frames
352    int,  # nframes
353]
354
355
356@attr.s
357class PprofExporter(exporter.Exporter):
358    """Export recorder events to pprof format."""
359
360    @staticmethod
361    def _none_to_str(
362        value,  # type: typing.Optional[typing.Any]
363    ):
364        # type: (...) -> str
365        if value is None:
366            return ""
367        return str(value)
368
369    @staticmethod
370    def _get_thread_name(thread_id, thread_name):
371        if thread_name is None:
372            return "Anonymous Thread %d" % thread_id
373        return thread_name
374
375    def _stack_event_group_key(
376        self,
377        event,  # type: stack.StackSampleEvent
378    ):
379        # type: (...) -> _stack_event_group_key_T
380        return (
381            event.thread_id,
382            event.thread_native_id,
383            self._get_thread_name(event.thread_id, event.thread_name),
384            self._none_to_str(event.task_id),
385            self._none_to_str(event.task_name),
386            self._none_to_str(event.local_root_span_id),
387            self._none_to_str(event.span_id),
388            self._none_to_str(self._get_event_trace_resource(event)),
389            self._none_to_str(event.trace_type),
390            tuple(event.frames),
391            event.nframes,
392        )
393
394    def _group_stack_events(
395        self,
396        events,  # type: typing.Iterable[stack.StackSampleEvent]
397    ):
398        # type: typing.Iterator[typing.Tuple[_stack_event_group_key_T, typing.Iterator[stack.StackSampleEvent]]]
399        return itertools.groupby(
400            sorted(events, key=self._stack_event_group_key),
401            key=self._stack_event_group_key,
402        )
403
404    def _lock_event_group_key(
405        self, event  # type: lock.LockEventBase
406    ):
407        return (
408            event.lock_name,
409            event.thread_id,
410            self._get_thread_name(event.thread_id, event.thread_name),
411            self._none_to_str(event.task_id),
412            self._none_to_str(event.task_name),
413            self._none_to_str(event.local_root_span_id),
414            self._none_to_str(event.span_id),
415            self._none_to_str(self._get_event_trace_resource(event)),
416            self._none_to_str(event.trace_type),
417            tuple(event.frames),
418            event.nframes,
419        )
420
421    def _group_lock_events(self, events):
422        return itertools.groupby(
423            sorted(events, key=self._lock_event_group_key),
424            key=self._lock_event_group_key,
425        )
426
427    def _stack_exception_group_key(self, event):
428        exc_type = event.exc_type
429        exc_type_name = exc_type.__module__ + "." + exc_type.__name__
430
431        return (
432            event.thread_id,
433            event.thread_native_id,
434            self._get_thread_name(event.thread_id, event.thread_name),
435            self._none_to_str(event.local_root_span_id),
436            self._none_to_str(event.span_id),
437            self._none_to_str(self._get_event_trace_resource(event)),
438            self._none_to_str(event.trace_type),
439            tuple(event.frames),
440            event.nframes,
441            exc_type_name,
442        )
443
444    def _group_stack_exception_events(self, events):
445        return itertools.groupby(
446            sorted(events, key=self._stack_exception_group_key),
447            key=self._stack_exception_group_key,
448        )
449
450    def _get_event_trace_resource(self, event):
451        trace_resource = None
452        # Do not export trace_resource for non Web spans for privacy concerns.
453        if event.trace_resource_container and event.trace_type == ext.SpanTypes.WEB.value:
454            (trace_resource,) = event.trace_resource_container
455        return trace_resource
456
457    def export(self, events, start_time_ns, end_time_ns) -> pprof_pb2.Profile:  # type: ignore[valid-type]
458        """Convert events to pprof format.
459
460        :param events: The event dictionary from a `ddtrace.profiling.recorder.Recorder`.
461        :param start_time_ns: The start time of recording.
462        :param end_time_ns: The end time of recording.
463        :return: A protobuf Profile object.
464        """
465        program_name = config.get_application_name()
466
467        sum_period = 0
468        nb_event = 0
469
470        converter = _PprofConverter()
471
472        # Handle StackSampleEvent
473        stack_events = []
474        for event in events.get(stack.StackSampleEvent, []):
475            stack_events.append(event)
476            sum_period += event.sampling_period
477            nb_event += 1
478
479        for (
480            (
481                thread_id,
482                thread_native_id,
483                thread_name,
484                task_id,
485                task_name,
486                local_root_span_id,
487                span_id,
488                trace_resource,
489                trace_type,
490                frames,
491                nframes,
492            ),
493            stack_events,
494        ) in self._group_stack_events(stack_events):
495            converter.convert_stack_event(
496                thread_id,
497                thread_native_id,
498                thread_name,
499                task_id,
500                task_name,
501                local_root_span_id,
502                span_id,
503                trace_resource,
504                trace_type,
505                frames,
506                nframes,
507                list(stack_events),
508            )
509
510        # Handle Lock events
511        for event_class, convert_fn in (
512            (threading.LockAcquireEvent, converter.convert_lock_acquire_event),
513            (threading.LockReleaseEvent, converter.convert_lock_release_event),
514        ):
515            lock_events = events.get(event_class, [])
516            sampling_sum_pct = sum(event.sampling_pct for event in lock_events)
517
518            if lock_events:
519                sampling_ratio_avg = sampling_sum_pct / (len(lock_events) * 100.0)
520
521                for (
522                    lock_name,
523                    thread_id,
524                    thread_name,
525                    task_id,
526                    task_name,
527                    local_root_span_id,
528                    span_id,
529                    trace_resource,
530                    trace_type,
531                    frames,
532                    nframes,
533                ), l_events in self._group_lock_events(lock_events):
534                    convert_fn(
535                        lock_name,
536                        thread_id,
537                        thread_name,
538                        task_id,
539                        task_name,
540                        local_root_span_id,
541                        span_id,
542                        trace_resource,
543                        trace_type,
544                        frames,
545                        nframes,
546                        list(l_events),
547                        sampling_ratio_avg,
548                    )
549
550        for (
551            (
552                thread_id,
553                thread_native_id,
554                thread_name,
555                local_root_span_id,
556                span_id,
557                trace_resource,
558                trace_type,
559                frames,
560                nframes,
561                exc_type_name,
562            ),
563            se_events,
564        ) in self._group_stack_exception_events(events.get(stack.StackExceptionSampleEvent, [])):
565            converter.convert_stack_exception_event(
566                thread_id,
567                thread_native_id,
568                thread_name,
569                local_root_span_id,
570                span_id,
571                trace_resource,
572                trace_type,
573                frames,
574                nframes,
575                exc_type_name,
576                list(se_events),
577            )
578
579        if memalloc._memalloc:
580            for (
581                (
582                    thread_id,
583                    thread_native_id,
584                    thread_name,
585                    task_id,
586                    task_name,
587                    local_root_span_id,
588                    span_id,
589                    trace_resource,
590                    trace_type,
591                    frames,
592                    nframes,
593                ),
594                memalloc_events,
595            ) in self._group_stack_events(events.get(memalloc.MemoryAllocSampleEvent, [])):
596                converter.convert_memalloc_event(
597                    thread_id,
598                    thread_native_id,
599                    thread_name,
600                    frames,
601                    nframes,
602                    list(memalloc_events),
603                )
604
605            for event in events.get(memalloc.MemoryHeapSampleEvent, []):
606                converter.convert_memalloc_heap_event(event)
607
608        # Compute some metadata
609        if nb_event:
610            period = int(sum_period / nb_event)
611        else:
612            period = None
613
614        duration_ns = end_time_ns - start_time_ns
615
616        sample_types = (
617            ("cpu-samples", "count"),
618            ("cpu-time", "nanoseconds"),
619            ("wall-time", "nanoseconds"),
620            ("exception-samples", "count"),
621            ("lock-acquire", "count"),
622            ("lock-acquire-wait", "nanoseconds"),
623            ("lock-release", "count"),
624            ("lock-release-hold", "nanoseconds"),
625            ("alloc-samples", "count"),
626            ("alloc-space", "bytes"),
627            ("heap-space", "bytes"),
628        )
629
630        return converter._build_profile(
631            start_time_ns=start_time_ns,
632            duration_ns=duration_ns,
633            period=period,
634            sample_types=sample_types,
635            program_name=program_name,
636        )
637