1# -*- coding: utf-8 -*-
2#
3# Copyright (C) 2005-2021 Edgewall Software
4# All rights reserved.
5#
6# This software is licensed as described in the file COPYING, which
7# you should have received as part of this distribution. The terms
8# are also available at https://trac.edgewall.org/wiki/TracLicense.
9#
10# This software consists of voluntary contributions made by many
11# individuals. For the exact contribution history, see the revision
12# history and logs, available at https://trac.edgewall.org/log/.
13
14import copy
15import os
16import unittest
17
18from trac.config import ConfigurationError
19from trac.db.api import DatabaseManager, get_column_names, \
20                        parse_connection_uri
21from trac.db_default import (schema as default_schema,
22                             db_version as default_db_version)
23from trac.db.schema import Column, Table
24from trac.test import EnvironmentStub, get_dburi
25
26
27class ParseConnectionStringTestCase(unittest.TestCase):
28
29    def test_sqlite_relative(self):
30        # Default syntax for specifying DB path relative to the environment
31        # directory
32        self.assertEqual(('sqlite', {'path': 'db/trac.db'}),
33                         parse_connection_uri('sqlite:db/trac.db'))
34
35    def test_sqlite_absolute(self):
36        # Standard syntax
37        self.assertEqual(('sqlite', {'path': '/var/db/trac.db'}),
38                         parse_connection_uri('sqlite:///var/db/trac.db'))
39        # Legacy syntax
40        self.assertEqual(('sqlite', {'path': '/var/db/trac.db'}),
41                         parse_connection_uri('sqlite:/var/db/trac.db'))
42
43    def test_sqlite_with_timeout_param(self):
44        # In-memory database
45        self.assertEqual(('sqlite', {'path': 'db/trac.db',
46                                     'params': {'timeout': '10000'}}),
47                         parse_connection_uri('sqlite:db/trac.db?timeout=10000'))
48
49    def test_sqlite_windows_path(self):
50        # In-memory database
51        os_name = os.name
52        try:
53            os.name = 'nt'
54            self.assertEqual(('sqlite', {'path': 'C:/project/db/trac.db'}),
55                             parse_connection_uri('sqlite:C|/project/db/trac.db'))
56        finally:
57            os.name = os_name
58
59    def test_postgres_simple(self):
60        self.assertEqual(('postgres', {'host': 'localhost', 'path': '/trac'}),
61                         parse_connection_uri('postgres://localhost/trac'))
62
63    def test_postgres_with_port(self):
64        self.assertEqual(('postgres', {'host': 'localhost', 'port': 9431,
65                                       'path': '/trac'}),
66                         parse_connection_uri('postgres://localhost:9431/trac'))
67
68    def test_postgres_with_creds(self):
69        self.assertEqual(('postgres', {'user': 'john', 'password': 'letmein',
70                                       'host': 'localhost', 'port': 9431,
71                                       'path': '/trac'}),
72                 parse_connection_uri('postgres://john:letmein@localhost:9431/trac'))
73
74    def test_postgres_with_quoted_password(self):
75        self.assertEqual(('postgres', {'user': 'john', 'password': ':@/',
76                                       'host': 'localhost', 'path': '/trac'}),
77                     parse_connection_uri('postgres://john:%3a%40%2f@localhost/trac'))
78
79    def test_mysql_simple(self):
80        self.assertEqual(('mysql', {'host': 'localhost', 'path': '/trac'}),
81                     parse_connection_uri('mysql://localhost/trac'))
82
83    def test_mysql_with_creds(self):
84        self.assertEqual(('mysql', {'user': 'john', 'password': 'letmein',
85                                    'host': 'localhost', 'port': 3306,
86                                    'path': '/trac'}),
87                     parse_connection_uri('mysql://john:letmein@localhost:3306/trac'))
88
89    def test_empty_string(self):
90        self.assertRaises(ConfigurationError, parse_connection_uri, '')
91
92    def test_invalid_port(self):
93        self.assertRaises(ConfigurationError, parse_connection_uri,
94                          'postgres://localhost:42:42')
95
96    def test_invalid_schema(self):
97        self.assertRaises(ConfigurationError, parse_connection_uri,
98                          'sqlitedb/trac.db')
99
100    def test_no_path(self):
101        self.assertRaises(ConfigurationError, parse_connection_uri,
102                          'sqlite:')
103
104    def test_invalid_query_string(self):
105        self.assertRaises(ConfigurationError, parse_connection_uri,
106                          'postgres://localhost/schema?name')
107
108
109class StringsTestCase(unittest.TestCase):
110
111    def setUp(self):
112        self.env = EnvironmentStub()
113
114    def tearDown(self):
115        self.env.reset_db()
116
117    def test_insert_unicode(self):
118        with self.env.db_transaction as db:
119            quoted = db.quote('system')
120            db("INSERT INTO " + quoted + " (name,value) VALUES (%s,%s)",
121               ('test-unicode', 'ünicöde'))
122        self.assertEqual([('ünicöde',)], self.env.db_query(
123            "SELECT value FROM " + quoted + " WHERE name='test-unicode'"))
124
125    def test_insert_empty(self):
126        from trac.util.text import empty
127        with self.env.db_transaction as db:
128            quoted = db.quote('system')
129            db("INSERT INTO " + quoted + " (name,value) VALUES (%s,%s)",
130               ('test-empty', empty))
131        self.assertEqual([('',)], self.env.db_query(
132            "SELECT value FROM " + quoted + " WHERE name='test-empty'"))
133
134    def test_insert_markup(self):
135        from trac.util.html import Markup
136        with self.env.db_transaction as db:
137            quoted = db.quote('system')
138            query = "INSERT INTO {} (name,value) VALUES (%s,%s)".format(quoted)
139            db(query, ('test-markup', Markup('<em>märkup</em>')))
140            db.executemany(query, [('test-markup.%d' % i,
141                                    Markup('<em>märkup.%d</em>' % i))
142                                   for i in range(3)])
143        values = dict(self.env.db_query(
144            "SELECT name, value FROM {} WHERE name LIKE %s".format(quoted),
145            ('test-markup%',)))
146        self.assertEqual({'test-markup': '<em>märkup</em>',
147                          'test-markup.0': '<em>märkup.0</em>',
148                          'test-markup.1': '<em>märkup.1</em>',
149                          'test-markup.2': '<em>märkup.2</em>'}, values)
150
151    def test_quote(self):
152        with self.env.db_query as db:
153            cursor = db.cursor()
154            cursor.execute('SELECT 1 AS %s' %
155                           db.quote(r'alpha\`\"\'\\beta``gamma""delta'))
156            self.assertEqual(r'alpha\`\"\'\\beta``gamma""delta',
157                             get_column_names(cursor)[0])
158
159    def test_quoted_id_with_percent(self):
160        name = """%?`%s"%'%%"""
161
162        def test(logging=False):
163            with self.env.db_query as db:
164                cursor = db.cursor()
165                if logging:
166                    cursor.log = self.env.log
167
168                cursor.execute('SELECT 1 AS ' + db.quote(name))
169                self.assertEqual(name, get_column_names(cursor)[0])
170                cursor.execute('SELECT %s AS ' + db.quote(name), (42,))
171                self.assertEqual(name, get_column_names(cursor)[0])
172                stmt = """
173                    UPDATE {0} SET value=%s WHERE 1=(SELECT 0 AS {1})
174                    """.format(db.quote('system'), db.quote(name))
175                cursor.executemany(stmt, [])
176                cursor.executemany(stmt, [('42',), ('43',)])
177
178        test()
179        test(True)
180
181    def test_prefix_match_case_sensitive(self):
182        with self.env.db_transaction as db:
183            db.executemany("""
184                INSERT INTO {0} (name,value) VALUES (%s,1)
185                """.format(db.quote('system')),
186                [('blahblah',), ('BlahBlah',), ('BLAHBLAH',), ('BlähBlah',),
187                 ('BlahBläh',)])
188
189        with self.env.db_query as db:
190            names = sorted(name for name, in db(
191                "SELECT name FROM {0} WHERE name {1}"
192                .format(db.quote('system'), db.prefix_match()),
193                (db.prefix_match_value('Blah'),)))
194        self.assertEqual('BlahBlah', names[0])
195        self.assertEqual('BlahBläh', names[1])
196        self.assertEqual(2, len(names))
197
198    def test_prefix_match_metachars(self):
199        def do_query(prefix):
200            with self.env.db_query as db:
201                return [name for name, in db("""
202                    SELECT name FROM {0} WHERE name {1} ORDER BY name
203                    """.format(db.quote('system'), db.prefix_match()),
204                    (db.prefix_match_value(prefix),))]
205
206        values = ['foo*bar', 'foo*bar!', 'foo?bar', 'foo?bar!',
207                  'foo[bar', 'foo[bar!', 'foo]bar', 'foo]bar!',
208                  'foo%bar', 'foo%bar!', 'foo_bar', 'foo_bar!',
209                  'foo/bar', 'foo/bar!', 'fo*ob?ar[fo]ob%ar_fo/obar']
210        with self.env.db_transaction as db:
211            db.executemany("""
212                INSERT INTO {0} (name,value) VALUES (%s,1)
213                """.format(db.quote('system')),
214                [(value,) for value in values])
215
216        self.assertEqual(['foo*bar', 'foo*bar!'], do_query('foo*'))
217        self.assertEqual(['foo?bar', 'foo?bar!'], do_query('foo?'))
218        self.assertEqual(['foo[bar', 'foo[bar!'], do_query('foo['))
219        self.assertEqual(['foo]bar', 'foo]bar!'], do_query('foo]'))
220        self.assertEqual(['foo%bar', 'foo%bar!'], do_query('foo%'))
221        self.assertEqual(['foo_bar', 'foo_bar!'], do_query('foo_'))
222        self.assertEqual(['foo/bar', 'foo/bar!'], do_query('foo/'))
223        self.assertEqual(['fo*ob?ar[fo]ob%ar_fo/obar'], do_query('fo*'))
224        self.assertEqual(['fo*ob?ar[fo]ob%ar_fo/obar'],
225                         do_query('fo*ob?ar[fo]ob%ar_fo/obar'))
226
227
228class ConnectionTestCase(unittest.TestCase):
229    def setUp(self):
230        self.env = EnvironmentStub()
231        self.schema = [
232            Table('HOURS', key='ID')[
233                Column('ID', auto_increment=True),
234                Column('AUTHOR')
235            ],
236            Table('blog', key='bid')[
237                Column('bid', auto_increment=True),
238                Column('author'),
239                Column('comment')
240            ]
241        ]
242        self.dbm = DatabaseManager(self.env)
243        self.dbm.drop_tables(self.schema)
244        self.dbm.create_tables(self.schema)
245
246    def tearDown(self):
247        DatabaseManager(self.env).drop_tables(self.schema)
248        self.env.reset_db()
249
250    def test_drop_column(self):
251        """Data is preserved when column is dropped."""
252        table_data = [
253            ('blog', ('author', 'comment'),
254             (('author1', 'comment one'),
255              ('author2', 'comment two'))),
256        ]
257        self.dbm.insert_into_tables(table_data)
258
259        with self.env.db_transaction as db:
260            db.drop_column('blog', 'comment')
261
262        data = list(self.env.db_query("SELECT * FROM blog"))
263        self.assertEqual((1, 'author1'), data[0])
264        self.assertEqual((2, 'author2'), data[1])
265
266    def test_drop_column_no_exists(self):
267        """Error is not raised when dropping non-existent column."""
268        table_data = [
269            ('blog', ('author', 'comment'),
270             (('author1', 'comment one'),
271              ('author2', 'comment two'))),
272        ]
273        self.dbm.insert_into_tables(table_data)
274
275        with self.env.db_transaction as db:
276            db.drop_column('blog', 'tags')
277
278        data = list(self.env.db_query("SELECT * FROM blog"))
279        self.assertEqual((1, 'author1', 'comment one'), data[0])
280        self.assertEqual((2, 'author2', 'comment two'), data[1])
281
282    def test_rollback_transaction_on_exception(self):
283        """Transaction is rolled back when an exception occurs in the
284        transaction context manager.
285        """
286        insert_sql = "INSERT INTO blog (bid, author) VALUES (42, 'anonymous')"
287        try:
288            with self.env.db_transaction as db:
289                db(insert_sql)
290                db(insert_sql)
291        except self.env.db_exc.IntegrityError:
292            pass
293
294        for _, in self.env.db_query("""
295                SELECT author FROM blog WHERE bid=42
296                """):
297            self.fail("Transaction was not rolled back")
298
299    def test_rollback_nested_transaction_on_exception(self):
300        """Transaction is rolled back when an exception occurs in the
301        inner transaction context manager.
302        """
303        sql = "INSERT INTO blog (bid, author) VALUES (42, 'anonymous')"
304        try:
305            with self.env.db_transaction as db_outer:
306                db_outer(sql)
307                with self.env.db_transaction as db_inner:
308                    db_inner(sql)
309        except self.env.db_exc.IntegrityError:
310            pass
311
312        for _, in self.env.db_query("""
313                SELECT author FROM blog WHERE bid=42
314                """):
315            self.fail("Transaction was not rolled back")
316
317    def test_get_last_id(self):
318        q = "INSERT INTO report (author) VALUES ('anonymous')"
319        with self.env.db_transaction as db:
320            cursor = db.cursor()
321            cursor.execute(q)
322            # Row ID correct before...
323            id1 = db.get_last_id(cursor, 'report')
324            db.commit()
325            cursor.execute(q)
326            # ... and after commit()
327            db.commit()
328            id2 = db.get_last_id(cursor, 'report')
329
330        self.assertNotEqual(0, id1)
331        self.assertEqual(id1 + 1, id2)
332
333    def test_update_sequence_default_column_name(self):
334        with self.env.db_transaction as db:
335            db("INSERT INTO report (id, author) VALUES (42, 'anonymous')")
336            cursor = db.cursor()
337            db.update_sequence(cursor, 'report')
338
339        self.env.db_transaction(
340            "INSERT INTO report (author) VALUES ('next-id')")
341
342        self.assertEqual(43, self.env.db_query(
343                "SELECT id FROM report WHERE author='next-id'")[0][0])
344
345    def test_update_sequence_nondefault_column_name(self):
346        with self.env.db_transaction as db:
347            cursor = db.cursor()
348            cursor.execute(
349                "INSERT INTO blog (bid, author) VALUES (42, 'anonymous')")
350            db.update_sequence(cursor, 'blog', 'bid')
351
352        self.env.db_transaction(
353            "INSERT INTO blog (author) VALUES ('next-id')")
354
355        self.assertEqual(43, self.env.db_query(
356            "SELECT bid FROM blog WHERE author='next-id'")[0][0])
357
358    def test_identifiers_need_quoting(self):
359        """Test for regression described in comment:4:ticket:11512."""
360        with self.env.db_transaction as db:
361            db("INSERT INTO %s (%s, %s) VALUES (42, 'anonymous')"
362               % (db.quote('HOURS'), db.quote('ID'), db.quote('AUTHOR')))
363            cursor = db.cursor()
364            db.update_sequence(cursor, 'HOURS', 'ID')
365
366        with self.env.db_transaction as db:
367            cursor = db.cursor()
368            cursor.execute(
369                "INSERT INTO %s (%s) VALUES ('next-id')"
370                % (db.quote('HOURS'), db.quote('AUTHOR')))
371            last_id = db.get_last_id(cursor, 'HOURS', 'ID')
372
373        self.assertEqual(43, last_id)
374
375    def test_get_table_names(self):
376        schema = default_schema + self.schema
377        with self.env.db_query as db:
378            # Some DB (e.g. MariaDB) normalize the table names to lower case
379            self.assertEqual(
380                sorted(table.name.lower() for table in schema),
381                sorted(name.lower() for name in db.get_table_names()))
382
383    def test_get_column_names(self):
384        schema = default_schema + self.schema
385        with self.env.db_query as db:
386            for table in schema:
387                column_names = [col.name for col in table.columns]
388                self.assertEqual(column_names,
389                                 db.get_column_names(table.name))
390
391    def test_get_column_names_non_existent_table(self):
392        with self.assertRaises(self.env.db_exc.OperationalError) as cm:
393            self.dbm.get_column_names('blah')
394        self.assertIn(str(cm.exception), ('Table "blah" not found',
395                                          'Table `blah` not found'))
396
397
398class DatabaseManagerTestCase(unittest.TestCase):
399
400    def setUp(self):
401        self.env = EnvironmentStub(default_data=True)
402        self.dbm = DatabaseManager(self.env)
403
404    def tearDown(self):
405        self.env.reset_db()
406
407    def test_destroy_db(self):
408        """Database doesn't exist after calling destroy_db."""
409        with self.env.db_query as db:
410            db("SELECT name FROM " + db.quote('system'))
411        self.assertIsNotNone(self.dbm._cnx_pool)
412        self.dbm.destroy_db()
413        self.assertIsNone(self.dbm._cnx_pool)  # No connection pool
414        scheme, params = parse_connection_uri(get_dburi())
415        if scheme != 'postgres' or params.get('schema', 'public') != 'public':
416            self.assertFalse(self.dbm.db_exists())
417        else:
418            self.assertEqual([], self.dbm.get_table_names())
419
420    def test_get_column_names(self):
421        """Get column names for the default database."""
422        for table in default_schema:
423            column_names = [col.name for col in table.columns]
424            self.assertEqual(column_names,
425                             self.dbm.get_column_names(table.name))
426
427    def test_get_default_database_version(self):
428        """Get database version for the default entry named
429        `database_version`.
430        """
431        self.assertEqual(default_db_version, self.dbm.get_database_version())
432
433    def test_get_table_names(self):
434        """Get table names for the default database."""
435        self.assertEqual(sorted(table.name for table in default_schema),
436                         sorted(self.dbm.get_table_names()))
437
438    def test_has_table(self):
439        self.assertIs(True, self.dbm.has_table('system'))
440        self.assertIs(True, self.dbm.has_table('wiki'))
441        self.assertIs(False, self.dbm.has_table('trac'))
442        self.assertIs(False, self.dbm.has_table('blah.blah'))
443
444    def test_no_database_version(self):
445        """False is returned when entry doesn't exist"""
446        self.assertFalse(self.dbm.get_database_version('trac_plugin_version'))
447
448    def test_set_default_database_version(self):
449        """Set database version for the default entry named
450        `database_version`.
451        """
452        new_db_version = default_db_version + 1
453        self.dbm.set_database_version(new_db_version)
454        self.assertEqual(new_db_version, self.dbm.get_database_version())
455        self.assertEqual([('INFO', 'Upgraded database_version from 45 to 46')],
456                         self.env.log_messages)
457
458        # Restore the previous version to avoid destroying the database
459        # on teardown
460        self.dbm.set_database_version(default_db_version)
461        self.assertEqual(default_db_version, self.dbm.get_database_version())
462
463    def test_set_get_plugin_database_version(self):
464        """Get and set database version for an entry with an
465        arbitrary name.
466        """
467        name = 'trac_plugin_version'
468        db_ver = 1
469
470        self.dbm.set_database_version(db_ver, name)
471        self.assertEqual([], self.env.log_messages)
472        self.assertEqual(db_ver, self.dbm.get_database_version(name))
473        # DB update will be skipped when new value equals database version
474        self.dbm.set_database_version(db_ver, name)
475        self.assertEqual([], self.env.log_messages)
476
477    def test_get_sequence_names(self):
478        sequence_names = []
479        if self.dbm.connection_uri.startswith('postgres'):
480            for table in default_schema:
481                for column in table.columns:
482                    if column.name == 'id' and column.auto_increment:
483                        sequence_names.append(table.name)
484            sequence_names.sort()
485
486        self.assertEqual(sequence_names, self.dbm.get_sequence_names())
487
488
489class ModifyTableTestCase(unittest.TestCase):
490
491    def setUp(self):
492        self.env = EnvironmentStub()
493        self.dbm = DatabaseManager(self.env)
494        self.schema = [
495            Table('table1', key='col1')[
496                Column('col1', auto_increment=True),
497                Column('col2'),
498                Column('col3'),
499            ],
500            Table('table2', key='col1')[
501                Column('col1'),
502                Column('col2'),
503            ],
504            Table('table3', key='col2')[
505                Column('col1'),
506                Column('col2', type='int'),
507                Column('col3')
508            ]
509        ]
510        self.dbm.create_tables(self.schema)
511        self.new_schema = copy.deepcopy([self.schema[0], self.schema[2]])
512        self.new_schema[0].remove_columns(('col2',))
513        self.new_schema[1].columns.append(Column('col4'))
514        self.new_schema.append(
515            Table('table4')[
516                Column('col1'),
517            ]
518        )
519
520    def tearDown(self):
521        self.dbm.drop_tables(['table1', 'table2', 'table3', 'table4'])
522        self.env.reset_db()
523
524    def _insert_data(self):
525        table_data = [
526            ('table1', ('col2', 'col3'),
527             (('data1', 'data2'),
528              ('data3', 'data4'))),
529            ('table2', ('col1', 'col2'),
530             (('data5', 'data6'),
531              ('data7', 'data8'))),
532            ('table3', ('col1', 'col2', 'col3'),
533             (('data9', 10, 'data11'),
534              ('data12', 13, 'data14'))),
535        ]
536        self.dbm.insert_into_tables(table_data)
537
538    def test_drop_columns(self):
539        """Data is preserved when column is dropped."""
540        self._insert_data()
541
542        self.dbm.drop_columns('table1', ('col2',))
543
544        self.assertEqual(['col1', 'col3'], self.dbm.get_column_names('table1'))
545        data = list(self.env.db_query("SELECT * FROM table1"))
546        self.assertEqual((1, 'data2'), data[0])
547        self.assertEqual((2, 'data4'), data[1])
548
549    def test_drop_columns_multiple_columns(self):
550        """Data is preserved when columns are dropped."""
551        self._insert_data()
552
553        self.dbm.drop_columns('table3', ('col1', 'col3'))
554
555        self.assertEqual(['col2'], self.dbm.get_column_names('table3'))
556        data = list(self.env.db_query("SELECT * FROM table3"))
557        self.assertEqual((10,), data[0])
558        self.assertEqual((13,), data[1])
559
560    def test_drop_columns_non_existent_table(self):
561        with self.assertRaises(self.env.db_exc.OperationalError) as cm:
562            self.dbm.drop_columns('blah', ('col1',))
563        self.assertIn(str(cm.exception), ('Table "blah" not found',
564                                          'Table `blah` not found'))
565
566    def test_upgrade_tables_have_new_schema(self):
567        """The upgraded tables have the new schema."""
568        self.dbm.upgrade_tables(self.new_schema)
569
570        for table in self.new_schema:
571            self.assertEqual([col.name for col in table.columns],
572                             self.dbm.get_column_names(table.name))
573
574    def test_upgrade_tables_data_is_migrated(self):
575        """The data is migrated to the upgraded tables."""
576        self._insert_data()
577
578        self.dbm.upgrade_tables(self.new_schema)
579        self.env.db_transaction("""
580                INSERT INTO table1 (col3) VALUES ('data12')
581                """)
582
583        data = list(self.env.db_query("SELECT * FROM table1"))
584        self.assertEqual((1, 'data2'), data[0])
585        self.assertEqual((2, 'data4'), data[1])
586        self.assertEqual(3, self.env.db_query("""
587                SELECT col1 FROM table1 WHERE col3='data12'""")[0][0])
588        data = list(self.env.db_query("SELECT * FROM table2"))
589        self.assertEqual(('data5', 'data6'), data[0])
590        self.assertEqual(('data7', 'data8'), data[1])
591        data = list(self.env.db_query("SELECT * FROM table3"))
592        self.assertEqual(('data9', 10, 'data11', None), data[0])
593        self.assertEqual(('data12', 13, 'data14', None), data[1])
594
595    def test_upgrade_tables_no_common_columns(self):
596        schema = [
597            Table('table1', key='id')[
598                Column('id', auto_increment=True),
599                Column('name'),
600                Column('value'),
601            ],
602        ]
603        self.dbm.upgrade_tables(schema)
604        self.assertEqual(['id', 'name', 'value'],
605                         self.dbm.get_column_names('table1'))
606        self.assertEqual([], list(self.env.db_query("SELECT * FROM table1")))
607
608
609def test_suite():
610    suite = unittest.TestSuite()
611    suite.addTest(unittest.makeSuite(ParseConnectionStringTestCase))
612    suite.addTest(unittest.makeSuite(StringsTestCase))
613    suite.addTest(unittest.makeSuite(ConnectionTestCase))
614    suite.addTest(unittest.makeSuite(DatabaseManagerTestCase))
615    suite.addTest(unittest.makeSuite(ModifyTableTestCase))
616    return suite
617
618
619if __name__ == '__main__':
620    unittest.main(defaultTest='test_suite')
621