1from __future__ import print_function
2
3import os
4import time
5import traceback
6
7import eventlet
8from eventlet import event
9try:
10    from eventlet.green import MySQLdb
11except ImportError:
12    MySQLdb = False
13import tests
14from tests import skip_unless, using_pyevent, get_database_auth
15
16
17def mysql_requirement(_f):
18    """We want to skip tests if using pyevent, MySQLdb is not installed, or if
19    there is no database running on the localhost that the auth file grants
20    us access to.
21
22    This errs on the side of skipping tests if everything is not right, but
23    it's better than a million tests failing when you don't care about mysql
24    support."""
25    if using_pyevent(_f):
26        return False
27    if MySQLdb is False:
28        print("Skipping mysql tests, MySQLdb not importable")
29        return False
30    try:
31        auth = get_database_auth()['MySQLdb'].copy()
32        MySQLdb.connect(**auth)
33        return True
34    except MySQLdb.OperationalError:
35        print("Skipping mysql tests, error when connecting:")
36        traceback.print_exc()
37        return False
38
39
40class TestMySQLdb(tests.LimitedTestCase):
41    TEST_TIMEOUT = 5
42
43    def setUp(self):
44        self._auth = get_database_auth()['MySQLdb']
45        self.create_db()
46        self.connection = None
47        self.connection = MySQLdb.connect(**self._auth)
48        cursor = self.connection.cursor()
49        cursor.execute("""CREATE TABLE gargleblatz
50        (
51        a INTEGER
52        );""")
53        self.connection.commit()
54        cursor.close()
55
56        super(TestMySQLdb, self).setUp()
57
58    def tearDown(self):
59        if self.connection:
60            self.connection.close()
61        self.drop_db()
62
63        super(TestMySQLdb, self).tearDown()
64
65    @skip_unless(mysql_requirement)
66    def create_db(self):
67        auth = self._auth.copy()
68        try:
69            self.drop_db()
70        except Exception:
71            pass
72        dbname = 'test_%d_%d' % (os.getpid(), int(time.time() * 1000))
73        db = MySQLdb.connect(**auth).cursor()
74        db.execute("create database " + dbname)
75        db.close()
76        self._auth['db'] = dbname
77        del db
78
79    def drop_db(self):
80        db = MySQLdb.connect(**self._auth).cursor()
81        db.execute("drop database " + self._auth['db'])
82        db.close()
83        del db
84
85    def set_up_dummy_table(self, connection=None):
86        close_connection = False
87        if connection is None:
88            close_connection = True
89            if self.connection is None:
90                connection = MySQLdb.connect(**self._auth)
91            else:
92                connection = self.connection
93
94        cursor = connection.cursor()
95        cursor.execute(self.dummy_table_sql)
96        connection.commit()
97        cursor.close()
98        if close_connection:
99            connection.close()
100
101    dummy_table_sql = """CREATE TEMPORARY TABLE test_table
102        (
103        row_id INTEGER PRIMARY KEY AUTO_INCREMENT,
104        value_int INTEGER,
105        value_float FLOAT,
106        value_string VARCHAR(200),
107        value_uuid CHAR(36),
108        value_binary BLOB,
109        value_binary_string VARCHAR(200) BINARY,
110        value_enum ENUM('Y','N'),
111        created TIMESTAMP
112        ) ENGINE=InnoDB;"""
113
114    def assert_cursor_yields(self, curs):
115        counter = [0]
116
117        def tick():
118            while True:
119                counter[0] += 1
120                eventlet.sleep()
121        gt = eventlet.spawn(tick)
122        curs.execute("select 1")
123        rows = curs.fetchall()
124        self.assertEqual(len(rows), 1)
125        self.assertEqual(len(rows[0]), 1)
126        self.assertEqual(rows[0][0], 1)
127        assert counter[0] > 0, counter[0]
128        gt.kill()
129
130    def assert_cursor_works(self, cursor):
131        cursor.execute("select 1")
132        rows = cursor.fetchall()
133        self.assertEqual(len(rows), 1)
134        self.assertEqual(len(rows[0]), 1)
135        self.assertEqual(rows[0][0], 1)
136        self.assert_cursor_yields(cursor)
137
138    def assert_connection_works(self, conn):
139        curs = conn.cursor()
140        self.assert_cursor_works(curs)
141
142    def test_module_attributes(self):
143        import MySQLdb as orig
144        for key in dir(orig):
145            if key not in ('__author__', '__path__', '__revision__',
146                           '__version__', '__loader__'):
147                assert hasattr(MySQLdb, key), "%s %s" % (key, getattr(orig, key))
148
149    def test_connecting(self):
150        assert self.connection is not None
151
152    def test_connecting_annoyingly(self):
153        self.assert_connection_works(MySQLdb.Connect(**self._auth))
154        self.assert_connection_works(MySQLdb.Connection(**self._auth))
155        self.assert_connection_works(MySQLdb.connections.Connection(**self._auth))
156
157    def test_create_cursor(self):
158        cursor = self.connection.cursor()
159        cursor.close()
160
161    def test_run_query(self):
162        cursor = self.connection.cursor()
163        self.assert_cursor_works(cursor)
164        cursor.close()
165
166    def test_run_bad_query(self):
167        cursor = self.connection.cursor()
168        try:
169            cursor.execute("garbage blah blah")
170            assert False
171        except AssertionError:
172            raise
173        except Exception:
174            pass
175        cursor.close()
176
177    def fill_up_table(self, conn):
178        curs = conn.cursor()
179        for i in range(1000):
180            curs.execute('insert into test_table (value_int) values (%s)' % i)
181        conn.commit()
182
183    def test_yields(self):
184        conn = self.connection
185        self.set_up_dummy_table(conn)
186        self.fill_up_table(conn)
187        curs = conn.cursor()
188        results = []
189        SHORT_QUERY = "select * from test_table"
190        evt = event.Event()
191
192        def a_query():
193            self.assert_cursor_works(curs)
194            curs.execute(SHORT_QUERY)
195            results.append(2)
196            evt.send()
197        eventlet.spawn(a_query)
198        results.append(1)
199        self.assertEqual([1], results)
200        evt.wait()
201        self.assertEqual([1, 2], results)
202
203    def test_visibility_from_other_connections(self):
204        conn = MySQLdb.connect(**self._auth)
205        conn2 = MySQLdb.connect(**self._auth)
206        curs = conn.cursor()
207        try:
208            curs2 = conn2.cursor()
209            curs2.execute("insert into gargleblatz (a) values (%s)" % (314159))
210            self.assertEqual(curs2.rowcount, 1)
211            conn2.commit()
212            selection_query = "select * from gargleblatz"
213            curs2.execute(selection_query)
214            self.assertEqual(curs2.rowcount, 1)
215            del curs2, conn2
216            # create a new connection, it should see the addition
217            conn3 = MySQLdb.connect(**self._auth)
218            curs3 = conn3.cursor()
219            curs3.execute(selection_query)
220            self.assertEqual(curs3.rowcount, 1)
221            # now, does the already-open connection see it?
222            curs.execute(selection_query)
223            self.assertEqual(curs.rowcount, 1)
224            del curs3, conn3
225        finally:
226            # clean up my litter
227            curs.execute("delete from gargleblatz where a=314159")
228            conn.commit()
229
230
231class TestMonkeyPatch(tests.LimitedTestCase):
232    @skip_unless(mysql_requirement)
233    def test_monkey_patching(self):
234        tests.run_isolated('mysqldb_monkey_patch.py')
235