1from itertools import chain
2import multiprocessing as mp
3
4
5try:
6    from multiprocessing import SimpleQueue as MPQueue
7except ImportError:
8    from multiprocessing.queues import SimpleQueue as MPQueue
9
10import os
11import threading
12
13from ddtrace import Span
14from ddtrace import tracer
15from ddtrace.internal import _rand
16from ddtrace.internal import forksafe
17from ddtrace.internal.compat import Queue
18
19
20def test_random():
21    m = set()
22    for i in range(0, 2 ** 16):
23        n = _rand.rand64bits()
24        assert 0 <= n <= 2 ** 64 - 1
25        assert n not in m
26        m.add(n)
27
28
29def test_fork_no_pid_check():
30    q = MPQueue()
31    pid = os.fork()
32
33    # Generate random numbers in the parent and child processes after forking.
34    # The child sends back their numbers to the parent where we check to see
35    # if we get collisions or not.
36    if pid > 0:
37        # parent
38        rns = {_rand.rand64bits() for _ in range(100)}
39        child_rns = q.get()
40
41        assert rns & child_rns == set()
42
43    else:
44        # child
45        try:
46            rngs = {_rand.rand64bits() for _ in range(100)}
47            q.put(rngs)
48        finally:
49            # Kill the process so it doesn't continue running the rest of the
50            # test suite in a separate process. Note we can't use sys.exit()
51            # as it raises an exception that pytest will detect as an error.
52            os._exit(0)
53
54
55def test_fork_pid_check():
56    q = MPQueue()
57    pid = os.fork()
58
59    # Generate random numbers in the parent and child processes after forking.
60    # The child sends back their numbers to the parent where we check to see
61    # if we get collisions or not.
62    if pid > 0:
63        # parent
64        rns = {_rand.rand64bits() for _ in range(100)}
65        child_rns = q.get()
66
67        assert rns & child_rns == set()
68
69    else:
70        # child
71        try:
72            rngs = {_rand.rand64bits() for _ in range(100)}
73            q.put(rngs)
74        finally:
75            # Kill the process so it doesn't continue running the rest of the
76            # test suite in a separate process. Note we can't use sys.exit()
77            # as it raises an exception that pytest will detect as an error.
78            os._exit(0)
79
80
81def test_multiprocess():
82    q = MPQueue()
83
84    def target(q):
85        assert sum((_ is _rand.seed for _ in forksafe._registry)) == 1
86        q.put([_rand.rand64bits() for _ in range(100)])
87
88    ps = [mp.Process(target=target, args=(q,)) for _ in range(30)]
89    for p in ps:
90        p.start()
91
92    for p in ps:
93        p.join()
94        assert p.exitcode == 0
95
96    ids_list = [_rand.rand64bits() for _ in range(1000)]
97    ids = set(ids_list)
98    assert len(ids_list) == len(ids), "Collisions found in ids"
99
100    while not q.empty():
101        child_ids_list = q.get()
102        child_ids = set(child_ids_list)
103
104        assert len(child_ids_list) == len(child_ids), "Collisions found in subprocess ids"
105
106        assert ids & child_ids == set()
107        ids = ids | child_ids  # accumulate the ids
108
109
110def test_threadsafe():
111    # Check that the PRNG is thread-safe.
112    # This obviously won't guarantee thread safety, but it's something
113    # at least.
114    # To provide some validation of this method I wrote a slow, unsafe RNG:
115    #
116    # state = 4101842887655102017
117    #
118    # def bad_random():
119    #     global state
120    #     state ^= state >> 21
121    #     state ^= state << 35
122    #     state ^= state >> 4
123    #     return state * 2685821657736338717
124    #
125    # which consistently fails this test.
126
127    q = Queue()
128
129    def _target():
130        # Generate a bunch of numbers to try to maximize the chance that
131        # two threads will be calling rand64bits at the same time.
132        rngs = [_rand.rand64bits() for _ in range(200000)]
133        q.put(rngs)
134
135    ts = [threading.Thread(target=_target) for _ in range(5)]
136
137    for t in ts:
138        t.start()
139
140    for t in ts:
141        t.join()
142
143    ids = set()
144
145    while not q.empty():
146        new_ids_list = q.get()
147
148        new_ids = set(new_ids_list)
149        assert len(new_ids) == len(new_ids_list), "Collision found in ids"
150        assert ids & new_ids == set()
151        ids = ids | new_ids
152
153    assert len(ids) > 0
154
155
156def test_tracer_usage_fork():
157    q = MPQueue()
158    pid = os.fork()
159
160    # Similar test to test_fork() above except we use the tracer API.
161    # In this case we expect to never have collisions.
162    if pid > 0:
163        # parent
164        parent_ids_list = list(
165            chain.from_iterable((s.span_id, s.trace_id) for s in [tracer.start_span("s") for _ in range(100)])
166        )
167        parent_ids = set(parent_ids_list)
168        assert len(parent_ids) == len(parent_ids_list), "Collisions found in parent process ids"
169
170        child_ids_list = q.get()
171
172        child_ids = set(child_ids_list)
173
174        assert len(child_ids) == len(child_ids_list), "Collisions found in child process ids"
175        assert parent_ids & child_ids == set()
176    else:
177        # child
178        try:
179            child_ids = list(
180                chain.from_iterable((s.span_id, s.trace_id) for s in [tracer.start_span("s") for _ in range(100)])
181            )
182            q.put(child_ids)
183        finally:
184            # Kill the process so it doesn't continue running the rest of the
185            # test suite in a separate process. Note we can't use sys.exit()
186            # as it raises an exception that pytest will detect as an error.
187            os._exit(0)
188
189
190def test_tracer_usage_multiprocess():
191    q = MPQueue()
192
193    # Similar to test_multiprocess(), ensures that no collisions are
194    # generated between parent and child processes while using
195    # multiprocessing.
196
197    # Note that we have to be wary of the size of the underlying
198    # pipe in the queue: https://bugs.python.org/msg143081
199
200    def target(q):
201        ids_list = list(
202            chain.from_iterable((s.span_id, s.trace_id) for s in [tracer.start_span("s") for _ in range(10)])
203        )
204        q.put(ids_list)
205
206    ps = [mp.Process(target=target, args=(q,)) for _ in range(30)]
207    for p in ps:
208        p.start()
209
210    for p in ps:
211        p.join()
212
213    ids_list = list(chain.from_iterable((s.span_id, s.trace_id) for s in [tracer.start_span("s") for _ in range(100)]))
214    ids = set(ids_list)
215    assert len(ids) == len(ids_list), "Collisions found in ids"
216
217    while not q.empty():
218        child_ids_list = q.get()
219        child_ids = set(child_ids_list)
220
221        assert len(child_ids) == len(child_ids_list), "Collisions found in subprocess ids"
222
223        assert ids & child_ids == set()
224        ids = ids | child_ids  # accumulate the ids
225
226
227def test_span_api_fork():
228    q = MPQueue()
229    pid = os.fork()
230
231    if pid > 0:
232        # parent
233        parent_ids_list = list(
234            chain.from_iterable((s.span_id, s.trace_id) for s in [Span(None, None) for _ in range(100)])
235        )
236        parent_ids = set(parent_ids_list)
237        assert len(parent_ids) == len(parent_ids_list), "Collisions found in parent process ids"
238
239        child_ids_list = q.get()
240
241        child_ids = set(child_ids_list)
242
243        assert len(child_ids) == len(child_ids_list), "Collisions found in child process ids"
244        assert parent_ids & child_ids == set()
245    else:
246        # child
247        try:
248            child_ids = list(
249                chain.from_iterable((s.span_id, s.trace_id) for s in [Span(None, None) for _ in range(100)])
250            )
251            q.put(child_ids)
252        finally:
253            os._exit(0)
254