1from importlib import import_module
2
3from ddtrace import config
4from ddtrace.contrib.trace_utils import ext_service
5from ddtrace.vendor.wrapt import wrap_function_wrapper as _w
6
7from ...constants import ANALYTICS_SAMPLE_RATE_KEY
8from ...constants import SPAN_MEASURED_KEY
9from ...ext import SpanTypes
10from ...ext import elasticsearch as metadata
11from ...ext import http
12from ...internal.compat import urlencode
13from ...internal.utils.wrappers import unwrap as _u
14from ...pin import Pin
15from .quantize import quantize
16
17
18config._add(
19    "elasticsearch",
20    {
21        "_default_service": "elasticsearch",
22    },
23)
24
25
26def _es_modules():
27    module_names = (
28        "elasticsearch",
29        "elasticsearch1",
30        "elasticsearch2",
31        "elasticsearch5",
32        "elasticsearch6",
33        "elasticsearch7",
34    )
35    for module_name in module_names:
36        try:
37            yield import_module(module_name)
38        except ImportError:
39            pass
40
41
42# NB: We are patching the default elasticsearch.transport module
43def patch():
44    for elasticsearch in _es_modules():
45        _patch(elasticsearch)
46
47
48def _patch(elasticsearch):
49    if getattr(elasticsearch, "_datadog_patch", False):
50        return
51    setattr(elasticsearch, "_datadog_patch", True)
52    _w(elasticsearch.transport, "Transport.perform_request", _get_perform_request(elasticsearch))
53    Pin().onto(elasticsearch.transport.Transport)
54
55
56def unpatch():
57    for elasticsearch in _es_modules():
58        _unpatch(elasticsearch)
59
60
61def _unpatch(elasticsearch):
62    if getattr(elasticsearch, "_datadog_patch", False):
63        setattr(elasticsearch, "_datadog_patch", False)
64        _u(elasticsearch.transport.Transport, "perform_request")
65
66
67def _get_perform_request(elasticsearch):
68    def _perform_request(func, instance, args, kwargs):
69        pin = Pin.get_from(instance)
70        if not pin or not pin.enabled():
71            return func(*args, **kwargs)
72
73        with pin.tracer.trace(
74            "elasticsearch.query", service=ext_service(pin, config.elasticsearch), span_type=SpanTypes.ELASTICSEARCH
75        ) as span:
76            span.set_tag(SPAN_MEASURED_KEY)
77
78            # Don't instrument if the trace is not sampled
79            if not span.sampled:
80                return func(*args, **kwargs)
81
82            method, url = args
83            params = kwargs.get("params") or {}
84            encoded_params = urlencode(params)
85            body = kwargs.get("body")
86
87            span.set_tag(metadata.METHOD, method)
88            span.set_tag(metadata.URL, url)
89            span.set_tag(metadata.PARAMS, encoded_params)
90            if config.elasticsearch.trace_query_string:
91                span.set_tag(http.QUERY_STRING, encoded_params)
92
93            if method in ["GET", "POST"]:
94                span.set_tag(metadata.BODY, instance.serializer.dumps(body))
95            status = None
96
97            # set analytics sample rate
98            span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.elasticsearch.get_analytics_sample_rate())
99
100            span = quantize(span)
101
102            try:
103                result = func(*args, **kwargs)
104            except elasticsearch.exceptions.TransportError as e:
105                span.set_tag(http.STATUS_CODE, getattr(e, "status_code", 500))
106                span.error = 1
107                raise
108
109            try:
110                # Optional metadata extraction with soft fail.
111                if isinstance(result, tuple) and len(result) == 2:
112                    # elasticsearch<2.4; it returns both the status and the body
113                    status, data = result
114                else:
115                    # elasticsearch>=2.4; internal change for ``Transport.perform_request``
116                    # that just returns the body
117                    data = result
118
119                took = data.get("took")
120                if took:
121                    span.set_metric(metadata.TOOK, int(took))
122            except Exception:
123                pass
124
125            if status:
126                span.set_tag(http.STATUS_CODE, status)
127
128            return result
129
130    return _perform_request
131