1import contextlib
2from contextlib import contextmanager
3import inspect
4import os
5import subprocess
6import sys
7from typing import List
8
9import attr
10import pytest
11
12import ddtrace
13from ddtrace import Span
14from ddtrace import Tracer
15from ddtrace.constants import SPAN_MEASURED_KEY
16from ddtrace.ext import http
17from ddtrace.internal.compat import httplib
18from ddtrace.internal.compat import parse
19from ddtrace.internal.compat import to_unicode
20from ddtrace.internal.encoding import JSONEncoder
21from ddtrace.internal.encoding import MsgpackEncoderV03 as Encoder
22from ddtrace.internal.utils.formats import parse_tags_str
23from ddtrace.internal.writer import AgentWriter
24from ddtrace.vendor import wrapt
25from tests.subprocesstest import SubprocessTestCase
26
27
28NO_CHILDREN = object()
29
30
31def assert_is_measured(span):
32    """Assert that the span has the proper _dd.measured tag set"""
33    assert SPAN_MEASURED_KEY in span.metrics
34    assert SPAN_MEASURED_KEY not in span.meta
35    assert span.get_metric(SPAN_MEASURED_KEY) == 1
36
37
38def assert_is_not_measured(span):
39    """Assert that the span does not set _dd.measured"""
40    assert SPAN_MEASURED_KEY not in span.meta
41    if SPAN_MEASURED_KEY in span.metrics:
42        assert span.get_metric(SPAN_MEASURED_KEY) == 0
43    else:
44        assert SPAN_MEASURED_KEY not in span.metrics
45
46
47def assert_span_http_status_code(span, code):
48    """Assert on the span's 'http.status_code' tag"""
49    tag = span.get_tag(http.STATUS_CODE)
50    code = str(code)
51    assert tag == code, "%r != %r" % (tag, code)
52
53
54@contextlib.contextmanager
55def override_env(env):
56    """
57    Temporarily override ``os.environ`` with provided values::
58
59        >>> with self.override_env(dict(DD_TRACE_DEBUG=True)):
60            # Your test
61    """
62    # Copy the full original environment
63    original = dict(os.environ)
64
65    # Update based on the passed in arguments
66    os.environ.update(env)
67    try:
68        yield
69    finally:
70        # Full clear the environment out and reset back to the original
71        os.environ.clear()
72        os.environ.update(original)
73
74
75@contextlib.contextmanager
76def override_global_config(values):
77    """
78    Temporarily override an global configuration::
79
80        >>> with self.override_global_config(dict(name=value,...)):
81            # Your test
82    """
83    # List of global variables we allow overriding
84    # DEV: We do not do `ddtrace.config.keys()` because we have all of our integrations
85    global_config_keys = [
86        "analytics_enabled",
87        "report_hostname",
88        "health_metrics_enabled",
89        "env",
90        "version",
91        "service",
92        "_raise",
93    ]
94
95    # Grab the current values of all keys
96    originals = dict((key, getattr(ddtrace.config, key)) for key in global_config_keys)
97
98    # Override from the passed in keys
99    for key, value in values.items():
100        if key in global_config_keys:
101            setattr(ddtrace.config, key, value)
102    try:
103        yield
104    finally:
105        # Reset all to their original values
106        for key, value in originals.items():
107            setattr(ddtrace.config, key, value)
108
109
110@contextlib.contextmanager
111def override_config(integration, values):
112    """
113    Temporarily override an integration configuration value::
114
115        >>> with self.override_config('flask', dict(service_name='test-service')):
116            # Your test
117    """
118    options = getattr(ddtrace.config, integration)
119
120    original = dict((key, options.get(key)) for key in values.keys())
121
122    options.update(values)
123    try:
124        yield
125    finally:
126        options.update(original)
127
128
129@contextlib.contextmanager
130def override_http_config(integration, values):
131    """
132    Temporarily override an integration configuration for HTTP value::
133
134        >>> with self.override_http_config('flask', dict(trace_query_string=True)):
135            # Your test
136    """
137    options = getattr(ddtrace.config, integration).http
138
139    original = {}
140    for key, value in values.items():
141        original[key] = getattr(options, key)
142        setattr(options, key, value)
143
144    try:
145        yield
146    finally:
147        for key, value in original.items():
148            setattr(options, key, value)
149
150
151@contextlib.contextmanager
152def override_sys_modules(modules):
153    """
154    Temporarily override ``sys.modules`` with provided dictionary of modules::
155
156        >>> mock_module = mock.MagicMock()
157        >>> mock_module.fn.side_effect = lambda: 'test'
158        >>> with self.override_sys_modules(dict(A=mock_module)):
159            # Your test
160    """
161    original = dict(sys.modules)
162
163    sys.modules.update(modules)
164    try:
165        yield
166    finally:
167        sys.modules.clear()
168        sys.modules.update(original)
169
170
171class BaseTestCase(SubprocessTestCase):
172    """
173    BaseTestCase extends ``unittest.TestCase`` to provide some useful helpers/assertions
174
175
176    Example::
177
178        from tests.utils import BaseTestCase
179
180
181        class MyTestCase(BaseTestCase):
182            def test_case(self):
183                with self.override_config('flask', dict(distributed_tracing_enabled=True):
184                    pass
185    """
186
187    override_env = staticmethod(override_env)
188    override_global_config = staticmethod(override_global_config)
189    override_config = staticmethod(override_config)
190    override_http_config = staticmethod(override_http_config)
191    override_sys_modules = staticmethod(override_sys_modules)
192    assert_is_measured = staticmethod(assert_is_measured)
193    assert_is_not_measured = staticmethod(assert_is_not_measured)
194
195
196def _build_tree(
197    spans,  # type: List[Span]
198    root,  # type: Span
199):
200    # type: (...) -> TestSpanNode
201    """helper to build a tree structure for the provided root span"""
202    children = []
203    for span in spans:
204        if span.parent_id == root.span_id:
205            children.append(_build_tree(spans, span))
206
207    return TestSpanNode(root, children)
208
209
210def get_root_span(
211    spans,  # type: List[Span]
212):
213    # type: (...) -> TestSpanNode
214    """
215    Helper to get the root span from the list of spans in this container
216
217    :returns: The root span if one was found, None if not, and AssertionError if multiple roots were found
218    :rtype: :class:`tests.utils.span.TestSpanNode`, None
219    :raises: AssertionError
220    """
221    root = None
222    for span in spans:
223        if span.parent_id is None:
224            if root is not None:
225                raise AssertionError("Multiple root spans found {0!r} {1!r}".format(root, span))
226            root = span
227
228    assert root, "No root span found in {0!r}".format(spans)
229
230    return _build_tree(spans, root)
231
232
233class TestSpanContainer(object):
234    """
235    Helper class for a container of Spans.
236
237    Subclasses of this class must implement a `get_spans` method::
238
239        def get_spans(self):
240            return []
241
242    This class provides methods and assertions over a list of spans::
243
244        class TestCases(BaseTracerTestCase):
245            def test_spans(self):
246                # TODO: Create spans
247
248                self.assert_has_spans()
249                self.assert_span_count(3)
250                self.assert_structure( ... )
251
252                # Grab only the `requests.request` spans
253                spans = self.filter_spans(name='requests.request')
254    """
255
256    def _ensure_test_spans(self, spans):
257        """
258        internal helper to ensure the list of spans are all :class:`tests.utils.span.TestSpan`
259
260        :param spans: List of :class:`ddtrace.span.Span` or :class:`tests.utils.span.TestSpan`
261        :type spans: list
262        :returns: A list og :class:`tests.utils.span.TestSpan`
263        :rtype: list
264        """
265        return [span if isinstance(span, TestSpan) else TestSpan(span) for span in spans]
266
267    @property
268    def spans(self):
269        return self._ensure_test_spans(self.get_spans())
270
271    def get_spans(self):
272        """subclass required property"""
273        raise NotImplementedError
274
275    def get_root_span(self):
276        # type: (...) -> TestSpanNode
277        """
278        Helper to get the root span from the list of spans in this container
279
280        :returns: The root span if one was found, None if not, and AssertionError if multiple roots were found
281        :rtype: :class:`tests.utils.span.TestSpanNode`, None
282        :raises: AssertionError
283        """
284        return get_root_span(self.spans)
285
286    def get_root_spans(self):
287        # type: (...) -> List[Span]
288        """
289        Helper to get all root spans from the list of spans in this container
290
291        :returns: The root spans if any were found, None if not
292        :rtype: list of :class:`tests.utils.span.TestSpanNode`, None
293        """
294        roots = []
295        for span in self.spans:
296            if span.parent_id is None:
297                roots.append(_build_tree(self.spans, span))
298
299        return sorted(roots, key=lambda s: s.start)
300
301    def assert_trace_count(self, count):
302        """Assert the number of unique trace ids this container has"""
303        trace_count = len(self.get_root_spans())
304        assert trace_count == count, "Trace count {0} != {1}".format(trace_count, count)
305
306    def assert_span_count(self, count):
307        """Assert this container has the expected number of spans"""
308        assert len(self.spans) == count, "Span count {0} != {1}".format(len(self.spans), count)
309
310    def assert_has_spans(self):
311        """Assert this container has spans"""
312        assert len(self.spans), "No spans found"
313
314    def assert_has_no_spans(self):
315        """Assert this container does not have any spans"""
316        assert len(self.spans) == 0, "Span count {0}".format(len(self.spans))
317
318    def filter_spans(self, *args, **kwargs):
319        """
320        Helper to filter current spans by provided parameters.
321
322        This function will yield all spans whose `TestSpan.matches` function return `True`.
323
324        :param args: Positional arguments to pass to :meth:`tests.utils.span.TestSpan.matches`
325        :type args: list
326        :param kwargs: Keyword arguments to pass to :meth:`tests.utils.span.TestSpan.matches`
327        :type kwargs: dict
328        :returns: generator for the matched :class:`tests.utils.span.TestSpan`
329        :rtype: generator
330        """
331        for span in self.spans:
332            # ensure we have a TestSpan
333            if not isinstance(span, TestSpan):
334                span = TestSpan(span)
335
336            if span.matches(*args, **kwargs):
337                yield span
338
339    def find_span(self, *args, **kwargs):
340        """
341        Find a single span matches the provided filter parameters.
342
343        This function will find the first span whose `TestSpan.matches` function return `True`.
344
345        :param args: Positional arguments to pass to :meth:`tests.utils.span.TestSpan.matches`
346        :type args: list
347        :param kwargs: Keyword arguments to pass to :meth:`tests.utils.span.TestSpan.matches`
348        :type kwargs: dict
349        :returns: The first matching span
350        :rtype: :class:`tests.TestSpan`
351        """
352        span = next(self.filter_spans(*args, **kwargs), None)
353        assert span is not None, "No span found for filter {0!r} {1!r}, have {2} spans".format(
354            args, kwargs, len(self.spans)
355        )
356        return span
357
358
359class TracerTestCase(TestSpanContainer, BaseTestCase):
360    """
361    BaseTracerTestCase is a base test case for when you need access to a dummy tracer and span assertions
362    """
363
364    def setUp(self):
365        """Before each test case, setup a dummy tracer to use"""
366        self.tracer = DummyTracer()
367
368        super(TracerTestCase, self).setUp()
369
370    def tearDown(self):
371        """After each test case, reset and remove the dummy tracer"""
372        super(TracerTestCase, self).tearDown()
373
374        self.reset()
375        delattr(self, "tracer")
376
377    def get_spans(self):
378        """Required subclass method for TestSpanContainer"""
379        return self.tracer.writer.spans
380
381    def pop_spans(self):
382        # type: () -> List[Span]
383        return self.tracer.pop()
384
385    def pop_traces(self):
386        # type: () -> List[List[Span]]
387        return self.tracer.pop_traces()
388
389    def reset(self):
390        """Helper to reset the existing list of spans created"""
391        self.tracer.writer.pop()
392
393    def trace(self, *args, **kwargs):
394        """Wrapper for self.tracer.trace that returns a TestSpan"""
395        return TestSpan(self.tracer.trace(*args, **kwargs))
396
397    def start_span(self, *args, **kwargs):
398        """Helper for self.tracer.start_span that returns a TestSpan"""
399        return TestSpan(self.tracer.start_span(*args, **kwargs))
400
401    def assert_structure(self, root, children=NO_CHILDREN):
402        """Helper to call TestSpanNode.assert_structure on the current root span"""
403        root_span = self.get_root_span()
404        root_span.assert_structure(root, children)
405
406    @contextlib.contextmanager
407    def override_global_tracer(self, tracer=None):
408        original = ddtrace.tracer
409        tracer = tracer or self.tracer
410        setattr(ddtrace, "tracer", tracer)
411        try:
412            yield
413        finally:
414            setattr(ddtrace, "tracer", original)
415
416
417class DummyWriter(AgentWriter):
418    """DummyWriter is a small fake writer used for tests. not thread-safe."""
419
420    def __init__(self, *args, **kwargs):
421        # original call
422        if len(args) == 0 and "agent_url" not in kwargs:
423            kwargs["agent_url"] = "http://localhost:8126"
424
425        super(DummyWriter, self).__init__(*args, **kwargs)
426        self.spans = []
427        self.traces = []
428        self.json_encoder = JSONEncoder()
429        self.msgpack_encoder = Encoder(4 << 20, 4 << 20)
430
431    def write(self, spans=None):
432        if spans:
433            # the traces encoding expect a list of traces so we
434            # put spans in a list like we do in the real execution path
435            # with both encoders
436            traces = [spans]
437            self.json_encoder.encode_traces(traces)
438            self.msgpack_encoder.put(spans)
439            self.msgpack_encoder.encode()
440            self.spans += spans
441            self.traces += traces
442
443    def pop(self):
444        # type: () -> List[Span]
445        s = self.spans
446        self.spans = []
447        return s
448
449    def pop_traces(self):
450        # type: () -> List[List[Span]]
451        traces = self.traces
452        self.traces = []
453        return traces
454
455
456class DummyTracer(Tracer):
457    """
458    DummyTracer is a tracer which uses the DummyWriter by default
459    """
460
461    def __init__(self):
462        super(DummyTracer, self).__init__()
463        self.configure()
464
465    def pop(self):
466        # type: () -> List[Span]
467        return self.writer.pop()
468
469    def pop_traces(self):
470        # type: () -> List[List[Span]]
471        return self.writer.pop_traces()
472
473    def configure(self, *args, **kwargs):
474        assert "writer" not in kwargs or isinstance(
475            kwargs["writer"], DummyWriter
476        ), "cannot configure writer of DummyTracer"
477        kwargs["writer"] = DummyWriter()
478        super(DummyTracer, self).configure(*args, **kwargs)
479
480
481class TestSpan(Span):
482    """
483    Test wrapper for a :class:`ddtrace.span.Span` that provides additional functions and assertions
484
485    Example::
486
487        span = tracer.trace('my.span')
488        span = TestSpan(span)
489
490        if span.matches(name='my.span'):
491            print('matches')
492
493        # Raises an AssertionError
494        span.assert_matches(name='not.my.span', meta={'system.pid': getpid()})
495    """
496
497    def __init__(self, span):
498        """
499        Constructor for TestSpan
500
501        :param span: The :class:`ddtrace.span.Span` to wrap
502        :type span: :class:`ddtrace.span.Span`
503        """
504        if isinstance(span, TestSpan):
505            span = span._span
506
507        # DEV: Use `object.__setattr__` to by-pass this class's `__setattr__`
508        object.__setattr__(self, "_span", span)
509
510    def __getattr__(self, key):
511        """
512        First look for property on the base :class:`ddtrace.span.Span` otherwise return this object's attribute
513        """
514        if hasattr(self._span, key):
515            return getattr(self._span, key)
516
517        return self.__getattribute__(key)
518
519    def __setattr__(self, key, value):
520        """Pass through all assignment to the base :class:`ddtrace.span.Span`"""
521        return setattr(self._span, key, value)
522
523    def __eq__(self, other):
524        """
525        Custom equality code to ensure we are using the base :class:`ddtrace.span.Span.__eq__`
526
527        :param other: The object to check equality with
528        :type other: object
529        :returns: True if equal, False otherwise
530        :rtype: bool
531        """
532        if isinstance(other, TestSpan):
533            return other._span == self._span
534        elif isinstance(other, Span):
535            return other == self._span
536        return other == self
537
538    def matches(self, **kwargs):
539        """
540        Helper function to check if this span's properties matches the expected.
541
542        Example::
543
544            span = TestSpan(span)
545            span.matches(name='my.span', resource='GET /')
546
547        :param kwargs: Property/Value pairs to evaluate on this span
548        :type kwargs: dict
549        :returns: True if the arguments passed match, False otherwise
550        :rtype: bool
551        """
552        for name, value in kwargs.items():
553            # Special case for `meta`
554            if name == "meta" and not self.meta_matches(value):
555                return False
556
557            # Ensure it has the property first
558            if not hasattr(self, name):
559                return False
560
561            # Ensure the values match
562            if getattr(self, name) != value:
563                return False
564
565        return True
566
567    def meta_matches(self, meta, exact=False):
568        """
569        Helper function to check if this span's meta matches the expected
570
571        Example::
572
573            span = TestSpan(span)
574            span.meta_matches({'system.pid': getpid()})
575
576        :param meta: Property/Value pairs to evaluate on this span
577        :type meta: dict
578        :param exact: Whether to do an exact match on the meta values or not, default: False
579        :type exact: bool
580        :returns: True if the arguments passed match, False otherwise
581        :rtype: bool
582        """
583        if exact:
584            return self.meta == meta
585
586        for key, value in meta.items():
587            if key not in self.meta:
588                return False
589            if self.meta[key] != value:
590                return False
591        return True
592
593    def assert_matches(self, **kwargs):
594        """
595        Assertion method to ensure this span's properties match as expected
596
597        Example::
598
599            span = TestSpan(span)
600            span.assert_matches(name='my.span')
601
602        :param kwargs: Property/Value pairs to evaluate on this span
603        :type kwargs: dict
604        :raises: AssertionError
605        """
606        for name, value in kwargs.items():
607            # Special case for `meta`
608            if name == "meta":
609                self.assert_meta(value)
610            elif name == "metrics":
611                self.assert_metrics(value)
612            else:
613                assert hasattr(self, name), "{0!r} does not have property {1!r}".format(self, name)
614                assert getattr(self, name) == value, "{0!r} property {1}: {2!r} != {3!r}".format(
615                    self, name, getattr(self, name), value
616                )
617
618    def assert_meta(self, meta, exact=False):
619        """
620        Assertion method to ensure this span's meta match as expected
621
622        Example::
623
624            span = TestSpan(span)
625            span.assert_meta({'system.pid': getpid()})
626
627        :param meta: Property/Value pairs to evaluate on this span
628        :type meta: dict
629        :param exact: Whether to do an exact match on the meta values or not, default: False
630        :type exact: bool
631        :raises: AssertionError
632        """
633        if exact:
634            assert self.meta == meta
635        else:
636            for key, value in meta.items():
637                assert key in self.meta, "{0} meta does not have property {1!r}".format(self, key)
638                assert self.meta[key] == value, "{0} meta property {1!r}: {2!r} != {3!r}".format(
639                    self, key, self.meta[key], value
640                )
641
642    def assert_metrics(self, metrics, exact=False):
643        """
644        Assertion method to ensure this span's metrics match as expected
645
646        Example::
647
648            span = TestSpan(span)
649            span.assert_metrics({'_dd1.sr.eausr': 1})
650
651        :param metrics: Property/Value pairs to evaluate on this span
652        :type metrics: dict
653        :param exact: Whether to do an exact match on the metrics values or not, default: False
654        :type exact: bool
655        :raises: AssertionError
656        """
657        if exact:
658            assert self.metrics == metrics
659        else:
660            for key, value in metrics.items():
661                assert key in self.metrics, "{0} metrics does not have property {1!r}".format(self, key)
662                assert self.metrics[key] == value, "{0} metrics property {1!r}: {2!r} != {3!r}".format(
663                    self, key, self.metrics[key], value
664                )
665
666
667class TracerSpanContainer(TestSpanContainer):
668    """
669    A class to wrap a :class:`tests.utils.tracer.DummyTracer` with a
670    :class:`tests.utils.span.TestSpanContainer` to use in tests
671    """
672
673    def __init__(self, tracer):
674        self.tracer = tracer
675        super(TracerSpanContainer, self).__init__()
676
677    def get_spans(self):
678        """
679        Overridden method to return all spans attached to this tracer
680
681        :returns: List of spans attached to this tracer
682        :rtype: list
683        """
684        return self.tracer.writer.spans
685
686    def pop(self):
687        return self.tracer.pop()
688
689    def pop_traces(self):
690        return self.tracer.pop_traces()
691
692    def reset(self):
693        """Helper to reset the existing list of spans created"""
694        self.tracer.pop()
695
696
697class TestSpanNode(TestSpan, TestSpanContainer):
698    """
699    A :class:`tests.utils.span.TestSpan` which is used as part of a span tree.
700
701    Each :class:`tests.utils.span.TestSpanNode` represents the current :class:`ddtrace.span.Span`
702    along with any children who have that span as it's parent.
703
704    This class can be used to assert on the parent/child relationships between spans.
705
706    Example::
707
708        class TestCase(BaseTestCase):
709            def test_case(self):
710                # TODO: Create spans
711
712                self.assert_structure( ... )
713
714                tree = self.get_root_span()
715
716                # Find the first child of the root span with the matching name
717                request = tree.find_span(name='requests.request')
718
719                # Assert the parent/child relationship of this `request` span
720                request.assert_structure( ... )
721    """
722
723    def __init__(self, root, children=None):
724        super(TestSpanNode, self).__init__(root)
725        object.__setattr__(self, "_children", children or [])
726
727    def get_spans(self):
728        """required subclass property, returns this spans children"""
729        return self._children
730
731    def assert_structure(self, root, children=NO_CHILDREN):
732        """
733        Assertion to assert on the structure of this node and it's children.
734
735        This assertion takes a dictionary of properties to assert for this node
736        along with a list of assertions to make for it's children.
737
738        Example::
739
740            def test_case(self):
741                # Assert the following structure
742                #
743                # One root_span, with two child_spans, one with a requests.request span
744                #
745                # |                  root_span                |
746                # |       child_span       | |   child_span   |
747                # | requests.request |
748                self.assert_structure(
749                    # Root span with two child_span spans
750                    dict(name='root_span'),
751
752                    (
753                        # Child span with one child of it's own
754                        (
755                            dict(name='child_span'),
756
757                            # One requests.request span with no children
758                            (
759                                dict(name='requests.request'),
760                            ),
761                        ),
762
763                        # Child span with no children
764                        dict(name='child_span'),
765                    ),
766                )
767
768        :param root: Properties to assert for this root span, these are passed to
769            :meth:`tests.utils.span.TestSpan.assert_matches`
770        :type root: dict
771        :param children: List of child assertions to make, if children is None then do not make any
772            assertions about this nodes children. Each list element must be a list with 2 items
773            the first is a ``dict`` of property assertions on that child, and the second is a ``list``
774            of child assertions to make.
775        :type children: list, None
776        :raises:
777        """
778        self.assert_matches(**root)
779
780        # Give them a way to ignore asserting on children
781        if children is None:
782            return
783        elif children is NO_CHILDREN:
784            children = ()
785
786        spans = self.spans
787        self.assert_span_count(len(children))
788        for i, child in enumerate(children):
789            if not isinstance(child, (list, tuple)):
790                child = (child, NO_CHILDREN)
791
792            root, _children = child
793            spans[i].assert_matches(parent_id=self.span_id, trace_id=self.trace_id, _parent=self)
794            spans[i].assert_structure(root, _children)
795
796    def pprint(self):
797        parts = [super(TestSpanNode, self).pprint()]
798        for child in self._children:
799            parts.append("-" * 20)
800            parts.append(child.pprint())
801        return "\r\n".join(parts)
802
803
804def assert_dict_issuperset(a, b):
805    assert set(a.items()).issuperset(set(b.items())), "{a} is not a superset of {b}".format(a=a, b=b)
806
807
808@contextmanager
809def override_global_tracer(tracer):
810    """Helper functions that overrides the global tracer available in the
811    `ddtrace` package. This is required because in some `httplib` tests we
812    can't get easily the PIN object attached to the `HTTPConnection` to
813    replace the used tracer with a dummy tracer.
814    """
815    original_tracer = ddtrace.tracer
816    ddtrace.tracer = tracer
817    yield
818    ddtrace.tracer = original_tracer
819
820
821class SnapshotFailed(Exception):
822    pass
823
824
825@attr.s
826class SnapshotTest(object):
827    token = attr.ib(type=str)
828    tracer = attr.ib(type=ddtrace.Tracer, default=ddtrace.tracer)
829
830    def clear(self):
831        """Clear any traces sent that were sent for this snapshot."""
832        parsed = parse.urlparse(self.tracer.writer.agent_url)
833        conn = httplib.HTTPConnection(parsed.hostname, parsed.port)
834        conn.request("GET", "/test/session/clear?test_session_token=%s" % self.token)
835        resp = conn.getresponse()
836        assert resp.status == 200
837
838
839@contextmanager
840def snapshot_context(token, ignores=None, tracer=None, async_mode=True, variants=None):
841    # Use variant that applies to update test token. One must apply. If none
842    # apply, the test should have been marked as skipped.
843    if variants:
844        applicable_variant_ids = [k for (k, v) in variants.items() if v]
845        assert len(applicable_variant_ids) == 1
846        variant_id = applicable_variant_ids[0]
847        token = "{}_{}".format(token, variant_id) if variant_id else token
848
849    ignores = ignores or []
850    if not tracer:
851        tracer = ddtrace.tracer
852
853    parsed = parse.urlparse(tracer.writer.agent_url)
854    conn = httplib.HTTPConnection(parsed.hostname, parsed.port)
855    try:
856        # clear queue in case traces have been generated before test case is
857        # itself run
858        try:
859            tracer.writer.flush_queue()
860        except Exception as e:
861            pytest.fail("Could not flush the queue before test case: %s" % str(e), pytrace=True)
862
863        if async_mode:
864            # Patch the tracer writer to include the test token header for all requests.
865            tracer.writer._headers["X-Datadog-Test-Session-Token"] = token
866
867            # Also add a header to the environment for subprocesses test cases that might use snapshotting.
868            existing_headers = parse_tags_str(os.environ.get("_DD_TRACE_WRITER_ADDITIONAL_HEADERS", ""))
869            existing_headers.update({"X-Datadog-Test-Session-Token": token})
870            os.environ["_DD_TRACE_WRITER_ADDITIONAL_HEADERS"] = ",".join(
871                ["%s:%s" % (k, v) for k, v in existing_headers.items()]
872            )
873
874        try:
875            conn.request("GET", "/test/session/start?test_session_token=%s" % token)
876        except Exception as e:
877            pytest.fail("Could not connect to test agent: %s" % str(e), pytrace=False)
878        else:
879            r = conn.getresponse()
880            if r.status != 200:
881                # The test agent returns nice error messages we can forward to the user.
882                pytest.fail(to_unicode(r.read()), pytrace=False)
883
884        try:
885            yield SnapshotTest(
886                tracer=tracer,
887                token=token,
888            )
889        finally:
890            # Force a flush so all traces are submitted.
891            tracer.writer.flush_queue()
892            if async_mode:
893                del tracer.writer._headers["X-Datadog-Test-Session-Token"]
894                del os.environ["_DD_TRACE_WRITER_ADDITIONAL_HEADERS"]
895
896        # Query for the results of the test.
897        conn = httplib.HTTPConnection(parsed.hostname, parsed.port)
898        conn.request("GET", "/test/session/snapshot?ignores=%s&test_session_token=%s" % (",".join(ignores), token))
899        r = conn.getresponse()
900        if r.status != 200:
901            pytest.fail(to_unicode(r.read()), pytrace=False)
902    except Exception as e:
903        # Even though it's unlikely any traces have been sent, make the
904        # final request to the test agent so that the test case is finished.
905        conn = httplib.HTTPConnection(parsed.hostname, parsed.port)
906        conn.request("GET", "/test/session/snapshot?ignores=%s&test_session_token=%s" % (",".join(ignores), token))
907        conn.getresponse()
908        pytest.fail("Unexpected test failure during snapshot test: %s" % str(e), pytrace=True)
909    finally:
910        conn.close()
911
912
913def snapshot(ignores=None, include_tracer=False, variants=None, async_mode=True, token_override=None):
914    """Performs a snapshot integration test with the testing agent.
915
916    All traces sent to the agent will be recorded and compared to a snapshot
917    created for the test case.
918
919    :param ignores: A list of keys to ignore when comparing snapshots. To refer
920                    to keys in the meta or metrics maps use "meta.key" and
921                    "metrics.key"
922    :param tracer: A tracer providing the agent connection information to use.
923    """
924    ignores = ignores or []
925
926    @wrapt.decorator
927    def wrapper(wrapped, instance, args, kwargs):
928        if len(args) > 1:
929            self = args[0]
930            clsname = self.__class__.__name__
931        else:
932            clsname = ""
933
934        if include_tracer:
935            tracer = Tracer()
936        else:
937            tracer = ddtrace.tracer
938
939        module = inspect.getmodule(wrapped)
940
941        # Use the fully qualified function name as a unique test token to
942        # identify the snapshot.
943        token = (
944            "{}{}{}.{}".format(module.__name__, "." if clsname else "", clsname, wrapped.__name__)
945            if token_override is None
946            else token_override
947        )
948
949        with snapshot_context(token, ignores=ignores, tracer=tracer, async_mode=async_mode, variants=variants):
950            # Run the test.
951            if include_tracer:
952                kwargs["tracer"] = tracer
953            return wrapped(*args, **kwargs)
954
955    return wrapper
956
957
958class AnyStr(object):
959    def __eq__(self, other):
960        return isinstance(other, str)
961
962
963class AnyInt(object):
964    def __eq__(self, other):
965        return isinstance(other, int)
966
967
968class AnyFloat(object):
969    def __eq__(self, other):
970        return isinstance(other, float)
971
972
973def call_program(*args, **kwargs):
974    subp = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True, **kwargs)
975    stdout, stderr = subp.communicate()
976    return stdout, stderr, subp.wait(), subp.pid
977