1import asyncio
2
3from opentracing.scope_managers.asyncio import AsyncioScopeManager
4import pytest
5
6import ddtrace
7from ddtrace.contrib.asyncio import context_provider
8from ddtrace.internal.compat import CONTEXTVARS_IS_AVAILABLE
9from ddtrace.opentracer.utils import get_context_provider_for_scope_manager
10
11
12@pytest.mark.asyncio
13def test_trace_coroutine(test_spans):
14    # it should use the task context when invoked in a coroutine
15    with test_spans.tracer.start_span("coroutine"):
16        pass
17
18    traces = test_spans.pop_traces()
19
20    assert len(traces) == 1
21    assert len(traces[0]) == 1
22    assert traces[0][0].name == "coroutine"
23
24
25@pytest.mark.asyncio
26async def test_trace_multiple_coroutines(ot_tracer, test_spans):
27    # if multiple coroutines have nested tracing, they must belong
28    # to the same trace
29
30    async def coro():
31        # another traced coroutine
32        with ot_tracer.start_active_span("coroutine_2"):
33            return 42
34
35    with ot_tracer.start_active_span("coroutine_1"):
36        value = await coro()
37
38    # the coroutine has been called correctly
39    assert value == 42
40    # a single trace has been properly reported
41    traces = test_spans.pop_traces()
42    assert len(traces) == 1
43    assert len(traces[0]) == 2
44    assert traces[0][0].name == "coroutine_1"
45    assert traces[0][1].name == "coroutine_2"
46    # the parenting is correct
47    assert traces[0][0] == traces[0][1]._parent
48    assert traces[0][0].trace_id == traces[0][1].trace_id
49
50
51@pytest.mark.asyncio
52async def test_exception(ot_tracer, test_spans):
53    async def f1():
54        with ot_tracer.start_span("f1"):
55            raise Exception("f1 error")
56
57    with pytest.raises(Exception):
58        await f1()
59
60    traces = test_spans.pop_traces()
61    assert len(traces) == 1
62    spans = traces[0]
63    assert len(spans) == 1
64    span = spans[0]
65    assert span.error == 1
66    assert span.get_tag("error.msg") == "f1 error"
67    assert "Exception: f1 error" in span.get_tag("error.stack")
68
69
70@pytest.mark.asyncio
71async def test_trace_multiple_calls(ot_tracer, test_spans):
72    ot_tracer._dd_tracer.configure(context_provider=context_provider)
73
74    # create multiple futures so that we expect multiple
75    # traces instead of a single one (helper not used)
76    async def coro():
77        # another traced coroutine
78        with ot_tracer.start_span("coroutine"):
79            await asyncio.sleep(0.01)
80
81    futures = [asyncio.ensure_future(coro()) for x in range(10)]
82    for future in futures:
83        await future
84
85    traces = test_spans.pop_traces()
86
87    assert len(traces) == 10
88    assert len(traces[0]) == 1
89    assert traces[0][0].name == "coroutine"
90
91
92@pytest.mark.asyncio
93async def test_trace_multiple_coroutines_ot_dd(ot_tracer):
94    """
95    Ensure we can trace from opentracer to ddtracer across asyncio
96    context switches.
97    """
98    # if multiple coroutines have nested tracing, they must belong
99    # to the same trace
100    async def coro():
101        # another traced coroutine
102        with ot_tracer._dd_tracer.trace("coroutine_2"):
103            return 42
104
105    with ot_tracer.start_active_span("coroutine_1"):
106        value = await coro()
107
108    # the coroutine has been called correctly
109    assert value == 42
110    # a single trace has been properly reported
111    traces = ot_tracer._dd_tracer.pop_traces()
112    assert len(traces) == 1
113    assert len(traces[0]) == 2
114    assert traces[0][0].name == "coroutine_1"
115    assert traces[0][1].name == "coroutine_2"
116    # the parenting is correct
117    assert traces[0][0] == traces[0][1]._parent
118    assert traces[0][0].trace_id == traces[0][1].trace_id
119
120
121@pytest.mark.asyncio
122async def test_trace_multiple_coroutines_dd_ot(ot_tracer):
123    """
124    Ensure we can trace from ddtracer to opentracer across asyncio
125    context switches.
126    """
127    # if multiple coroutines have nested tracing, they must belong
128    # to the same trace
129    async def coro():
130        # another traced coroutine
131        with ot_tracer.start_span("coroutine_2"):
132            return 42
133
134    with ot_tracer._dd_tracer.trace("coroutine_1"):
135        value = await coro()
136
137    # the coroutine has been called correctly
138    assert value == 42
139    # a single trace has been properly reported
140    traces = ot_tracer._dd_tracer.pop_traces()
141    assert len(traces) == 1
142    assert len(traces[0]) == 2
143    assert traces[0][0].name == "coroutine_1"
144    assert traces[0][1].name == "coroutine_2"
145    # the parenting is correct
146    assert traces[0][0] == traces[0][1]._parent
147    assert traces[0][0].trace_id == traces[0][1].trace_id
148
149
150@pytest.mark.skipif(CONTEXTVARS_IS_AVAILABLE, reason="only applicable to legacy asyncio provider")
151def test_get_context_provider_for_scope_manager_asyncio():
152    scope_manager = AsyncioScopeManager()
153    ctx_prov = get_context_provider_for_scope_manager(scope_manager)
154    assert isinstance(ctx_prov, ddtrace.contrib.asyncio.provider.AsyncioContextProvider)
155
156
157@pytest.mark.skipif(CONTEXTVARS_IS_AVAILABLE, reason="only applicable to legacy asyncio provider")
158def test_tracer_context_provider_config():
159    tracer = ddtrace.opentracer.Tracer("mysvc", scope_manager=AsyncioScopeManager())
160    assert isinstance(
161        tracer._dd_tracer.context_provider,
162        ddtrace.contrib.asyncio.provider.AsyncioContextProvider,
163    )
164