1import flask.templating 2from flask import request 3 4from aws_xray_sdk.core.models import http 5from aws_xray_sdk.core.utils import stacktrace 6from aws_xray_sdk.ext.util import calculate_sampling_decision, \ 7 calculate_segment_name, construct_xray_header, prepare_response_header 8from aws_xray_sdk.core.lambda_launcher import check_in_lambda, LambdaContext 9 10 11class XRayMiddleware(object): 12 13 def __init__(self, app, recorder): 14 self.app = app 15 self.app.logger.info("initializing xray middleware") 16 17 self._recorder = recorder 18 self.app.before_request(self._before_request) 19 self.app.after_request(self._after_request) 20 self.app.teardown_request(self._teardown_request) 21 self.in_lambda_ctx = False 22 23 if check_in_lambda() and type(self._recorder.context) == LambdaContext: 24 self.in_lambda_ctx = True 25 26 _patch_render(recorder) 27 28 def _before_request(self): 29 headers = request.headers 30 xray_header = construct_xray_header(headers) 31 req = request._get_current_object() 32 33 name = calculate_segment_name(req.host, self._recorder) 34 35 sampling_req = { 36 'host': req.host, 37 'method': req.method, 38 'path': req.path, 39 'service': name, 40 } 41 sampling_decision = calculate_sampling_decision( 42 trace_header=xray_header, 43 recorder=self._recorder, 44 sampling_req=sampling_req, 45 ) 46 47 if self.in_lambda_ctx: 48 segment = self._recorder.begin_subsegment(name) 49 else: 50 segment = self._recorder.begin_segment( 51 name=name, 52 traceid=xray_header.root, 53 parent_id=xray_header.parent, 54 sampling=sampling_decision, 55 ) 56 57 segment.save_origin_trace_header(xray_header) 58 segment.put_http_meta(http.URL, req.base_url) 59 segment.put_http_meta(http.METHOD, req.method) 60 segment.put_http_meta(http.USER_AGENT, headers.get('User-Agent')) 61 62 client_ip = headers.get('X-Forwarded-For') or headers.get('HTTP_X_FORWARDED_FOR') 63 if client_ip: 64 segment.put_http_meta(http.CLIENT_IP, client_ip) 65 segment.put_http_meta(http.X_FORWARDED_FOR, True) 66 else: 67 segment.put_http_meta(http.CLIENT_IP, req.remote_addr) 68 69 def _after_request(self, response): 70 if self.in_lambda_ctx: 71 segment = self._recorder.current_subsegment() 72 else: 73 segment = self._recorder.current_segment() 74 segment.put_http_meta(http.STATUS, response.status_code) 75 76 origin_header = segment.get_origin_trace_header() 77 resp_header_str = prepare_response_header(origin_header, segment) 78 response.headers[http.XRAY_HEADER] = resp_header_str 79 80 cont_len = response.headers.get('Content-Length') 81 if cont_len: 82 segment.put_http_meta(http.CONTENT_LENGTH, int(cont_len)) 83 84 return response 85 86 def _teardown_request(self, exception): 87 segment = None 88 try: 89 if self.in_lambda_ctx: 90 segment = self._recorder.current_subsegment() 91 else: 92 segment = self._recorder.current_segment() 93 except Exception: 94 pass 95 if not segment: 96 return 97 98 if exception: 99 segment.put_http_meta(http.STATUS, 500) 100 stack = stacktrace.get_stacktrace(limit=self._recorder._max_trace_back) 101 segment.add_exception(exception, stack) 102 103 if self.in_lambda_ctx: 104 self._recorder.end_subsegment() 105 else: 106 self._recorder.end_segment() 107 108 109def _patch_render(recorder): 110 111 _render = flask.templating._render 112 113 @recorder.capture('template_render') 114 def _traced_render(template, context, app): 115 if template.name: 116 recorder.current_subsegment().name = template.name 117 return _render(template, context, app) 118 119 flask.templating._render = _traced_render 120