1# -*- encoding: utf-8 -*-
2import collections
3import email.parser
4import platform
5import socket
6import threading
7import time
8
9import pytest
10import six
11from six.moves import BaseHTTPServer
12from six.moves import http_client
13
14import ddtrace
15from ddtrace.internal import compat
16from ddtrace.profiling import exporter
17from ddtrace.profiling.exporter import http
18
19from . import test_pprof
20
21
22_API_KEY = "my-api-key"
23
24
25class _APIEndpointRequestHandlerTest(BaseHTTPServer.BaseHTTPRequestHandler):
26    error_message_format = "%(message)s\n"
27    error_content_type = "text/plain"
28    path_prefix = "/profiling/v1"
29
30    @staticmethod
31    def log_message(format, *args):  # noqa: A002
32        pass
33
34    @staticmethod
35    def _check_tags(tags):
36        tags.sort()
37        return (
38            len(tags) == 6
39            and tags[0].startswith(b"host:")
40            and tags[1] == b"language:python"
41            and tags[2] == ("profiler_version:%s" % ddtrace.__version__).encode("utf-8")
42            and tags[3].startswith(b"runtime-id:")
43            and tags[4] == b"runtime:CPython"
44            and tags[5].startswith(b"service:")
45            and tags[6] == platform.python_version().encode(),
46        )
47
48    def do_POST(self):
49        assert self.path.startswith(self.path_prefix)
50        api_key = self.headers["DD-API-KEY"]
51        if api_key != _API_KEY:
52            self.send_error(400, "Wrong API Key")
53            return
54        length = int(self.headers["Content-Length"])
55        body = self.rfile.read(length)
56        mmpart = b"Content-Type: " + self.headers["Content-Type"].encode() + b"\r\n" + body
57        if six.PY2:
58            msg = email.parser.Parser().parsestr(mmpart)
59        else:
60            msg = email.parser.BytesParser().parsebytes(mmpart)
61        if not msg.is_multipart():
62            self.send_error(400, "No multipart")
63            return
64        items = collections.defaultdict(list)
65        for part in msg.get_payload():
66            items[part.get_param("name", header="content-disposition")].append(part.get_payload(decode=True))
67        for key, check in {
68            "recording-start": lambda x: x[0] == b"1970-01-01T00:00:00Z",
69            "recording-end": lambda x: x[0].startswith(b"20"),
70            "runtime": lambda x: x[0] == platform.python_implementation().encode(),
71            "format": lambda x: x[0] == b"pprof",
72            "type": lambda x: x[0] == b"cpu+alloc+exceptions",
73            "tags[]": self._check_tags,
74            "chunk-data": lambda x: x[0].startswith(b"\x1f\x8b\x08\x00"),
75        }.items():
76            if not check(items[key]):
77                self.send_error(400, "Wrong value for %s: %r" % (key, items[key]))
78                return
79        self.send_error(200, "OK")
80
81
82class _TimeoutAPIEndpointRequestHandlerTest(_APIEndpointRequestHandlerTest):
83    def do_POST(self):
84        # This server sleeps longer than our timeout
85        time.sleep(5)
86        self.send_error(500, "Argh")
87
88
89class _ResetAPIEndpointRequestHandlerTest(_APIEndpointRequestHandlerTest):
90    def do_POST(self):
91        return
92
93
94class _UnknownAPIEndpointRequestHandlerTest(_APIEndpointRequestHandlerTest):
95    def do_POST(self):
96        self.send_error(404, "Argh")
97
98
99_PORT = 8992
100_TIMEOUT_PORT = _PORT + 1
101_RESET_PORT = _PORT + 2
102_UNKNOWN_PORT = _PORT + 3
103_ENDPOINT = "http://localhost:%d" % _PORT
104_TIMEOUT_ENDPOINT = "http://localhost:%d" % _TIMEOUT_PORT
105_RESET_ENDPOINT = "http://localhost:%d" % _RESET_PORT
106_UNKNOWN_ENDPOINT = "http://localhost:%d" % _UNKNOWN_PORT
107
108
109def _make_server(port, request_handler):
110    server = BaseHTTPServer.HTTPServer(("localhost", port), request_handler)
111    t = threading.Thread(target=server.serve_forever)
112    # Set daemon just in case something fails
113    t.daemon = True
114    t.start()
115    return server, t
116
117
118@pytest.fixture(scope="module")
119def endpoint_test_server():
120    server, thread = _make_server(_PORT, _APIEndpointRequestHandlerTest)
121    try:
122        yield thread
123    finally:
124        server.shutdown()
125        thread.join()
126
127
128@pytest.fixture(scope="module")
129def endpoint_test_timeout_server():
130    server, thread = _make_server(_TIMEOUT_PORT, _TimeoutAPIEndpointRequestHandlerTest)
131    try:
132        yield thread
133    finally:
134        server.shutdown()
135        thread.join()
136
137
138@pytest.fixture(scope="module")
139def endpoint_test_reset_server():
140    server, thread = _make_server(_RESET_PORT, _ResetAPIEndpointRequestHandlerTest)
141    try:
142        yield thread
143    finally:
144        server.shutdown()
145        thread.join()
146
147
148@pytest.fixture(scope="module")
149def endpoint_test_unknown_server():
150    server, thread = _make_server(_UNKNOWN_PORT, _UnknownAPIEndpointRequestHandlerTest)
151    try:
152        yield thread
153    finally:
154        server.shutdown()
155        thread.join()
156
157
158def test_wrong_api_key(endpoint_test_server):
159    # This is mostly testing our test server, not the exporter
160    exp = http.PprofHTTPExporter(_ENDPOINT, "this is not the right API key", max_retry_delay=2)
161    with pytest.raises(exporter.ExportError) as t:
162        exp.export(test_pprof.TEST_EVENTS, 0, 1)
163    assert str(t.value) == "Server returned 400, check your API key"
164
165
166def test_export(endpoint_test_server):
167    exp = http.PprofHTTPExporter(_ENDPOINT, _API_KEY)
168    exp.export(test_pprof.TEST_EVENTS, 0, compat.time_ns())
169
170
171def test_export_server_down():
172    exp = http.PprofHTTPExporter("http://localhost:2", _API_KEY, max_retry_delay=2)
173    with pytest.raises(http.UploadFailed) as t:
174        exp.export(test_pprof.TEST_EVENTS, 0, 1)
175    e = t.value.last_attempt.exception()
176    assert isinstance(e, (IOError, OSError))
177    assert str(t.value).startswith("[Errno ")
178
179
180def test_export_timeout(endpoint_test_timeout_server):
181    exp = http.PprofHTTPExporter(_TIMEOUT_ENDPOINT, _API_KEY, timeout=1, max_retry_delay=2)
182    with pytest.raises(http.UploadFailed) as t:
183        exp.export(test_pprof.TEST_EVENTS, 0, 1)
184    e = t.value.last_attempt.exception()
185    assert isinstance(e, socket.timeout)
186    assert str(t.value) == "timed out"
187
188
189def test_export_reset(endpoint_test_reset_server):
190    exp = http.PprofHTTPExporter(_RESET_ENDPOINT, _API_KEY, timeout=1, max_retry_delay=2)
191    with pytest.raises(http.UploadFailed) as t:
192        exp.export(test_pprof.TEST_EVENTS, 0, 1)
193    e = t.value.last_attempt.exception()
194    if six.PY3:
195        assert isinstance(e, ConnectionResetError)
196    else:
197        assert isinstance(e, http_client.BadStatusLine)
198        assert str(e) == "No status line received - the server has closed the connection"
199
200
201def test_export_404_agent(endpoint_test_unknown_server):
202    exp = http.PprofHTTPExporter(_UNKNOWN_ENDPOINT)
203    with pytest.raises(exporter.ExportError) as t:
204        exp.export(test_pprof.TEST_EVENTS, 0, 1)
205    assert str(t.value) == (
206        "Datadog Agent is not accepting profiles. " "Agent-based profiling deployments require Datadog Agent >= 7.20"
207    )
208
209
210def test_export_404_agentless(endpoint_test_unknown_server):
211    exp = http.PprofHTTPExporter(_UNKNOWN_ENDPOINT, api_key="123", timeout=1)
212    with pytest.raises(exporter.ExportError) as t:
213        exp.export(test_pprof.TEST_EVENTS, 0, 1)
214    assert str(t.value) == "HTTP Error 404"
215
216
217def test_export_tracer_base_path(endpoint_test_server):
218    # Base path is prepended to the endpoint path because
219    # it does not start with a slash.
220    exp = http.PprofHTTPExporter(_ENDPOINT + "/profiling/", _API_KEY, endpoint_path="v1/input")
221    exp.export(test_pprof.TEST_EVENTS, 0, compat.time_ns())
222
223
224def test_export_tracer_base_path_agent_less(endpoint_test_server):
225    # Base path is ignored by the profiling HTTP exporter
226    # because the endpoint path starts with a slash.
227    exp = http.PprofHTTPExporter(_ENDPOINT + "/profiling/", _API_KEY, endpoint_path="/profiling/v1/input")
228    exp.export(test_pprof.TEST_EVENTS, 0, compat.time_ns())
229
230
231def _check_tags_types(tags):
232    for k, v in tags.items():
233        assert isinstance(k, str)
234        assert isinstance(v, bytes)
235
236
237def test_get_tags():
238    tags = http.PprofHTTPExporter(env="foobar", endpoint="")._get_tags("foobar")
239    _check_tags_types(tags)
240    assert len(tags) == 8
241    assert tags["service"] == b"foobar"
242    assert len(tags["host"])
243    assert len(tags["runtime-id"])
244    assert tags["language"] == b"python"
245    assert tags["env"] == b"foobar"
246    assert tags["runtime"] == b"CPython"
247    assert tags["profiler_version"] == ddtrace.__version__.encode("utf-8")
248    assert "version" not in tags
249
250
251def test_get_malformed(monkeypatch):
252    monkeypatch.setenv("DD_TAGS", "mytagfoobar")
253    tags = http.PprofHTTPExporter(endpoint="")._get_tags("foobar")
254    _check_tags_types(tags)
255    assert len(tags) == 7
256    assert tags["service"] == b"foobar"
257    assert len(tags["host"])
258    assert len(tags["runtime-id"])
259    assert tags["language"] == b"python"
260    assert tags["runtime"] == b"CPython"
261    assert tags["profiler_version"] == ddtrace.__version__.encode("utf-8")
262
263    monkeypatch.setenv("DD_TAGS", "mytagfoobar,")
264    tags = http.PprofHTTPExporter(endpoint="")._get_tags("foobar")
265    _check_tags_types(tags)
266    assert len(tags) == 7
267    assert tags["service"] == b"foobar"
268    assert len(tags["host"])
269    assert len(tags["runtime-id"])
270    assert tags["language"] == b"python"
271    assert tags["runtime"] == b"CPython"
272    assert tags["profiler_version"] == ddtrace.__version__.encode("utf-8")
273
274    monkeypatch.setenv("DD_TAGS", ",")
275    tags = http.PprofHTTPExporter(endpoint="")._get_tags("foobar")
276    _check_tags_types(tags)
277    assert len(tags) == 7
278    assert tags["service"] == b"foobar"
279    assert len(tags["host"])
280    assert len(tags["runtime-id"])
281    assert tags["language"] == b"python"
282    assert tags["runtime"] == b"CPython"
283    assert tags["profiler_version"] == ddtrace.__version__.encode("utf-8")
284
285    monkeypatch.setenv("DD_TAGS", "foo:bar,")
286    tags = http.PprofHTTPExporter(endpoint="")._get_tags("foobar")
287    _check_tags_types(tags)
288    assert len(tags) == 8
289    assert tags["service"] == b"foobar"
290    assert len(tags["host"])
291    assert len(tags["runtime-id"])
292    assert tags["language"] == b"python"
293    assert tags["runtime"] == b"CPython"
294    assert tags["foo"] == b"bar"
295    assert tags["profiler_version"] == ddtrace.__version__.encode("utf-8")
296
297
298def test_get_tags_override(monkeypatch):
299    monkeypatch.setenv("DD_TAGS", "mytag:foobar")
300    tags = http.PprofHTTPExporter(endpoint="")._get_tags("foobar")
301    _check_tags_types(tags)
302    assert len(tags) == 8
303    assert tags["service"] == b"foobar"
304    assert len(tags["host"])
305    assert len(tags["runtime-id"])
306    assert tags["language"] == b"python"
307    assert tags["runtime"] == b"CPython"
308    assert tags["mytag"] == b"foobar"
309    assert tags["profiler_version"] == ddtrace.__version__.encode("utf-8")
310    assert "version" not in tags
311
312    monkeypatch.setenv("DD_TAGS", "mytag:foobar,author:jd")
313    tags = http.PprofHTTPExporter(endpoint="")._get_tags("foobar")
314    _check_tags_types(tags)
315    assert len(tags) == 9
316    assert tags["service"] == b"foobar"
317    assert len(tags["host"])
318    assert len(tags["runtime-id"])
319    assert tags["language"] == b"python"
320    assert tags["runtime"] == b"CPython"
321    assert tags["mytag"] == b"foobar"
322    assert tags["author"] == b"jd"
323    assert tags["profiler_version"] == ddtrace.__version__.encode("utf-8")
324    assert "version" not in tags
325
326    monkeypatch.setenv("DD_TAGS", "")
327    tags = http.PprofHTTPExporter(endpoint="")._get_tags("foobar")
328    _check_tags_types(tags)
329    assert len(tags) == 7
330    assert tags["service"] == b"foobar"
331    assert len(tags["host"])
332    assert len(tags["runtime-id"])
333    assert tags["language"] == b"python"
334    assert tags["runtime"] == b"CPython"
335    assert tags["profiler_version"] == ddtrace.__version__.encode("utf-8")
336    assert "version" not in tags
337
338    monkeypatch.setenv("DD_TAGS", "foobar:baz,service:mycustomservice")
339    tags = http.PprofHTTPExporter(endpoint="")._get_tags("foobar")
340    _check_tags_types(tags)
341    assert len(tags) == 8
342    assert tags["service"] == b"mycustomservice"
343    assert len(tags["host"])
344    assert len(tags["runtime-id"])
345    assert tags["language"] == b"python"
346    assert tags["runtime"] == b"CPython"
347    assert tags["foobar"] == b"baz"
348    assert tags["profiler_version"] == ddtrace.__version__.encode("utf-8")
349    assert "version" not in tags
350
351    monkeypatch.setenv("DD_TAGS", "foobar:baz,service:��")
352    tags = http.PprofHTTPExporter(endpoint="")._get_tags("foobar")
353    _check_tags_types(tags)
354    assert len(tags) == 8
355    assert tags["service"] == u"��".encode("utf-8")
356    assert len(tags["host"])
357    assert len(tags["runtime-id"])
358    assert tags["language"] == b"python"
359    assert tags["runtime"] == b"CPython"
360    assert tags["foobar"] == b"baz"
361    assert tags["profiler_version"] == ddtrace.__version__.encode("utf-8")
362    assert "version" not in tags
363
364    tags = http.PprofHTTPExporter(endpoint="", version="123")._get_tags("foobar")
365    _check_tags_types(tags)
366    assert len(tags) == 9
367    assert tags["service"] == u"��".encode("utf-8")
368    assert len(tags["host"])
369    assert len(tags["runtime-id"])
370    assert tags["language"] == b"python"
371    assert tags["runtime"] == b"CPython"
372    assert tags["foobar"] == b"baz"
373    assert tags["profiler_version"] == ddtrace.__version__.encode("utf-8")
374    assert tags["version"] == b"123"
375    assert "env" not in tags
376
377    tags = http.PprofHTTPExporter(endpoint="", version="123", env="prod")._get_tags("foobar")
378    _check_tags_types(tags)
379    assert len(tags) == 10
380    assert tags["service"] == u"��".encode("utf-8")
381    assert len(tags["host"])
382    assert len(tags["runtime-id"])
383    assert tags["language"] == b"python"
384    assert tags["runtime"] == b"CPython"
385    assert tags["foobar"] == b"baz"
386    assert tags["profiler_version"] == ddtrace.__version__.encode("utf-8")
387    assert tags["version"] == b"123"
388    assert tags["env"] == b"prod"
389
390    tags = http.PprofHTTPExporter(endpoint="", version="123", env="prod", tags={"mytag": "123"})._get_tags("foobar")
391    _check_tags_types(tags)
392    assert len(tags) == 11
393    assert tags["service"] == u"��".encode("utf-8")
394    assert len(tags["host"])
395    assert len(tags["runtime-id"])
396    assert tags["language"] == b"python"
397    assert tags["runtime"] == b"CPython"
398    assert tags["foobar"] == b"baz"
399    assert tags["profiler_version"] == ddtrace.__version__.encode("utf-8")
400    assert tags["version"] == b"123"
401    assert tags["env"] == b"prod"
402    assert tags["mytag"] == b"123"
403
404
405def test_get_tags_legacy(monkeypatch):
406    monkeypatch.setenv("DD_PROFILING_TAGS", "mytag:baz")
407    tags = http.PprofHTTPExporter(endpoint="")._get_tags("foobar")
408    _check_tags_types(tags)
409    assert tags["mytag"] == b"baz"
410
411    # precedence
412    monkeypatch.setenv("DD_TAGS", "mytag:val1,ddtag:hi")
413    monkeypatch.setenv("DD_PROFILING_TAGS", "mytag:val2,ddptag:lo")
414    tags = http.PprofHTTPExporter(endpoint="")._get_tags("foobar")
415    _check_tags_types(tags)
416    assert tags["mytag"] == b"val2"
417    assert tags["ddtag"] == b"hi"
418    assert tags["ddptag"] == b"lo"
419