1import logging 2import sys 3 4if sys.version_info >= (3, 0, 0): 5 from urllib.parse import urlparse, uses_netloc 6else: 7 from urlparse import urlparse, uses_netloc 8 9import wrapt 10 11from aws_xray_sdk.core import xray_recorder 12from aws_xray_sdk.core.patcher import _PATCHED_MODULES 13from aws_xray_sdk.core.utils import stacktrace 14from aws_xray_sdk.ext.util import unwrap 15 16from sqlalchemy.sql.expression import ClauseElement 17 18 19def _sql_meta(engine_instance, args): 20 try: 21 metadata = {} 22 url = urlparse(str(engine_instance.engine.url)) 23 # Add Scheme to uses_netloc or // will be missing from url. 24 uses_netloc.append(url.scheme) 25 if url.password is None: 26 metadata['url'] = url.geturl() 27 name = url.netloc 28 else: 29 # Strip password from URL 30 host_info = url.netloc.rpartition('@')[-1] 31 parts = url._replace(netloc='{}@{}'.format(url.username, host_info)) 32 metadata['url'] = parts.geturl() 33 name = host_info 34 metadata['user'] = url.username 35 metadata['database_type'] = engine_instance.engine.name 36 try: 37 version = getattr(engine_instance.dialect, '{}_version'.format(engine_instance.engine.driver)) 38 version_str = '.'.join(map(str, version)) 39 metadata['driver_version'] = "{}-{}".format(engine_instance.engine.driver, version_str) 40 except AttributeError: 41 metadata['driver_version'] = engine_instance.engine.driver 42 if engine_instance.dialect.server_version_info is not None: 43 metadata['database_version'] = '.'.join(map(str, engine_instance.dialect.server_version_info)) 44 if xray_recorder.stream_sql: 45 try: 46 if isinstance(args[0], ClauseElement): 47 metadata['sanitized_query'] = str(args[0].compile(engine_instance.engine)) 48 else: 49 metadata['sanitized_query'] = str(args[0]) 50 except Exception: 51 logging.getLogger(__name__).exception('Error getting the sanitized query') 52 except Exception: 53 metadata = None 54 name = None 55 logging.getLogger(__name__).exception('Error parsing sql metadata.') 56 return name, metadata 57 58 59def _xray_traced_sqlalchemy_execute(wrapped, instance, args, kwargs): 60 return _process_request(wrapped, instance, args, kwargs) 61 62 63def _xray_traced_sqlalchemy_session(wrapped, instance, args, kwargs): 64 return _process_request(wrapped, instance.bind, args, kwargs) 65 66 67def _process_request(wrapped, engine_instance, args, kwargs): 68 name, sql = _sql_meta(engine_instance, args) 69 if sql is not None: 70 subsegment = xray_recorder.begin_subsegment(name, namespace='remote') 71 else: 72 subsegment = None 73 try: 74 res = wrapped(*args, **kwargs) 75 except Exception: 76 if subsegment is not None: 77 exception = sys.exc_info()[1] 78 stack = stacktrace.get_stacktrace(limit=xray_recorder._max_trace_back) 79 subsegment.add_exception(exception, stack) 80 raise 81 finally: 82 if subsegment is not None: 83 subsegment.set_sql(sql) 84 xray_recorder.end_subsegment() 85 return res 86 87 88def patch(): 89 wrapt.wrap_function_wrapper( 90 'sqlalchemy.engine.base', 91 'Connection.execute', 92 _xray_traced_sqlalchemy_execute 93 ) 94 95 wrapt.wrap_function_wrapper( 96 'sqlalchemy.orm.session', 97 'Session.execute', 98 _xray_traced_sqlalchemy_session 99 ) 100 101 102def unpatch(): 103 """ 104 Unpatch any previously patched modules. 105 This operation is idempotent. 106 """ 107 _PATCHED_MODULES.discard('sqlalchemy_core') 108 import sqlalchemy 109 unwrap(sqlalchemy.engine.base.Connection, 'execute') 110 unwrap(sqlalchemy.orm.session.Session, 'execute') 111