1import contextlib
2import datetime
3import inspect
4import re
5
6import six
7import sqlalchemy as sa
8
9
10def create_mock_engine(bind, stream=None):
11    """Create a mock SQLAlchemy engine from the passed engine or bind URL.
12
13    :param bind: A SQLAlchemy engine or bind URL to mock.
14    :param stream: Render all DDL operations to the stream.
15    """
16
17    if not isinstance(bind, six.string_types):
18        bind_url = str(bind.url)
19
20    else:
21        bind_url = bind
22
23    if stream is not None:
24
25        def dump(sql, *args, **kwargs):
26
27            class Compiler(type(sql._compiler(engine.dialect))):
28
29                def visit_bindparam(self, bindparam, *args, **kwargs):
30                    return self.render_literal_value(
31                        bindparam.value, bindparam.type)
32
33                def render_literal_value(self, value, type_):
34                    if isinstance(value, six.integer_types):
35                        return str(value)
36
37                    elif isinstance(value, (datetime.date, datetime.datetime)):
38                        return "'%s'" % value
39
40                    return super(Compiler, self).render_literal_value(
41                        value, type_)
42
43            text = str(Compiler(engine.dialect, sql).process(sql))
44            text = re.sub(r'\n+', '\n', text)
45            text = text.strip('\n').strip()
46
47            stream.write('\n%s;' % text)
48
49    else:
50        def dump(*args, **kw):
51            return None
52
53    try:
54        engine = sa.create_mock_engine(bind_url, executor=dump)
55    except AttributeError:  # SQLAlchemy <1.4
56        engine = sa.create_engine(bind_url, strategy='mock', executor=dump)
57    return engine
58
59
60@contextlib.contextmanager
61def mock_engine(engine, stream=None):
62    """Mocks out the engine specified in the passed bind expression.
63
64    Note this function is meant for convenience and protected usage. Do NOT
65    blindly pass user input to this function as it uses exec.
66
67    :param engine: A python expression that represents the engine to mock.
68    :param stream: Render all DDL operations to the stream.
69    """
70
71    # Create a stream if not present.
72
73    if stream is None:
74        stream = six.moves.cStringIO()
75
76    # Navigate the stack and find the calling frame that allows the
77    # expression to execuate.
78
79    for frame in inspect.stack()[1:]:
80
81        try:
82            frame = frame[0]
83            expression = '__target = %s' % engine
84            six.exec_(expression, frame.f_globals, frame.f_locals)
85            target = frame.f_locals['__target']
86            break
87
88        except Exception:
89            pass
90
91    else:
92
93        raise ValueError('Not a valid python expression', engine)
94
95    # Evaluate the expression and get the target engine.
96
97    frame.f_locals['__mock'] = create_mock_engine(target, stream)
98
99    # Replace the target with our mock.
100
101    six.exec_('%s = __mock' % engine, frame.f_globals, frame.f_locals)
102
103    # Give control back.
104
105    yield stream
106
107    # Put the target engine back.
108
109    frame.f_locals['__target'] = target
110    six.exec_('%s = __target' % engine, frame.f_globals, frame.f_locals)
111    six.exec_('del __target', frame.f_globals, frame.f_locals)
112    six.exec_('del __mock', frame.f_globals, frame.f_locals)
113