1import datetime 2import os 3import re 4 5from dateutil import tz 6import sqlalchemy as sa 7from sqlalchemy import inspect 8 9from alembic import autogenerate 10from alembic import command 11from alembic import util 12from alembic.environment import EnvironmentContext 13from alembic.operations import ops 14from alembic.script import ScriptDirectory 15from alembic.testing import assert_raises_message 16from alembic.testing import assertions 17from alembic.testing import eq_ 18from alembic.testing import is_ 19from alembic.testing import mock 20from alembic.testing import ne_ 21from alembic.testing.env import _get_staging_directory 22from alembic.testing.env import _multi_dir_testing_config 23from alembic.testing.env import _multidb_testing_config 24from alembic.testing.env import _no_sql_testing_config 25from alembic.testing.env import _sqlite_file_db 26from alembic.testing.env import _sqlite_testing_config 27from alembic.testing.env import _testing_config 28from alembic.testing.env import clear_staging_env 29from alembic.testing.env import env_file_fixture 30from alembic.testing.env import script_file_fixture 31from alembic.testing.env import staging_env 32from alembic.testing.env import three_rev_fixture 33from alembic.testing.env import write_script 34from alembic.testing.fixtures import TestBase 35from alembic.util import CommandError 36 37env, abc, def_ = None, None, None 38 39 40class GeneralOrderedTests(TestBase): 41 def setUp(self): 42 global env 43 env = staging_env() 44 45 def tearDown(self): 46 clear_staging_env() 47 48 def test_steps(self): 49 self._test_001_environment() 50 self._test_002_rev_ids() 51 self._test_003_api_methods_clean() 52 self._test_004_rev() 53 self._test_005_nextrev() 54 self._test_006_from_clean_env() 55 self._test_007_long_name() 56 self._test_008_long_name_configurable() 57 58 def _test_001_environment(self): 59 assert_set = set(["env.py", "script.py.mako", "README"]) 60 eq_(assert_set.intersection(os.listdir(env.dir)), assert_set) 61 62 def _test_002_rev_ids(self): 63 global abc, def_ 64 abc = util.rev_id() 65 def_ = util.rev_id() 66 ne_(abc, def_) 67 68 def _test_003_api_methods_clean(self): 69 eq_(env.get_heads(), []) 70 71 eq_(env.get_base(), None) 72 73 def _test_004_rev(self): 74 script = env.generate_revision(abc, "this is a message", refresh=True) 75 eq_(script.doc, "this is a message") 76 eq_(script.revision, abc) 77 eq_(script.down_revision, None) 78 assert os.access( 79 os.path.join(env.dir, "versions", "%s_this_is_a_message.py" % abc), 80 os.F_OK, 81 ) 82 assert callable(script.module.upgrade) 83 eq_(env.get_heads(), [abc]) 84 eq_(env.get_base(), abc) 85 86 def _test_005_nextrev(self): 87 script = env.generate_revision( 88 def_, "this is the next rev", refresh=True 89 ) 90 assert os.access( 91 os.path.join( 92 env.dir, "versions", "%s_this_is_the_next_rev.py" % def_ 93 ), 94 os.F_OK, 95 ) 96 eq_(script.revision, def_) 97 eq_(script.down_revision, abc) 98 eq_(env.get_revision(abc).nextrev, set([def_])) 99 assert script.module.down_revision == abc 100 assert callable(script.module.upgrade) 101 assert callable(script.module.downgrade) 102 eq_(env.get_heads(), [def_]) 103 eq_(env.get_base(), abc) 104 105 def _test_006_from_clean_env(self): 106 # test the environment so far with a 107 # new ScriptDirectory instance. 108 109 env = staging_env(create=False) 110 abc_rev = env.get_revision(abc) 111 def_rev = env.get_revision(def_) 112 eq_(abc_rev.nextrev, set([def_])) 113 eq_(abc_rev.revision, abc) 114 eq_(def_rev.down_revision, abc) 115 eq_(env.get_heads(), [def_]) 116 eq_(env.get_base(), abc) 117 118 def _test_007_long_name(self): 119 rid = util.rev_id() 120 env.generate_revision( 121 rid, 122 "this is a really long name with " 123 "lots of characters and also " 124 "I'd like it to\nhave\nnewlines", 125 ) 126 assert os.access( 127 os.path.join( 128 env.dir, 129 "versions", 130 "%s_this_is_a_really_long_name_with_lots_of_.py" % rid, 131 ), 132 os.F_OK, 133 ) 134 135 def _test_008_long_name_configurable(self): 136 env.truncate_slug_length = 60 137 rid = util.rev_id() 138 env.generate_revision( 139 rid, 140 "this is a really long name with " 141 "lots of characters and also " 142 "I'd like it to\nhave\nnewlines", 143 ) 144 assert os.access( 145 os.path.join( 146 env.dir, 147 "versions", 148 "%s_this_is_a_really_long_name_with_lots_" 149 "of_characters_and_also_.py" % rid, 150 ), 151 os.F_OK, 152 ) 153 154 155class ScriptNamingTest(TestBase): 156 @classmethod 157 def setup_class(cls): 158 _testing_config() 159 160 @classmethod 161 def teardown_class(cls): 162 clear_staging_env() 163 164 def test_args(self): 165 script = ScriptDirectory( 166 _get_staging_directory(), 167 file_template="%(rev)s_%(slug)s_" 168 "%(year)s_%(month)s_" 169 "%(day)s_%(hour)s_" 170 "%(minute)s_%(second)s", 171 ) 172 create_date = datetime.datetime(2012, 7, 25, 15, 8, 5) 173 eq_( 174 script._rev_path( 175 script.versions, "12345", "this is a message", create_date 176 ), 177 os.path.abspath( 178 "%s/versions/12345_this_is_a_" 179 "message_2012_7_25_15_8_5.py" % _get_staging_directory() 180 ), 181 ) 182 183 def _test_tz(self, timezone_arg, given, expected): 184 script = ScriptDirectory( 185 _get_staging_directory(), 186 file_template="%(rev)s_%(slug)s_" 187 "%(year)s_%(month)s_" 188 "%(day)s_%(hour)s_" 189 "%(minute)s_%(second)s", 190 timezone=timezone_arg, 191 ) 192 193 with mock.patch( 194 "alembic.script.base.datetime", 195 mock.Mock( 196 datetime=mock.Mock(utcnow=lambda: given, now=lambda: given) 197 ), 198 ): 199 create_date = script._generate_create_date() 200 eq_(create_date, expected) 201 202 def test_custom_tz(self): 203 self._test_tz( 204 "EST5EDT", 205 datetime.datetime(2012, 7, 25, 15, 8, 5), 206 datetime.datetime( 207 2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz("EST5EDT") 208 ), 209 ) 210 211 def test_custom_tz_lowercase(self): 212 self._test_tz( 213 "est5edt", 214 datetime.datetime(2012, 7, 25, 15, 8, 5), 215 datetime.datetime( 216 2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz("EST5EDT") 217 ), 218 ) 219 220 def test_custom_tz_utc(self): 221 self._test_tz( 222 "utc", 223 datetime.datetime(2012, 7, 25, 15, 8, 5), 224 datetime.datetime(2012, 7, 25, 15, 8, 5, tzinfo=tz.gettz("UTC")), 225 ) 226 227 def test_custom_tzdata_tz(self): 228 self._test_tz( 229 "Europe/Berlin", 230 datetime.datetime(2012, 7, 25, 15, 8, 5), 231 datetime.datetime( 232 2012, 7, 25, 17, 8, 5, tzinfo=tz.gettz("Europe/Berlin") 233 ), 234 ) 235 236 def test_default_tz(self): 237 self._test_tz( 238 None, 239 datetime.datetime(2012, 7, 25, 15, 8, 5), 240 datetime.datetime(2012, 7, 25, 15, 8, 5), 241 ) 242 243 def test_tz_cant_locate(self): 244 assert_raises_message( 245 CommandError, 246 "Can't locate timezone: fake", 247 self._test_tz, 248 "fake", 249 datetime.datetime(2012, 7, 25, 15, 8, 5), 250 datetime.datetime(2012, 7, 25, 15, 8, 5), 251 ) 252 253 254class RevisionCommandTest(TestBase): 255 def setUp(self): 256 self.env = staging_env() 257 self.cfg = _sqlite_testing_config() 258 self.a, self.b, self.c = three_rev_fixture(self.cfg) 259 260 def tearDown(self): 261 clear_staging_env() 262 263 def test_create_script_basic(self): 264 rev = command.revision(self.cfg, message="some message") 265 script = ScriptDirectory.from_config(self.cfg) 266 rev = script.get_revision(rev.revision) 267 eq_(rev.down_revision, self.c) 268 assert "some message" in rev.doc 269 270 def test_create_script_splice(self): 271 rev = command.revision( 272 self.cfg, message="some message", head=self.b, splice=True 273 ) 274 script = ScriptDirectory.from_config(self.cfg) 275 rev = script.get_revision(rev.revision) 276 eq_(rev.down_revision, self.b) 277 assert "some message" in rev.doc 278 eq_(set(script.get_heads()), set([rev.revision, self.c])) 279 280 def test_create_script_missing_splice(self): 281 assert_raises_message( 282 util.CommandError, 283 "Revision %s is not a head revision; please specify --splice " 284 "to create a new branch from this revision" % self.b, 285 command.revision, 286 self.cfg, 287 message="some message", 288 head=self.b, 289 ) 290 291 def test_illegal_revision_chars(self): 292 assert_raises_message( 293 util.CommandError, 294 r"Character\(s\) '-' not allowed in " 295 "revision identifier 'no-dashes'", 296 command.revision, 297 self.cfg, 298 message="some message", 299 rev_id="no-dashes", 300 ) 301 302 assert not os.path.exists( 303 os.path.join(self.env.dir, "versions", "no-dashes_some_message.py") 304 ) 305 306 assert_raises_message( 307 util.CommandError, 308 r"Character\(s\) '@' not allowed in " 309 "revision identifier 'no@atsigns'", 310 command.revision, 311 self.cfg, 312 message="some message", 313 rev_id="no@atsigns", 314 ) 315 316 assert_raises_message( 317 util.CommandError, 318 r"Character\(s\) '-, @' not allowed in revision " 319 "identifier 'no@atsigns-ordashes'", 320 command.revision, 321 self.cfg, 322 message="some message", 323 rev_id="no@atsigns-ordashes", 324 ) 325 326 assert_raises_message( 327 util.CommandError, 328 r"Character\(s\) '\+' not allowed in revision " 329 r"identifier 'no\+plussignseither'", 330 command.revision, 331 self.cfg, 332 message="some message", 333 rev_id="no+plussignseither", 334 ) 335 336 def test_create_script_branches(self): 337 rev = command.revision( 338 self.cfg, message="some message", branch_label="foobar" 339 ) 340 script = ScriptDirectory.from_config(self.cfg) 341 rev = script.get_revision(rev.revision) 342 eq_(script.get_revision("foobar"), rev) 343 344 def test_create_script_branches_old_template(self): 345 script = ScriptDirectory.from_config(self.cfg) 346 with open(os.path.join(script.dir, "script.py.mako"), "w") as file_: 347 file_.write( 348 "<%text>#</%text> ${message}\n" 349 "revision = ${repr(up_revision)}\n" 350 "down_revision = ${repr(down_revision)}\n\n" 351 "def upgrade():\n" 352 " ${upgrades if upgrades else 'pass'}\n\n" 353 "def downgrade():\n" 354 " ${downgrade if downgrades else 'pass'}\n\n" 355 ) 356 357 # works OK if no branch names 358 command.revision(self.cfg, message="some message") 359 360 assert_raises_message( 361 util.CommandError, 362 r"Version \w+ specified branch_labels foobar, " 363 r"however the migration file .+?\b does not have them; have you " 364 "upgraded your script.py.mako to include the 'branch_labels' " 365 r"section\?", 366 command.revision, 367 self.cfg, 368 message="some message", 369 branch_label="foobar", 370 ) 371 372 373class CustomizeRevisionTest(TestBase): 374 def setUp(self): 375 self.env = staging_env() 376 self.cfg = _multi_dir_testing_config() 377 self.cfg.set_main_option("revision_environment", "true") 378 379 script = ScriptDirectory.from_config(self.cfg) 380 self.model1 = util.rev_id() 381 self.model2 = util.rev_id() 382 self.model3 = util.rev_id() 383 for model, name in [ 384 (self.model1, "model1"), 385 (self.model2, "model2"), 386 (self.model3, "model3"), 387 ]: 388 script.generate_revision( 389 model, 390 name, 391 refresh=True, 392 version_path=os.path.join(_get_staging_directory(), name), 393 head="base", 394 ) 395 396 write_script( 397 script, 398 model, 399 """\ 400"%s" 401revision = '%s' 402down_revision = None 403branch_labels = ['%s'] 404 405from alembic import op 406 407 408def upgrade(): 409 pass 410 411 412def downgrade(): 413 pass 414 415""" 416 % (name, model, name), 417 ) 418 419 def tearDown(self): 420 clear_staging_env() 421 422 def _env_fixture(self, fn, target_metadata): 423 self.engine = engine = _sqlite_file_db() 424 425 def run_env(self): 426 from alembic import context 427 428 with engine.connect() as connection: 429 context.configure( 430 connection=connection, 431 target_metadata=target_metadata, 432 process_revision_directives=fn, 433 ) 434 with context.begin_transaction(): 435 context.run_migrations() 436 437 return mock.patch( 438 "alembic.script.base.ScriptDirectory.run_env", run_env 439 ) 440 441 def test_new_locations_no_autogen(self): 442 m = sa.MetaData() 443 444 def process_revision_directives(context, rev, generate_revisions): 445 generate_revisions[:] = [ 446 ops.MigrationScript( 447 util.rev_id(), 448 ops.UpgradeOps(), 449 ops.DowngradeOps(), 450 version_path=os.path.join( 451 _get_staging_directory(), "model1" 452 ), 453 head="model1@head", 454 ), 455 ops.MigrationScript( 456 util.rev_id(), 457 ops.UpgradeOps(), 458 ops.DowngradeOps(), 459 version_path=os.path.join( 460 _get_staging_directory(), "model2" 461 ), 462 head="model2@head", 463 ), 464 ops.MigrationScript( 465 util.rev_id(), 466 ops.UpgradeOps(), 467 ops.DowngradeOps(), 468 version_path=os.path.join( 469 _get_staging_directory(), "model3" 470 ), 471 head="model3@head", 472 ), 473 ] 474 475 with self._env_fixture(process_revision_directives, m): 476 revs = command.revision(self.cfg, message="some message") 477 478 script = ScriptDirectory.from_config(self.cfg) 479 480 for rev, model in [ 481 (revs[0], "model1"), 482 (revs[1], "model2"), 483 (revs[2], "model3"), 484 ]: 485 rev_script = script.get_revision(rev.revision) 486 eq_( 487 rev_script.path, 488 os.path.abspath( 489 os.path.join( 490 _get_staging_directory(), 491 model, 492 "%s_.py" % (rev_script.revision,), 493 ) 494 ), 495 ) 496 assert os.path.exists(rev_script.path) 497 498 def test_renders_added_directives_no_autogen(self): 499 m = sa.MetaData() 500 501 def process_revision_directives(context, rev, generate_revisions): 502 generate_revisions[0].upgrade_ops.ops.append( 503 ops.CreateIndexOp("some_index", "some_table", ["a", "b"]) 504 ) 505 506 with self._env_fixture(process_revision_directives, m): 507 rev = command.revision( 508 self.cfg, message="some message", head="model1@head", sql=True 509 ) 510 511 with mock.patch.object(rev.module, "op") as op_mock: 512 rev.module.upgrade() 513 eq_( 514 op_mock.mock_calls, 515 [ 516 mock.call.create_index( 517 "some_index", "some_table", ["a", "b"], unique=False 518 ) 519 ], 520 ) 521 522 def test_autogen(self): 523 m = sa.MetaData() 524 sa.Table("t", m, sa.Column("x", sa.Integer)) 525 526 def process_revision_directives(context, rev, generate_revisions): 527 existing_upgrades = generate_revisions[0].upgrade_ops 528 existing_downgrades = generate_revisions[0].downgrade_ops 529 530 # model1 will run the upgrades, e.g. create the table, 531 # model2 will run the downgrades as upgrades, e.g. drop 532 # the table again 533 534 generate_revisions[:] = [ 535 ops.MigrationScript( 536 util.rev_id(), 537 existing_upgrades, 538 ops.DowngradeOps(), 539 version_path=os.path.join( 540 _get_staging_directory(), "model1" 541 ), 542 head="model1@head", 543 ), 544 ops.MigrationScript( 545 util.rev_id(), 546 ops.UpgradeOps(ops=existing_downgrades.ops), 547 ops.DowngradeOps(), 548 version_path=os.path.join( 549 _get_staging_directory(), "model2" 550 ), 551 head="model2@head", 552 ), 553 ] 554 555 with self._env_fixture(process_revision_directives, m): 556 command.upgrade(self.cfg, "heads") 557 558 eq_(inspect(self.engine).get_table_names(), ["alembic_version"]) 559 560 command.revision( 561 self.cfg, message="some message", autogenerate=True 562 ) 563 564 command.upgrade(self.cfg, "model1@head") 565 566 eq_( 567 inspect(self.engine).get_table_names(), 568 ["alembic_version", "t"], 569 ) 570 571 command.upgrade(self.cfg, "model2@head") 572 573 eq_(inspect(self.engine).get_table_names(), ["alembic_version"]) 574 575 def test_programmatic_command_option(self): 576 def process_revision_directives(context, rev, generate_revisions): 577 generate_revisions[0].message = "test programatic" 578 generate_revisions[0].upgrade_ops = ops.UpgradeOps( 579 ops=[ 580 ops.CreateTableOp( 581 "test_table", 582 [ 583 sa.Column("id", sa.Integer(), primary_key=True), 584 sa.Column("name", sa.String(50), nullable=False), 585 ], 586 ) 587 ] 588 ) 589 generate_revisions[0].downgrade_ops = ops.DowngradeOps( 590 ops=[ops.DropTableOp("test_table")] 591 ) 592 593 with self._env_fixture(None, None): 594 rev = command.revision( 595 self.cfg, 596 head="model1@head", 597 process_revision_directives=process_revision_directives, 598 ) 599 600 with open(rev.path) as handle: 601 result = handle.read() 602 assert ( 603 ( 604 """ 605def upgrade(): 606 # ### commands auto generated by Alembic - please adjust! ### 607 op.create_table('test_table', 608 sa.Column('id', sa.Integer(), nullable=False), 609 sa.Column('name', sa.String(length=50), nullable=False), 610 sa.PrimaryKeyConstraint('id') 611 ) 612 # ### end Alembic commands ### 613""" 614 ) 615 in result 616 ) 617 618 619class ScriptAccessorTest(TestBase): 620 def test_upgrade_downgrade_ops_list_accessors(self): 621 u1 = ops.UpgradeOps(ops=[]) 622 d1 = ops.DowngradeOps(ops=[]) 623 m1 = ops.MigrationScript("somerev", u1, d1) 624 is_(m1.upgrade_ops, u1) 625 is_(m1.downgrade_ops, d1) 626 u2 = ops.UpgradeOps(ops=[]) 627 d2 = ops.DowngradeOps(ops=[]) 628 m1._upgrade_ops.append(u2) 629 m1._downgrade_ops.append(d2) 630 631 assert_raises_message( 632 ValueError, 633 "This MigrationScript instance has a multiple-entry list for " 634 "UpgradeOps; please use the upgrade_ops_list attribute.", 635 getattr, 636 m1, 637 "upgrade_ops", 638 ) 639 assert_raises_message( 640 ValueError, 641 "This MigrationScript instance has a multiple-entry list for " 642 "DowngradeOps; please use the downgrade_ops_list attribute.", 643 getattr, 644 m1, 645 "downgrade_ops", 646 ) 647 eq_(m1.upgrade_ops_list, [u1, u2]) 648 eq_(m1.downgrade_ops_list, [d1, d2]) 649 650 651class ImportsTest(TestBase): 652 def setUp(self): 653 self.env = staging_env() 654 self.cfg = _sqlite_testing_config() 655 656 def tearDown(self): 657 clear_staging_env() 658 659 def _env_fixture(self, target_metadata, **kw): 660 self.engine = engine = _sqlite_file_db() 661 662 def run_env(self): 663 from alembic import context 664 665 with engine.connect() as connection: 666 context.configure( 667 connection=connection, 668 target_metadata=target_metadata, 669 **kw 670 ) 671 with context.begin_transaction(): 672 context.run_migrations() 673 674 return mock.patch( 675 "alembic.script.base.ScriptDirectory.run_env", run_env 676 ) 677 678 def test_imports_in_script(self): 679 from sqlalchemy import MetaData, Table, Column 680 from sqlalchemy.dialects.mysql import VARCHAR 681 682 type_ = VARCHAR(20, charset="utf8", national=True) 683 684 m = MetaData() 685 686 Table("t", m, Column("x", type_)) 687 688 def process_revision_directives(context, rev, generate_revisions): 689 generate_revisions[0].imports.add( 690 "from sqlalchemy.dialects.mysql import TINYINT" 691 ) 692 693 with self._env_fixture( 694 m, process_revision_directives=process_revision_directives 695 ): 696 rev = command.revision( 697 self.cfg, message="some message", autogenerate=True 698 ) 699 700 with open(rev.path) as file_: 701 contents = file_.read() 702 assert "from sqlalchemy.dialects import mysql" in contents 703 assert "from sqlalchemy.dialects.mysql import TINYINT" in contents 704 705 706class MultiContextTest(TestBase): 707 """test the multidb template for autogenerate front-to-back""" 708 709 def setUp(self): 710 self.engine1 = _sqlite_file_db(tempname="eng1.db") 711 self.engine2 = _sqlite_file_db(tempname="eng2.db") 712 self.engine3 = _sqlite_file_db(tempname="eng3.db") 713 714 self.env = staging_env(template="multidb") 715 self.cfg = _multidb_testing_config( 716 { 717 "engine1": self.engine1, 718 "engine2": self.engine2, 719 "engine3": self.engine3, 720 } 721 ) 722 723 def _write_metadata(self, meta): 724 path = os.path.join(_get_staging_directory(), "scripts", "env.py") 725 with open(path) as env_: 726 existing_env = env_.read() 727 existing_env = existing_env.replace("target_metadata = {}", meta) 728 with open(path, "w") as env_: 729 env_.write(existing_env) 730 731 def tearDown(self): 732 clear_staging_env() 733 734 def test_autogen(self): 735 self._write_metadata( 736 """ 737import sqlalchemy as sa 738 739m1 = sa.MetaData() 740m2 = sa.MetaData() 741m3 = sa.MetaData() 742target_metadata = {"engine1": m1, "engine2": m2, "engine3": m3} 743 744sa.Table('e1t1', m1, sa.Column('x', sa.Integer)) 745sa.Table('e2t1', m2, sa.Column('y', sa.Integer)) 746sa.Table('e3t1', m3, sa.Column('z', sa.Integer)) 747 748""" 749 ) 750 751 rev = command.revision( 752 self.cfg, message="some message", autogenerate=True 753 ) 754 with mock.patch.object(rev.module, "op") as op_mock: 755 rev.module.upgrade_engine1() 756 eq_( 757 op_mock.mock_calls[-1], 758 mock.call.create_table("e1t1", mock.ANY), 759 ) 760 rev.module.upgrade_engine2() 761 eq_( 762 op_mock.mock_calls[-1], 763 mock.call.create_table("e2t1", mock.ANY), 764 ) 765 rev.module.upgrade_engine3() 766 eq_( 767 op_mock.mock_calls[-1], 768 mock.call.create_table("e3t1", mock.ANY), 769 ) 770 rev.module.downgrade_engine1() 771 eq_(op_mock.mock_calls[-1], mock.call.drop_table("e1t1")) 772 rev.module.downgrade_engine2() 773 eq_(op_mock.mock_calls[-1], mock.call.drop_table("e2t1")) 774 rev.module.downgrade_engine3() 775 eq_(op_mock.mock_calls[-1], mock.call.drop_table("e3t1")) 776 777 778class RewriterTest(TestBase): 779 def test_all_traverse(self): 780 writer = autogenerate.Rewriter() 781 782 mocker = mock.Mock(side_effect=lambda context, revision, op: op) 783 writer.rewrites(ops.MigrateOperation)(mocker) 784 785 addcolop = ops.AddColumnOp("t1", sa.Column("x", sa.Integer())) 786 787 directives = [ 788 ops.MigrationScript( 789 util.rev_id(), 790 ops.UpgradeOps(ops=[ops.ModifyTableOps("t1", ops=[addcolop])]), 791 ops.DowngradeOps(ops=[]), 792 ) 793 ] 794 795 ctx, rev = mock.Mock(), mock.Mock() 796 writer(ctx, rev, directives) 797 eq_( 798 mocker.mock_calls, 799 [ 800 mock.call(ctx, rev, directives[0]), 801 mock.call(ctx, rev, directives[0].upgrade_ops), 802 mock.call(ctx, rev, directives[0].upgrade_ops.ops[0]), 803 mock.call(ctx, rev, addcolop), 804 mock.call(ctx, rev, directives[0].downgrade_ops), 805 ], 806 ) 807 808 def test_double_migrate_table(self): 809 writer = autogenerate.Rewriter() 810 811 idx_ops = [] 812 813 @writer.rewrites(ops.ModifyTableOps) 814 def second_table(context, revision, op): 815 return [ 816 op, 817 ops.ModifyTableOps( 818 "t2", 819 ops=[ops.AddColumnOp("t2", sa.Column("x", sa.Integer()))], 820 ), 821 ] 822 823 @writer.rewrites(ops.AddColumnOp) 824 def add_column(context, revision, op): 825 idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name]) 826 idx_ops.append(idx_op) 827 return [op, idx_op] 828 829 directives = [ 830 ops.MigrationScript( 831 util.rev_id(), 832 ops.UpgradeOps( 833 ops=[ 834 ops.ModifyTableOps( 835 "t1", 836 ops=[ 837 ops.AddColumnOp( 838 "t1", sa.Column("x", sa.Integer()) 839 ) 840 ], 841 ) 842 ] 843 ), 844 ops.DowngradeOps(ops=[]), 845 ) 846 ] 847 848 ctx, rev = mock.Mock(), mock.Mock() 849 writer(ctx, rev, directives) 850 eq_( 851 [d.table_name for d in directives[0].upgrade_ops.ops], ["t1", "t2"] 852 ) 853 is_(directives[0].upgrade_ops.ops[0].ops[1], idx_ops[0]) 854 is_(directives[0].upgrade_ops.ops[1].ops[1], idx_ops[1]) 855 856 def test_chained_ops(self): 857 writer1 = autogenerate.Rewriter() 858 writer2 = autogenerate.Rewriter() 859 860 @writer1.rewrites(ops.AddColumnOp) 861 def add_column_nullable(context, revision, op): 862 if op.column.nullable: 863 return op 864 else: 865 op.column.nullable = True 866 return [ 867 op, 868 ops.AlterColumnOp( 869 op.table_name, 870 op.column.name, 871 modify_nullable=False, 872 existing_type=op.column.type, 873 ), 874 ] 875 876 @writer2.rewrites(ops.AddColumnOp) 877 def add_column_idx(context, revision, op): 878 idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name]) 879 return [op, idx_op] 880 881 directives = [ 882 ops.MigrationScript( 883 util.rev_id(), 884 ops.UpgradeOps( 885 ops=[ 886 ops.ModifyTableOps( 887 "t1", 888 ops=[ 889 ops.AddColumnOp( 890 "t1", 891 sa.Column( 892 "x", sa.Integer(), nullable=False 893 ), 894 ) 895 ], 896 ) 897 ] 898 ), 899 ops.DowngradeOps(ops=[]), 900 ) 901 ] 902 903 ctx, rev = mock.Mock(), mock.Mock() 904 writer1.chain(writer2)(ctx, rev, directives) 905 906 eq_( 907 autogenerate.render_python_code(directives[0].upgrade_ops), 908 "# ### commands auto generated by Alembic - please adjust! ###\n" 909 " op.add_column('t1', " 910 "sa.Column('x', sa.Integer(), nullable=True))\n" 911 " op.create_index('ixt', 't1', ['x'], unique=False)\n" 912 " op.alter_column('t1', 'x',\n" 913 " existing_type=sa.Integer(),\n" 914 " nullable=False)\n" 915 " # ### end Alembic commands ###", 916 ) 917 918 def test_no_needless_pass(self): 919 writer1 = autogenerate.Rewriter() 920 921 @writer1.rewrites(ops.AlterColumnOp) 922 def rewrite_alter_column(context, revision, op): 923 return [] 924 925 directives = [ 926 ops.MigrationScript( 927 util.rev_id(), 928 ops.UpgradeOps( 929 ops=[ 930 ops.ModifyTableOps( 931 "t1", 932 ops=[ 933 ops.AlterColumnOp( 934 "foo", 935 "bar", 936 modify_nullable=False, 937 existing_type=sa.Integer(), 938 ), 939 ops.AlterColumnOp( 940 "foo", 941 "bar", 942 modify_nullable=False, 943 existing_type=sa.Integer(), 944 ), 945 ], 946 ), 947 ops.ModifyTableOps( 948 "t1", 949 ops=[ 950 ops.AlterColumnOp( 951 "foo", 952 "bar", 953 modify_nullable=False, 954 existing_type=sa.Integer(), 955 ) 956 ], 957 ), 958 ] 959 ), 960 ops.DowngradeOps(ops=[]), 961 ) 962 ] 963 ctx, rev = mock.Mock(), mock.Mock() 964 writer1(ctx, rev, directives) 965 966 eq_( 967 autogenerate.render_python_code(directives[0].upgrade_ops), 968 "# ### commands auto generated by Alembic - please adjust! ###\n" 969 " pass\n" 970 " # ### end Alembic commands ###", 971 ) 972 973 def test_multiple_passes_with_mutations(self): 974 writer1 = autogenerate.Rewriter() 975 976 @writer1.rewrites(ops.CreateTableOp) 977 def rewrite_alter_column(context, revision, op): 978 op.table_name += "_pass" 979 return op 980 981 directives = [ 982 ops.MigrationScript( 983 util.rev_id(), 984 ops.UpgradeOps( 985 ops=[ 986 ops.CreateTableOp( 987 "test_table", 988 [sa.Column("id", sa.Integer(), primary_key=True)], 989 ) 990 ] 991 ), 992 ops.DowngradeOps(ops=[]), 993 ) 994 ] 995 ctx, rev = mock.Mock(), mock.Mock() 996 writer1(ctx, rev, directives) 997 998 directives[0].upgrade_ops_list.extend( 999 [ 1000 ops.UpgradeOps( 1001 ops=[ 1002 ops.CreateTableOp( 1003 "another_test_table", 1004 [sa.Column("id", sa.Integer(), primary_key=True)], 1005 ) 1006 ] 1007 ), 1008 ops.UpgradeOps( 1009 ops=[ 1010 ops.CreateTableOp( 1011 "third_test_table", 1012 [sa.Column("id", sa.Integer(), primary_key=True)], 1013 ) 1014 ] 1015 ), 1016 ] 1017 ) 1018 1019 writer1(ctx, rev, directives) 1020 1021 eq_( 1022 autogenerate.render_python_code(directives[0].upgrade_ops_list[0]), 1023 "# ### commands auto generated by Alembic - please adjust! ###\n" 1024 " op.create_table('test_table_pass',\n" 1025 " sa.Column('id', sa.Integer(), nullable=False),\n" 1026 " sa.PrimaryKeyConstraint('id')\n" 1027 " )\n" 1028 " # ### end Alembic commands ###", 1029 ) 1030 eq_( 1031 autogenerate.render_python_code(directives[0].upgrade_ops_list[1]), 1032 "# ### commands auto generated by Alembic - please adjust! ###\n" 1033 " op.create_table('another_test_table_pass',\n" 1034 " sa.Column('id', sa.Integer(), nullable=False),\n" 1035 " sa.PrimaryKeyConstraint('id')\n" 1036 " )\n" 1037 " # ### end Alembic commands ###", 1038 ) 1039 eq_( 1040 autogenerate.render_python_code(directives[0].upgrade_ops_list[2]), 1041 "# ### commands auto generated by Alembic - please adjust! ###\n" 1042 " op.create_table('third_test_table_pass',\n" 1043 " sa.Column('id', sa.Integer(), nullable=False),\n" 1044 " sa.PrimaryKeyConstraint('id')\n" 1045 " )\n" 1046 " # ### end Alembic commands ###", 1047 ) 1048 1049 1050class MultiDirRevisionCommandTest(TestBase): 1051 def setUp(self): 1052 self.env = staging_env() 1053 self.cfg = _multi_dir_testing_config() 1054 1055 def tearDown(self): 1056 clear_staging_env() 1057 1058 def test_multiple_dir_no_bases(self): 1059 assert_raises_message( 1060 util.CommandError, 1061 "Multiple version locations present, please specify " 1062 "--version-path", 1063 command.revision, 1064 self.cfg, 1065 message="some message", 1066 ) 1067 1068 def test_multiple_dir_no_bases_invalid_version_path(self): 1069 assert_raises_message( 1070 util.CommandError, 1071 "Path foo/bar/ is not represented in current version locations", 1072 command.revision, 1073 self.cfg, 1074 message="x", 1075 version_path=os.path.join("foo/bar/"), 1076 ) 1077 1078 def test_multiple_dir_no_bases_version_path(self): 1079 script = command.revision( 1080 self.cfg, 1081 message="x", 1082 version_path=os.path.join(_get_staging_directory(), "model1"), 1083 ) 1084 assert os.access(script.path, os.F_OK) 1085 1086 def test_multiple_dir_chooses_base(self): 1087 command.revision( 1088 self.cfg, 1089 message="x", 1090 head="base", 1091 version_path=os.path.join(_get_staging_directory(), "model1"), 1092 ) 1093 1094 script2 = command.revision( 1095 self.cfg, 1096 message="y", 1097 head="base", 1098 version_path=os.path.join(_get_staging_directory(), "model2"), 1099 ) 1100 1101 script3 = command.revision( 1102 self.cfg, message="y2", head=script2.revision 1103 ) 1104 1105 eq_( 1106 os.path.dirname(script3.path), 1107 os.path.abspath(os.path.join(_get_staging_directory(), "model2")), 1108 ) 1109 assert os.access(script3.path, os.F_OK) 1110 1111 1112class TemplateArgsTest(TestBase): 1113 def setUp(self): 1114 staging_env() 1115 self.cfg = _no_sql_testing_config( 1116 directives="\nrevision_environment=true\n" 1117 ) 1118 1119 def tearDown(self): 1120 clear_staging_env() 1121 1122 def test_args_propagate(self): 1123 config = _no_sql_testing_config() 1124 script = ScriptDirectory.from_config(config) 1125 template_args = {"x": "x1", "y": "y1", "z": "z1"} 1126 env = EnvironmentContext(config, script, template_args=template_args) 1127 env.configure( 1128 dialect_name="sqlite", template_args={"y": "y2", "q": "q1"} 1129 ) 1130 eq_(template_args, {"x": "x1", "y": "y2", "z": "z1", "q": "q1"}) 1131 1132 def test_tmpl_args_revision(self): 1133 env_file_fixture( 1134 """ 1135context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"}) 1136""" 1137 ) 1138 script_file_fixture( 1139 """ 1140# somearg: ${somearg} 1141revision = ${repr(up_revision)} 1142down_revision = ${repr(down_revision)} 1143""" 1144 ) 1145 1146 command.revision(self.cfg, message="some rev") 1147 script = ScriptDirectory.from_config(self.cfg) 1148 1149 rev = script.get_revision("head") 1150 with open(rev.path) as f: 1151 text = f.read() 1152 assert "somearg: somevalue" in text 1153 1154 def test_bad_render(self): 1155 env_file_fixture( 1156 """ 1157context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"}) 1158""" 1159 ) 1160 script_file_fixture( 1161 """ 1162 <% z = x + y %> 1163""" 1164 ) 1165 1166 try: 1167 command.revision(self.cfg, message="some rev") 1168 except CommandError as ce: 1169 m = re.match( 1170 r"^Template rendering failed; see (.+?) " 1171 "for a template-oriented", 1172 str(ce), 1173 ) 1174 assert m, "Command error did not produce a file" 1175 with open(m.group(1)) as handle: 1176 contents = handle.read() 1177 os.remove(m.group(1)) 1178 assert "<% z = x + y %>" in contents 1179 1180 1181class DuplicateVersionLocationsTest(TestBase): 1182 def setUp(self): 1183 self.env = staging_env() 1184 self.cfg = _multi_dir_testing_config( 1185 # this is a duplicate of one of the paths 1186 # already present in this fixture 1187 extra_version_location="%(here)s/model1" 1188 ) 1189 1190 script = ScriptDirectory.from_config(self.cfg) 1191 self.model1 = util.rev_id() 1192 self.model2 = util.rev_id() 1193 self.model3 = util.rev_id() 1194 for model, name in [ 1195 (self.model1, "model1"), 1196 (self.model2, "model2"), 1197 (self.model3, "model3"), 1198 ]: 1199 script.generate_revision( 1200 model, 1201 name, 1202 refresh=True, 1203 version_path=os.path.join(_get_staging_directory(), name), 1204 head="base", 1205 ) 1206 write_script( 1207 script, 1208 model, 1209 """\ 1210"%s" 1211revision = '%s' 1212down_revision = None 1213branch_labels = ['%s'] 1214 1215from alembic import op 1216 1217 1218def upgrade(): 1219 pass 1220 1221 1222def downgrade(): 1223 pass 1224 1225""" 1226 % (name, model, name), 1227 ) 1228 1229 def tearDown(self): 1230 clear_staging_env() 1231 1232 def test_env_emits_warning(self): 1233 msg = ( 1234 "File %s loaded twice! ignoring. " 1235 "Please ensure version_locations is unique." 1236 % ( 1237 os.path.realpath( 1238 os.path.join( 1239 _get_staging_directory(), 1240 "model1", 1241 "%s_model1.py" % self.model1, 1242 ) 1243 ) 1244 ) 1245 ) 1246 with assertions.expect_warnings(msg, regex=False): 1247 script = ScriptDirectory.from_config(self.cfg) 1248 script.revision_map.heads 1249 eq_( 1250 [rev.revision for rev in script.walk_revisions()], 1251 [self.model1, self.model2, self.model3], 1252 ) 1253 1254 1255class NormPathTest(TestBase): 1256 def setUp(self): 1257 self.env = staging_env() 1258 1259 def tearDown(self): 1260 clear_staging_env() 1261 1262 def test_script_location(self): 1263 config = _no_sql_testing_config() 1264 1265 script = ScriptDirectory.from_config(config) 1266 1267 def normpath(path): 1268 return path.replace("/", ":NORM:") 1269 1270 normpath = mock.Mock(side_effect=normpath) 1271 1272 with mock.patch("os.path.normpath", normpath): 1273 eq_( 1274 script._version_locations, 1275 ( 1276 os.path.abspath( 1277 os.path.join( 1278 _get_staging_directory(), "scripts", "versions" 1279 ) 1280 ).replace("/", ":NORM:"), 1281 ), 1282 ) 1283 1284 eq_( 1285 script.versions, 1286 os.path.abspath( 1287 os.path.join( 1288 _get_staging_directory(), "scripts", "versions" 1289 ) 1290 ).replace("/", ":NORM:"), 1291 ) 1292 1293 def test_script_location_muliple(self): 1294 config = _multi_dir_testing_config() 1295 1296 script = ScriptDirectory.from_config(config) 1297 1298 def normpath(path): 1299 return path.replace("/", ":NORM:") 1300 1301 normpath = mock.Mock(side_effect=normpath) 1302 1303 with mock.patch("os.path.normpath", normpath): 1304 eq_( 1305 script._version_locations, 1306 [ 1307 os.path.abspath( 1308 os.path.join(_get_staging_directory(), "model1/") 1309 ).replace("/", ":NORM:"), 1310 os.path.abspath( 1311 os.path.join(_get_staging_directory(), "model2/") 1312 ).replace("/", ":NORM:"), 1313 os.path.abspath( 1314 os.path.join(_get_staging_directory(), "model3/") 1315 ).replace("/", ":NORM:"), 1316 ], 1317 ) 1318