1# -*- coding: utf-8 -*- 2# 3# Copyright (C) 2005-2021 Edgewall Software 4# All rights reserved. 5# 6# This software is licensed as described in the file COPYING, which 7# you should have received as part of this distribution. The terms 8# are also available at https://trac.edgewall.org/wiki/TracLicense. 9# 10# This software consists of voluntary contributions made by many 11# individuals. For the exact contribution history, see the revision 12# history and logs, available at https://trac.edgewall.org/log/. 13 14import copy 15import os 16import unittest 17 18from trac.config import ConfigurationError 19from trac.db.api import DatabaseManager, get_column_names, \ 20 parse_connection_uri 21from trac.db_default import (schema as default_schema, 22 db_version as default_db_version) 23from trac.db.schema import Column, Table 24from trac.test import EnvironmentStub, get_dburi 25 26 27class ParseConnectionStringTestCase(unittest.TestCase): 28 29 def test_sqlite_relative(self): 30 # Default syntax for specifying DB path relative to the environment 31 # directory 32 self.assertEqual(('sqlite', {'path': 'db/trac.db'}), 33 parse_connection_uri('sqlite:db/trac.db')) 34 35 def test_sqlite_absolute(self): 36 # Standard syntax 37 self.assertEqual(('sqlite', {'path': '/var/db/trac.db'}), 38 parse_connection_uri('sqlite:///var/db/trac.db')) 39 # Legacy syntax 40 self.assertEqual(('sqlite', {'path': '/var/db/trac.db'}), 41 parse_connection_uri('sqlite:/var/db/trac.db')) 42 43 def test_sqlite_with_timeout_param(self): 44 # In-memory database 45 self.assertEqual(('sqlite', {'path': 'db/trac.db', 46 'params': {'timeout': '10000'}}), 47 parse_connection_uri('sqlite:db/trac.db?timeout=10000')) 48 49 def test_sqlite_windows_path(self): 50 # In-memory database 51 os_name = os.name 52 try: 53 os.name = 'nt' 54 self.assertEqual(('sqlite', {'path': 'C:/project/db/trac.db'}), 55 parse_connection_uri('sqlite:C|/project/db/trac.db')) 56 finally: 57 os.name = os_name 58 59 def test_postgres_simple(self): 60 self.assertEqual(('postgres', {'host': 'localhost', 'path': '/trac'}), 61 parse_connection_uri('postgres://localhost/trac')) 62 63 def test_postgres_with_port(self): 64 self.assertEqual(('postgres', {'host': 'localhost', 'port': 9431, 65 'path': '/trac'}), 66 parse_connection_uri('postgres://localhost:9431/trac')) 67 68 def test_postgres_with_creds(self): 69 self.assertEqual(('postgres', {'user': 'john', 'password': 'letmein', 70 'host': 'localhost', 'port': 9431, 71 'path': '/trac'}), 72 parse_connection_uri('postgres://john:letmein@localhost:9431/trac')) 73 74 def test_postgres_with_quoted_password(self): 75 self.assertEqual(('postgres', {'user': 'john', 'password': ':@/', 76 'host': 'localhost', 'path': '/trac'}), 77 parse_connection_uri('postgres://john:%3a%40%2f@localhost/trac')) 78 79 def test_mysql_simple(self): 80 self.assertEqual(('mysql', {'host': 'localhost', 'path': '/trac'}), 81 parse_connection_uri('mysql://localhost/trac')) 82 83 def test_mysql_with_creds(self): 84 self.assertEqual(('mysql', {'user': 'john', 'password': 'letmein', 85 'host': 'localhost', 'port': 3306, 86 'path': '/trac'}), 87 parse_connection_uri('mysql://john:letmein@localhost:3306/trac')) 88 89 def test_empty_string(self): 90 self.assertRaises(ConfigurationError, parse_connection_uri, '') 91 92 def test_invalid_port(self): 93 self.assertRaises(ConfigurationError, parse_connection_uri, 94 'postgres://localhost:42:42') 95 96 def test_invalid_schema(self): 97 self.assertRaises(ConfigurationError, parse_connection_uri, 98 'sqlitedb/trac.db') 99 100 def test_no_path(self): 101 self.assertRaises(ConfigurationError, parse_connection_uri, 102 'sqlite:') 103 104 def test_invalid_query_string(self): 105 self.assertRaises(ConfigurationError, parse_connection_uri, 106 'postgres://localhost/schema?name') 107 108 109class StringsTestCase(unittest.TestCase): 110 111 def setUp(self): 112 self.env = EnvironmentStub() 113 114 def tearDown(self): 115 self.env.reset_db() 116 117 def test_insert_unicode(self): 118 with self.env.db_transaction as db: 119 quoted = db.quote('system') 120 db("INSERT INTO " + quoted + " (name,value) VALUES (%s,%s)", 121 ('test-unicode', 'ünicöde')) 122 self.assertEqual([('ünicöde',)], self.env.db_query( 123 "SELECT value FROM " + quoted + " WHERE name='test-unicode'")) 124 125 def test_insert_empty(self): 126 from trac.util.text import empty 127 with self.env.db_transaction as db: 128 quoted = db.quote('system') 129 db("INSERT INTO " + quoted + " (name,value) VALUES (%s,%s)", 130 ('test-empty', empty)) 131 self.assertEqual([('',)], self.env.db_query( 132 "SELECT value FROM " + quoted + " WHERE name='test-empty'")) 133 134 def test_insert_markup(self): 135 from trac.util.html import Markup 136 with self.env.db_transaction as db: 137 quoted = db.quote('system') 138 query = "INSERT INTO {} (name,value) VALUES (%s,%s)".format(quoted) 139 db(query, ('test-markup', Markup('<em>märkup</em>'))) 140 db.executemany(query, [('test-markup.%d' % i, 141 Markup('<em>märkup.%d</em>' % i)) 142 for i in range(3)]) 143 values = dict(self.env.db_query( 144 "SELECT name, value FROM {} WHERE name LIKE %s".format(quoted), 145 ('test-markup%',))) 146 self.assertEqual({'test-markup': '<em>märkup</em>', 147 'test-markup.0': '<em>märkup.0</em>', 148 'test-markup.1': '<em>märkup.1</em>', 149 'test-markup.2': '<em>märkup.2</em>'}, values) 150 151 def test_quote(self): 152 with self.env.db_query as db: 153 cursor = db.cursor() 154 cursor.execute('SELECT 1 AS %s' % 155 db.quote(r'alpha\`\"\'\\beta``gamma""delta')) 156 self.assertEqual(r'alpha\`\"\'\\beta``gamma""delta', 157 get_column_names(cursor)[0]) 158 159 def test_quoted_id_with_percent(self): 160 name = """%?`%s"%'%%""" 161 162 def test(logging=False): 163 with self.env.db_query as db: 164 cursor = db.cursor() 165 if logging: 166 cursor.log = self.env.log 167 168 cursor.execute('SELECT 1 AS ' + db.quote(name)) 169 self.assertEqual(name, get_column_names(cursor)[0]) 170 cursor.execute('SELECT %s AS ' + db.quote(name), (42,)) 171 self.assertEqual(name, get_column_names(cursor)[0]) 172 stmt = """ 173 UPDATE {0} SET value=%s WHERE 1=(SELECT 0 AS {1}) 174 """.format(db.quote('system'), db.quote(name)) 175 cursor.executemany(stmt, []) 176 cursor.executemany(stmt, [('42',), ('43',)]) 177 178 test() 179 test(True) 180 181 def test_prefix_match_case_sensitive(self): 182 with self.env.db_transaction as db: 183 db.executemany(""" 184 INSERT INTO {0} (name,value) VALUES (%s,1) 185 """.format(db.quote('system')), 186 [('blahblah',), ('BlahBlah',), ('BLAHBLAH',), ('BlähBlah',), 187 ('BlahBläh',)]) 188 189 with self.env.db_query as db: 190 names = sorted(name for name, in db( 191 "SELECT name FROM {0} WHERE name {1}" 192 .format(db.quote('system'), db.prefix_match()), 193 (db.prefix_match_value('Blah'),))) 194 self.assertEqual('BlahBlah', names[0]) 195 self.assertEqual('BlahBläh', names[1]) 196 self.assertEqual(2, len(names)) 197 198 def test_prefix_match_metachars(self): 199 def do_query(prefix): 200 with self.env.db_query as db: 201 return [name for name, in db(""" 202 SELECT name FROM {0} WHERE name {1} ORDER BY name 203 """.format(db.quote('system'), db.prefix_match()), 204 (db.prefix_match_value(prefix),))] 205 206 values = ['foo*bar', 'foo*bar!', 'foo?bar', 'foo?bar!', 207 'foo[bar', 'foo[bar!', 'foo]bar', 'foo]bar!', 208 'foo%bar', 'foo%bar!', 'foo_bar', 'foo_bar!', 209 'foo/bar', 'foo/bar!', 'fo*ob?ar[fo]ob%ar_fo/obar'] 210 with self.env.db_transaction as db: 211 db.executemany(""" 212 INSERT INTO {0} (name,value) VALUES (%s,1) 213 """.format(db.quote('system')), 214 [(value,) for value in values]) 215 216 self.assertEqual(['foo*bar', 'foo*bar!'], do_query('foo*')) 217 self.assertEqual(['foo?bar', 'foo?bar!'], do_query('foo?')) 218 self.assertEqual(['foo[bar', 'foo[bar!'], do_query('foo[')) 219 self.assertEqual(['foo]bar', 'foo]bar!'], do_query('foo]')) 220 self.assertEqual(['foo%bar', 'foo%bar!'], do_query('foo%')) 221 self.assertEqual(['foo_bar', 'foo_bar!'], do_query('foo_')) 222 self.assertEqual(['foo/bar', 'foo/bar!'], do_query('foo/')) 223 self.assertEqual(['fo*ob?ar[fo]ob%ar_fo/obar'], do_query('fo*')) 224 self.assertEqual(['fo*ob?ar[fo]ob%ar_fo/obar'], 225 do_query('fo*ob?ar[fo]ob%ar_fo/obar')) 226 227 228class ConnectionTestCase(unittest.TestCase): 229 def setUp(self): 230 self.env = EnvironmentStub() 231 self.schema = [ 232 Table('HOURS', key='ID')[ 233 Column('ID', auto_increment=True), 234 Column('AUTHOR') 235 ], 236 Table('blog', key='bid')[ 237 Column('bid', auto_increment=True), 238 Column('author'), 239 Column('comment') 240 ] 241 ] 242 self.dbm = DatabaseManager(self.env) 243 self.dbm.drop_tables(self.schema) 244 self.dbm.create_tables(self.schema) 245 246 def tearDown(self): 247 DatabaseManager(self.env).drop_tables(self.schema) 248 self.env.reset_db() 249 250 def test_drop_column(self): 251 """Data is preserved when column is dropped.""" 252 table_data = [ 253 ('blog', ('author', 'comment'), 254 (('author1', 'comment one'), 255 ('author2', 'comment two'))), 256 ] 257 self.dbm.insert_into_tables(table_data) 258 259 with self.env.db_transaction as db: 260 db.drop_column('blog', 'comment') 261 262 data = list(self.env.db_query("SELECT * FROM blog")) 263 self.assertEqual((1, 'author1'), data[0]) 264 self.assertEqual((2, 'author2'), data[1]) 265 266 def test_drop_column_no_exists(self): 267 """Error is not raised when dropping non-existent column.""" 268 table_data = [ 269 ('blog', ('author', 'comment'), 270 (('author1', 'comment one'), 271 ('author2', 'comment two'))), 272 ] 273 self.dbm.insert_into_tables(table_data) 274 275 with self.env.db_transaction as db: 276 db.drop_column('blog', 'tags') 277 278 data = list(self.env.db_query("SELECT * FROM blog")) 279 self.assertEqual((1, 'author1', 'comment one'), data[0]) 280 self.assertEqual((2, 'author2', 'comment two'), data[1]) 281 282 def test_rollback_transaction_on_exception(self): 283 """Transaction is rolled back when an exception occurs in the 284 transaction context manager. 285 """ 286 insert_sql = "INSERT INTO blog (bid, author) VALUES (42, 'anonymous')" 287 try: 288 with self.env.db_transaction as db: 289 db(insert_sql) 290 db(insert_sql) 291 except self.env.db_exc.IntegrityError: 292 pass 293 294 for _, in self.env.db_query(""" 295 SELECT author FROM blog WHERE bid=42 296 """): 297 self.fail("Transaction was not rolled back") 298 299 def test_rollback_nested_transaction_on_exception(self): 300 """Transaction is rolled back when an exception occurs in the 301 inner transaction context manager. 302 """ 303 sql = "INSERT INTO blog (bid, author) VALUES (42, 'anonymous')" 304 try: 305 with self.env.db_transaction as db_outer: 306 db_outer(sql) 307 with self.env.db_transaction as db_inner: 308 db_inner(sql) 309 except self.env.db_exc.IntegrityError: 310 pass 311 312 for _, in self.env.db_query(""" 313 SELECT author FROM blog WHERE bid=42 314 """): 315 self.fail("Transaction was not rolled back") 316 317 def test_get_last_id(self): 318 q = "INSERT INTO report (author) VALUES ('anonymous')" 319 with self.env.db_transaction as db: 320 cursor = db.cursor() 321 cursor.execute(q) 322 # Row ID correct before... 323 id1 = db.get_last_id(cursor, 'report') 324 db.commit() 325 cursor.execute(q) 326 # ... and after commit() 327 db.commit() 328 id2 = db.get_last_id(cursor, 'report') 329 330 self.assertNotEqual(0, id1) 331 self.assertEqual(id1 + 1, id2) 332 333 def test_update_sequence_default_column_name(self): 334 with self.env.db_transaction as db: 335 db("INSERT INTO report (id, author) VALUES (42, 'anonymous')") 336 cursor = db.cursor() 337 db.update_sequence(cursor, 'report') 338 339 self.env.db_transaction( 340 "INSERT INTO report (author) VALUES ('next-id')") 341 342 self.assertEqual(43, self.env.db_query( 343 "SELECT id FROM report WHERE author='next-id'")[0][0]) 344 345 def test_update_sequence_nondefault_column_name(self): 346 with self.env.db_transaction as db: 347 cursor = db.cursor() 348 cursor.execute( 349 "INSERT INTO blog (bid, author) VALUES (42, 'anonymous')") 350 db.update_sequence(cursor, 'blog', 'bid') 351 352 self.env.db_transaction( 353 "INSERT INTO blog (author) VALUES ('next-id')") 354 355 self.assertEqual(43, self.env.db_query( 356 "SELECT bid FROM blog WHERE author='next-id'")[0][0]) 357 358 def test_identifiers_need_quoting(self): 359 """Test for regression described in comment:4:ticket:11512.""" 360 with self.env.db_transaction as db: 361 db("INSERT INTO %s (%s, %s) VALUES (42, 'anonymous')" 362 % (db.quote('HOURS'), db.quote('ID'), db.quote('AUTHOR'))) 363 cursor = db.cursor() 364 db.update_sequence(cursor, 'HOURS', 'ID') 365 366 with self.env.db_transaction as db: 367 cursor = db.cursor() 368 cursor.execute( 369 "INSERT INTO %s (%s) VALUES ('next-id')" 370 % (db.quote('HOURS'), db.quote('AUTHOR'))) 371 last_id = db.get_last_id(cursor, 'HOURS', 'ID') 372 373 self.assertEqual(43, last_id) 374 375 def test_get_table_names(self): 376 schema = default_schema + self.schema 377 with self.env.db_query as db: 378 # Some DB (e.g. MariaDB) normalize the table names to lower case 379 self.assertEqual( 380 sorted(table.name.lower() for table in schema), 381 sorted(name.lower() for name in db.get_table_names())) 382 383 def test_get_column_names(self): 384 schema = default_schema + self.schema 385 with self.env.db_query as db: 386 for table in schema: 387 column_names = [col.name for col in table.columns] 388 self.assertEqual(column_names, 389 db.get_column_names(table.name)) 390 391 def test_get_column_names_non_existent_table(self): 392 with self.assertRaises(self.env.db_exc.OperationalError) as cm: 393 self.dbm.get_column_names('blah') 394 self.assertIn(str(cm.exception), ('Table "blah" not found', 395 'Table `blah` not found')) 396 397 398class DatabaseManagerTestCase(unittest.TestCase): 399 400 def setUp(self): 401 self.env = EnvironmentStub(default_data=True) 402 self.dbm = DatabaseManager(self.env) 403 404 def tearDown(self): 405 self.env.reset_db() 406 407 def test_destroy_db(self): 408 """Database doesn't exist after calling destroy_db.""" 409 with self.env.db_query as db: 410 db("SELECT name FROM " + db.quote('system')) 411 self.assertIsNotNone(self.dbm._cnx_pool) 412 self.dbm.destroy_db() 413 self.assertIsNone(self.dbm._cnx_pool) # No connection pool 414 scheme, params = parse_connection_uri(get_dburi()) 415 if scheme != 'postgres' or params.get('schema', 'public') != 'public': 416 self.assertFalse(self.dbm.db_exists()) 417 else: 418 self.assertEqual([], self.dbm.get_table_names()) 419 420 def test_get_column_names(self): 421 """Get column names for the default database.""" 422 for table in default_schema: 423 column_names = [col.name for col in table.columns] 424 self.assertEqual(column_names, 425 self.dbm.get_column_names(table.name)) 426 427 def test_get_default_database_version(self): 428 """Get database version for the default entry named 429 `database_version`. 430 """ 431 self.assertEqual(default_db_version, self.dbm.get_database_version()) 432 433 def test_get_table_names(self): 434 """Get table names for the default database.""" 435 self.assertEqual(sorted(table.name for table in default_schema), 436 sorted(self.dbm.get_table_names())) 437 438 def test_has_table(self): 439 self.assertIs(True, self.dbm.has_table('system')) 440 self.assertIs(True, self.dbm.has_table('wiki')) 441 self.assertIs(False, self.dbm.has_table('trac')) 442 self.assertIs(False, self.dbm.has_table('blah.blah')) 443 444 def test_no_database_version(self): 445 """False is returned when entry doesn't exist""" 446 self.assertFalse(self.dbm.get_database_version('trac_plugin_version')) 447 448 def test_set_default_database_version(self): 449 """Set database version for the default entry named 450 `database_version`. 451 """ 452 new_db_version = default_db_version + 1 453 self.dbm.set_database_version(new_db_version) 454 self.assertEqual(new_db_version, self.dbm.get_database_version()) 455 self.assertEqual([('INFO', 'Upgraded database_version from 45 to 46')], 456 self.env.log_messages) 457 458 # Restore the previous version to avoid destroying the database 459 # on teardown 460 self.dbm.set_database_version(default_db_version) 461 self.assertEqual(default_db_version, self.dbm.get_database_version()) 462 463 def test_set_get_plugin_database_version(self): 464 """Get and set database version for an entry with an 465 arbitrary name. 466 """ 467 name = 'trac_plugin_version' 468 db_ver = 1 469 470 self.dbm.set_database_version(db_ver, name) 471 self.assertEqual([], self.env.log_messages) 472 self.assertEqual(db_ver, self.dbm.get_database_version(name)) 473 # DB update will be skipped when new value equals database version 474 self.dbm.set_database_version(db_ver, name) 475 self.assertEqual([], self.env.log_messages) 476 477 def test_get_sequence_names(self): 478 sequence_names = [] 479 if self.dbm.connection_uri.startswith('postgres'): 480 for table in default_schema: 481 for column in table.columns: 482 if column.name == 'id' and column.auto_increment: 483 sequence_names.append(table.name) 484 sequence_names.sort() 485 486 self.assertEqual(sequence_names, self.dbm.get_sequence_names()) 487 488 489class ModifyTableTestCase(unittest.TestCase): 490 491 def setUp(self): 492 self.env = EnvironmentStub() 493 self.dbm = DatabaseManager(self.env) 494 self.schema = [ 495 Table('table1', key='col1')[ 496 Column('col1', auto_increment=True), 497 Column('col2'), 498 Column('col3'), 499 ], 500 Table('table2', key='col1')[ 501 Column('col1'), 502 Column('col2'), 503 ], 504 Table('table3', key='col2')[ 505 Column('col1'), 506 Column('col2', type='int'), 507 Column('col3') 508 ] 509 ] 510 self.dbm.create_tables(self.schema) 511 self.new_schema = copy.deepcopy([self.schema[0], self.schema[2]]) 512 self.new_schema[0].remove_columns(('col2',)) 513 self.new_schema[1].columns.append(Column('col4')) 514 self.new_schema.append( 515 Table('table4')[ 516 Column('col1'), 517 ] 518 ) 519 520 def tearDown(self): 521 self.dbm.drop_tables(['table1', 'table2', 'table3', 'table4']) 522 self.env.reset_db() 523 524 def _insert_data(self): 525 table_data = [ 526 ('table1', ('col2', 'col3'), 527 (('data1', 'data2'), 528 ('data3', 'data4'))), 529 ('table2', ('col1', 'col2'), 530 (('data5', 'data6'), 531 ('data7', 'data8'))), 532 ('table3', ('col1', 'col2', 'col3'), 533 (('data9', 10, 'data11'), 534 ('data12', 13, 'data14'))), 535 ] 536 self.dbm.insert_into_tables(table_data) 537 538 def test_drop_columns(self): 539 """Data is preserved when column is dropped.""" 540 self._insert_data() 541 542 self.dbm.drop_columns('table1', ('col2',)) 543 544 self.assertEqual(['col1', 'col3'], self.dbm.get_column_names('table1')) 545 data = list(self.env.db_query("SELECT * FROM table1")) 546 self.assertEqual((1, 'data2'), data[0]) 547 self.assertEqual((2, 'data4'), data[1]) 548 549 def test_drop_columns_multiple_columns(self): 550 """Data is preserved when columns are dropped.""" 551 self._insert_data() 552 553 self.dbm.drop_columns('table3', ('col1', 'col3')) 554 555 self.assertEqual(['col2'], self.dbm.get_column_names('table3')) 556 data = list(self.env.db_query("SELECT * FROM table3")) 557 self.assertEqual((10,), data[0]) 558 self.assertEqual((13,), data[1]) 559 560 def test_drop_columns_non_existent_table(self): 561 with self.assertRaises(self.env.db_exc.OperationalError) as cm: 562 self.dbm.drop_columns('blah', ('col1',)) 563 self.assertIn(str(cm.exception), ('Table "blah" not found', 564 'Table `blah` not found')) 565 566 def test_upgrade_tables_have_new_schema(self): 567 """The upgraded tables have the new schema.""" 568 self.dbm.upgrade_tables(self.new_schema) 569 570 for table in self.new_schema: 571 self.assertEqual([col.name for col in table.columns], 572 self.dbm.get_column_names(table.name)) 573 574 def test_upgrade_tables_data_is_migrated(self): 575 """The data is migrated to the upgraded tables.""" 576 self._insert_data() 577 578 self.dbm.upgrade_tables(self.new_schema) 579 self.env.db_transaction(""" 580 INSERT INTO table1 (col3) VALUES ('data12') 581 """) 582 583 data = list(self.env.db_query("SELECT * FROM table1")) 584 self.assertEqual((1, 'data2'), data[0]) 585 self.assertEqual((2, 'data4'), data[1]) 586 self.assertEqual(3, self.env.db_query(""" 587 SELECT col1 FROM table1 WHERE col3='data12'""")[0][0]) 588 data = list(self.env.db_query("SELECT * FROM table2")) 589 self.assertEqual(('data5', 'data6'), data[0]) 590 self.assertEqual(('data7', 'data8'), data[1]) 591 data = list(self.env.db_query("SELECT * FROM table3")) 592 self.assertEqual(('data9', 10, 'data11', None), data[0]) 593 self.assertEqual(('data12', 13, 'data14', None), data[1]) 594 595 def test_upgrade_tables_no_common_columns(self): 596 schema = [ 597 Table('table1', key='id')[ 598 Column('id', auto_increment=True), 599 Column('name'), 600 Column('value'), 601 ], 602 ] 603 self.dbm.upgrade_tables(schema) 604 self.assertEqual(['id', 'name', 'value'], 605 self.dbm.get_column_names('table1')) 606 self.assertEqual([], list(self.env.db_query("SELECT * FROM table1"))) 607 608 609def test_suite(): 610 suite = unittest.TestSuite() 611 suite.addTest(unittest.makeSuite(ParseConnectionStringTestCase)) 612 suite.addTest(unittest.makeSuite(StringsTestCase)) 613 suite.addTest(unittest.makeSuite(ConnectionTestCase)) 614 suite.addTest(unittest.makeSuite(DatabaseManagerTestCase)) 615 suite.addTest(unittest.makeSuite(ModifyTableTestCase)) 616 return suite 617 618 619if __name__ == '__main__': 620 unittest.main(defaultTest='test_suite') 621