1import re
2
3from aws_xray_sdk.core.models.trace_header import TraceHeader
4from aws_xray_sdk.core.models import http
5
6import wrapt
7import sys
8
9if sys.version_info.major >= 3:  # Python 3 and above
10    from urllib.parse import urlparse
11else:  # Python 2 and below
12    from urlparse import urlparse
13
14
15first_cap_re = re.compile('(.)([A-Z][a-z]+)')
16all_cap_re = re.compile('([a-z0-9])([A-Z])')
17UNKNOWN_HOSTNAME = "UNKNOWN HOST"
18
19
20def inject_trace_header(headers, entity):
21    """
22    Extract trace id, entity id and sampling decision
23    from the input entity and inject these information
24    to headers.
25
26    :param dict headers: http headers to inject
27    :param Entity entity: trace entity that the trace header
28        value generated from.
29    """
30    if not entity:
31        return
32
33    if hasattr(entity, 'type') and entity.type == 'subsegment':
34        header = entity.parent_segment.get_origin_trace_header()
35    else:
36        header = entity.get_origin_trace_header()
37    data = header.data if header else None
38
39    to_insert = TraceHeader(
40        root=entity.trace_id,
41        parent=entity.id,
42        sampled=entity.sampled,
43        data=data,
44    )
45
46    value = to_insert.to_header_str()
47
48    headers[http.XRAY_HEADER] = value
49
50
51def calculate_sampling_decision(trace_header, recorder, sampling_req):
52    """
53    Return 1 or the matched rule name if should sample and 0 if should not.
54    The sampling decision coming from ``trace_header`` always has
55    the highest precedence. If the ``trace_header`` doesn't contain
56    sampling decision then it checks if sampling is enabled or not
57    in the recorder. If not enbaled it returns 1. Otherwise it uses user
58    defined sampling rules to decide.
59    """
60    if trace_header.sampled is not None and trace_header.sampled != '?':
61        return trace_header.sampled
62    elif not recorder.sampling:
63        return 1
64    else:
65        decision = recorder.sampler.should_trace(sampling_req)
66    return decision if decision else 0
67
68
69def construct_xray_header(headers):
70    """
71    Construct a ``TraceHeader`` object from dictionary headers
72    of the incoming request. This method should always return
73    a ``TraceHeader`` object regardless of tracing header's presence
74    in the incoming request.
75    """
76    header_str = headers.get(http.XRAY_HEADER) or headers.get(http.ALT_XRAY_HEADER)
77    if header_str:
78        return TraceHeader.from_header_str(header_str)
79    else:
80        return TraceHeader()
81
82
83def calculate_segment_name(host_name, recorder):
84    """
85    Returns the segment name based on recorder configuration and
86    input host name. This is a helper generally used in web framework
87    middleware where a host name is available from incoming request's headers.
88    """
89    if recorder.dynamic_naming:
90        return recorder.dynamic_naming.get_name(host_name)
91    else:
92        return recorder.service
93
94
95def prepare_response_header(origin_header, segment):
96    """
97    Prepare a trace header to be inserted into response
98    based on original header and the request segment.
99    """
100    if origin_header and origin_header.sampled == '?':
101        new_header = TraceHeader(root=segment.trace_id,
102                                 sampled=segment.sampled)
103    else:
104        new_header = TraceHeader(root=segment.trace_id)
105
106    return new_header.to_header_str()
107
108
109def to_snake_case(name):
110    """
111    Convert the input string to snake-cased string.
112    """
113    s1 = first_cap_re.sub(r'\1_\2', name)
114    # handle acronym words
115    return all_cap_re.sub(r'\1_\2', s1).lower()
116
117
118# ? is not a valid entity, and we don't want things after the ? for the segment name
119def strip_url(url):
120    """
121    Will generate a valid url string for use as a segment name
122    :param url: url to strip
123    :return: validated url string
124    """
125    return url.partition('?')[0] if url else url
126
127
128def get_hostname(url):
129    if url is None:
130        return UNKNOWN_HOSTNAME
131    url_parse = urlparse(url)
132    hostname = url_parse.hostname
133    if hostname is None:
134        return UNKNOWN_HOSTNAME
135    return hostname if hostname else url  # If hostname is none, we return the regular URL; indication of malformed url
136
137
138def unwrap(obj, attr):
139    """
140    Will unwrap a `wrapt` attribute
141    :param obj: base object
142    :param attr: attribute on `obj` to unwrap
143    """
144    f = getattr(obj, attr, None)
145    if f and isinstance(f, wrapt.ObjectProxy) and hasattr(f, '__wrapped__'):
146        setattr(obj, attr, f.__wrapped__)
147