1"""
2Variety of test cases ensuring that ddtrace does not leak memory.
3"""
4import gc
5import os
6from threading import Thread
7from typing import TYPE_CHECKING
8from weakref import WeakValueDictionary
9
10import pytest
11
12from ddtrace import Tracer
13
14
15if TYPE_CHECKING:
16    from ddtrace.span import Span
17
18
19@pytest.fixture
20def tracer():
21    # type: (...) -> Tracer
22    return Tracer()
23
24
25def trace(weakdict, tracer, *args, **kwargs):
26    # type: (WeakValueDictionary, Tracer, ...) -> Span
27    """Return a span created from ``tracer`` and add it to the given weak
28    dictionary.
29
30    Note: ensure to delete the returned reference from this function to ensure
31    no additional references are kept to the span.
32    """
33    s = tracer.trace(*args, **kwargs)
34    weakdict[s.span_id] = s
35    return s
36
37
38def test_leak(tracer):
39    wd = WeakValueDictionary()
40    span = trace(wd, tracer, "span1")
41    span2 = trace(wd, tracer, "span2")
42    assert len(wd) == 2
43
44    # The spans are still open and referenced so they should not be gc'd
45    gc.collect()
46    assert len(wd) == 2
47    span2.finish()
48    span.finish()
49    del span, span2
50    gc.collect()
51    assert len(wd) == 0
52
53
54def test_single_thread_single_trace(tracer):
55    """
56    Ensure a simple trace doesn't leak span objects.
57    """
58    wd = WeakValueDictionary()
59    with trace(wd, tracer, "span1"):
60        with trace(wd, tracer, "span2"):
61            pass
62
63    # Spans are serialized and unreferenced when traces are finished
64    # so gc-ing right away should delete all span objects.
65    gc.collect()
66    assert len(wd) == 0
67
68
69def test_single_thread_multi_trace(tracer):
70    """
71    Ensure a trace in a thread is properly garbage collected.
72    """
73    wd = WeakValueDictionary()
74    for _ in range(1000):
75        with trace(wd, tracer, "span1"):
76            with trace(wd, tracer, "span2"):
77                pass
78            with trace(wd, tracer, "span3"):
79                pass
80
81    # Once these references are deleted then the spans should no longer be
82    # referenced by anything and should be gc'd.
83    gc.collect()
84    assert len(wd) == 0
85
86
87def test_multithread_trace(tracer):
88    """
89    Ensure a trace that crosses thread boundaries is properly garbage collected.
90    """
91    wd = WeakValueDictionary()
92    state = []
93
94    def _target(ctx):
95        tracer.context_provider.activate(ctx)
96        with trace(wd, tracer, "thread"):
97            assert len(wd) == 2
98        state.append(1)
99
100    span = trace(wd, tracer, "")
101    t = Thread(target=_target, args=(span.context,))
102    t.start()
103    t.join()
104    # Ensure thread finished successfully
105    assert state == [1]
106
107    span.finish()
108    del span
109    gc.collect()
110    assert len(wd) == 0
111
112
113def test_fork_open_span(tracer):
114    """
115    When a fork occurs with an open span then the child process should not have
116    a strong reference to the span because it might never be closed.
117    """
118    wd = WeakValueDictionary()
119    span = trace(wd, tracer, "span")
120    pid = os.fork()
121
122    if pid == 0:
123        assert len(wd) == 1
124        gc.collect()
125        # span is still open and in the context
126        assert len(wd) == 1
127        span2 = trace(wd, tracer, "span2")
128        assert span2._parent is None
129        assert len(wd) == 2
130        span2.finish()
131
132        del span2
133        # Expect there to be one span left (the original from before the fork)
134        # which is inherited into the child process but will never be closed.
135        # The important thing in this test case is all new spans created in the
136        # child will be gc'd.
137        gc.collect()
138        assert len(wd) == 1
139
140        # Normally, if the child process leaves this function frame the span
141        # reference would be lost and it would be free to be gc'd. We delete
142        # the reference explicitly here to mimic this scenario.
143        del span
144        gc.collect()
145        assert len(wd) == 0
146        os._exit(12)
147
148    assert len(wd) == 1
149    span.finish()
150    del span
151    gc.collect()
152    assert len(wd) == 0
153
154    _, status = os.waitpid(pid, 0)
155    exit_code = os.WEXITSTATUS(status)
156    assert exit_code == 12
157