1import logging
2import sys
3
4if sys.version_info >= (3, 0, 0):
5    from urllib.parse import urlparse, uses_netloc
6else:
7    from urlparse import urlparse, uses_netloc
8
9import wrapt
10
11from aws_xray_sdk.core import xray_recorder
12from aws_xray_sdk.core.patcher import _PATCHED_MODULES
13from aws_xray_sdk.core.utils import stacktrace
14from aws_xray_sdk.ext.util import unwrap
15
16from sqlalchemy.sql.expression import ClauseElement
17
18
19def _sql_meta(engine_instance, args):
20    try:
21        metadata = {}
22        url = urlparse(str(engine_instance.engine.url))
23        # Add Scheme to uses_netloc or // will be missing from url.
24        uses_netloc.append(url.scheme)
25        if url.password is None:
26            metadata['url'] = url.geturl()
27            name = url.netloc
28        else:
29            # Strip password from URL
30            host_info = url.netloc.rpartition('@')[-1]
31            parts = url._replace(netloc='{}@{}'.format(url.username, host_info))
32            metadata['url'] = parts.geturl()
33            name = host_info
34        metadata['user'] = url.username
35        metadata['database_type'] = engine_instance.engine.name
36        try:
37            version = getattr(engine_instance.dialect, '{}_version'.format(engine_instance.engine.driver))
38            version_str = '.'.join(map(str, version))
39            metadata['driver_version'] = "{}-{}".format(engine_instance.engine.driver, version_str)
40        except AttributeError:
41            metadata['driver_version'] = engine_instance.engine.driver
42        if engine_instance.dialect.server_version_info is not None:
43            metadata['database_version'] = '.'.join(map(str, engine_instance.dialect.server_version_info))
44        if xray_recorder.stream_sql:
45            try:
46                if isinstance(args[0], ClauseElement):
47                    metadata['sanitized_query'] = str(args[0].compile(engine_instance.engine))
48                else:
49                    metadata['sanitized_query'] = str(args[0])
50            except Exception:
51                logging.getLogger(__name__).exception('Error getting the sanitized query')
52    except Exception:
53        metadata = None
54        name = None
55        logging.getLogger(__name__).exception('Error parsing sql metadata.')
56    return name, metadata
57
58
59def _xray_traced_sqlalchemy_execute(wrapped, instance, args, kwargs):
60    return _process_request(wrapped, instance, args, kwargs)
61
62
63def _xray_traced_sqlalchemy_session(wrapped, instance, args, kwargs):
64    return _process_request(wrapped, instance.bind, args, kwargs)
65
66
67def _process_request(wrapped, engine_instance, args, kwargs):
68    name, sql = _sql_meta(engine_instance, args)
69    if sql is not None:
70        subsegment = xray_recorder.begin_subsegment(name, namespace='remote')
71    else:
72        subsegment = None
73    try:
74        res = wrapped(*args, **kwargs)
75    except Exception:
76        if subsegment is not None:
77            exception = sys.exc_info()[1]
78            stack = stacktrace.get_stacktrace(limit=xray_recorder._max_trace_back)
79            subsegment.add_exception(exception, stack)
80        raise
81    finally:
82        if subsegment is not None:
83            subsegment.set_sql(sql)
84            xray_recorder.end_subsegment()
85    return res
86
87
88def patch():
89    wrapt.wrap_function_wrapper(
90        'sqlalchemy.engine.base',
91        'Connection.execute',
92        _xray_traced_sqlalchemy_execute
93    )
94
95    wrapt.wrap_function_wrapper(
96        'sqlalchemy.orm.session',
97        'Session.execute',
98        _xray_traced_sqlalchemy_session
99    )
100
101
102def unpatch():
103    """
104    Unpatch any previously patched modules.
105    This operation is idempotent.
106    """
107    _PATCHED_MODULES.discard('sqlalchemy_core')
108    import sqlalchemy
109    unwrap(sqlalchemy.engine.base.Connection, 'execute')
110    unwrap(sqlalchemy.orm.session.Session, 'execute')
111