1import flask.templating
2from flask import request
3
4from aws_xray_sdk.core.models import http
5from aws_xray_sdk.core.utils import stacktrace
6from aws_xray_sdk.ext.util import calculate_sampling_decision, \
7    calculate_segment_name, construct_xray_header, prepare_response_header
8from aws_xray_sdk.core.lambda_launcher import check_in_lambda, LambdaContext
9
10
11class XRayMiddleware(object):
12
13    def __init__(self, app, recorder):
14        self.app = app
15        self.app.logger.info("initializing xray middleware")
16
17        self._recorder = recorder
18        self.app.before_request(self._before_request)
19        self.app.after_request(self._after_request)
20        self.app.teardown_request(self._teardown_request)
21        self.in_lambda_ctx = False
22
23        if check_in_lambda() and type(self._recorder.context) == LambdaContext:
24            self.in_lambda_ctx = True
25
26        _patch_render(recorder)
27
28    def _before_request(self):
29        headers = request.headers
30        xray_header = construct_xray_header(headers)
31        req = request._get_current_object()
32
33        name = calculate_segment_name(req.host, self._recorder)
34
35        sampling_req = {
36            'host': req.host,
37            'method': req.method,
38            'path': req.path,
39            'service': name,
40        }
41        sampling_decision = calculate_sampling_decision(
42            trace_header=xray_header,
43            recorder=self._recorder,
44            sampling_req=sampling_req,
45        )
46
47        if self.in_lambda_ctx:
48            segment = self._recorder.begin_subsegment(name)
49        else:
50            segment = self._recorder.begin_segment(
51                name=name,
52                traceid=xray_header.root,
53                parent_id=xray_header.parent,
54                sampling=sampling_decision,
55            )
56
57        segment.save_origin_trace_header(xray_header)
58        segment.put_http_meta(http.URL, req.base_url)
59        segment.put_http_meta(http.METHOD, req.method)
60        segment.put_http_meta(http.USER_AGENT, headers.get('User-Agent'))
61
62        client_ip = headers.get('X-Forwarded-For') or headers.get('HTTP_X_FORWARDED_FOR')
63        if client_ip:
64            segment.put_http_meta(http.CLIENT_IP, client_ip)
65            segment.put_http_meta(http.X_FORWARDED_FOR, True)
66        else:
67            segment.put_http_meta(http.CLIENT_IP, req.remote_addr)
68
69    def _after_request(self, response):
70        if self.in_lambda_ctx:
71            segment = self._recorder.current_subsegment()
72        else:
73            segment = self._recorder.current_segment()
74        segment.put_http_meta(http.STATUS, response.status_code)
75
76        origin_header = segment.get_origin_trace_header()
77        resp_header_str = prepare_response_header(origin_header, segment)
78        response.headers[http.XRAY_HEADER] = resp_header_str
79
80        cont_len = response.headers.get('Content-Length')
81        if cont_len:
82            segment.put_http_meta(http.CONTENT_LENGTH, int(cont_len))
83
84        return response
85
86    def _teardown_request(self, exception):
87        segment = None
88        try:
89            if self.in_lambda_ctx:
90                segment = self._recorder.current_subsegment()
91            else:
92                segment = self._recorder.current_segment()
93        except Exception:
94            pass
95        if not segment:
96            return
97
98        if exception:
99            segment.put_http_meta(http.STATUS, 500)
100            stack = stacktrace.get_stacktrace(limit=self._recorder._max_trace_back)
101            segment.add_exception(exception, stack)
102
103        if self.in_lambda_ctx:
104            self._recorder.end_subsegment()
105        else:
106            self._recorder.end_segment()
107
108
109def _patch_render(recorder):
110
111    _render = flask.templating._render
112
113    @recorder.capture('template_render')
114    def _traced_render(template, context, app):
115        if template.name:
116            recorder.current_subsegment().name = template.name
117        return _render(template, context, app)
118
119    flask.templating._render = _traced_render
120