1import datetime
2import os
3import re
4
5from dateutil import tz
6import sqlalchemy as sa
7from sqlalchemy import inspect
8
9from alembic import autogenerate
10from alembic import command
11from alembic import util
12from alembic.environment import EnvironmentContext
13from alembic.operations import ops
14from alembic.script import ScriptDirectory
15from alembic.testing import assert_raises_message
16from alembic.testing import assertions
17from alembic.testing import eq_
18from alembic.testing import is_
19from alembic.testing import mock
20from alembic.testing import ne_
21from alembic.testing.env import _get_staging_directory
22from alembic.testing.env import _multi_dir_testing_config
23from alembic.testing.env import _multidb_testing_config
24from alembic.testing.env import _no_sql_testing_config
25from alembic.testing.env import _sqlite_file_db
26from alembic.testing.env import _sqlite_testing_config
27from alembic.testing.env import _testing_config
28from alembic.testing.env import clear_staging_env
29from alembic.testing.env import env_file_fixture
30from alembic.testing.env import script_file_fixture
31from alembic.testing.env import staging_env
32from alembic.testing.env import three_rev_fixture
33from alembic.testing.env import write_script
34from alembic.testing.fixtures import TestBase
35from alembic.util import CommandError
36
37env, abc, def_ = None, None, None
38
39
40class GeneralOrderedTests(TestBase):
41    def setUp(self):
42        global env
43        env = staging_env()
44
45    def tearDown(self):
46        clear_staging_env()
47
48    def test_steps(self):
49        self._test_001_environment()
50        self._test_002_rev_ids()
51        self._test_003_api_methods_clean()
52        self._test_004_rev()
53        self._test_005_nextrev()
54        self._test_006_from_clean_env()
55        self._test_007_long_name()
56        self._test_008_long_name_configurable()
57
58    def _test_001_environment(self):
59        assert_set = set(["env.py", "script.py.mako", "README"])
60        eq_(assert_set.intersection(os.listdir(env.dir)), assert_set)
61
62    def _test_002_rev_ids(self):
63        global abc, def_
64        abc = util.rev_id()
65        def_ = util.rev_id()
66        ne_(abc, def_)
67
68    def _test_003_api_methods_clean(self):
69        eq_(env.get_heads(), [])
70
71        eq_(env.get_base(), None)
72
73    def _test_004_rev(self):
74        script = env.generate_revision(abc, "this is a message", refresh=True)
75        eq_(script.doc, "this is a message")
76        eq_(script.revision, abc)
77        eq_(script.down_revision, None)
78        assert os.access(
79            os.path.join(env.dir, "versions", "%s_this_is_a_message.py" % abc),
80            os.F_OK,
81        )
82        assert callable(script.module.upgrade)
83        eq_(env.get_heads(), [abc])
84        eq_(env.get_base(), abc)
85
86    def _test_005_nextrev(self):
87        script = env.generate_revision(
88            def_, "this is the next rev", refresh=True
89        )
90        assert os.access(
91            os.path.join(
92                env.dir, "versions", "%s_this_is_the_next_rev.py" % def_
93            ),
94            os.F_OK,
95        )
96        eq_(script.revision, def_)
97        eq_(script.down_revision, abc)
98        eq_(env.get_revision(abc).nextrev, set([def_]))
99        assert script.module.down_revision == abc
100        assert callable(script.module.upgrade)
101        assert callable(script.module.downgrade)
102        eq_(env.get_heads(), [def_])
103        eq_(env.get_base(), abc)
104
105    def _test_006_from_clean_env(self):
106        # test the environment so far with a
107        # new ScriptDirectory instance.
108
109        env = staging_env(create=False)
110        abc_rev = env.get_revision(abc)
111        def_rev = env.get_revision(def_)
112        eq_(abc_rev.nextrev, set([def_]))
113        eq_(abc_rev.revision, abc)
114        eq_(def_rev.down_revision, abc)
115        eq_(env.get_heads(), [def_])
116        eq_(env.get_base(), abc)
117
118    def _test_007_long_name(self):
119        rid = util.rev_id()
120        env.generate_revision(
121            rid,
122            "this is a really long name with "
123            "lots of characters and also "
124            "I'd like it to\nhave\nnewlines",
125        )
126        assert os.access(
127            os.path.join(
128                env.dir,
129                "versions",
130                "%s_this_is_a_really_long_name_with_lots_of_.py" % rid,
131            ),
132            os.F_OK,
133        )
134
135    def _test_008_long_name_configurable(self):
136        env.truncate_slug_length = 60
137        rid = util.rev_id()
138        env.generate_revision(
139            rid,
140            "this is a really long name with "
141            "lots of characters and also "
142            "I'd like it to\nhave\nnewlines",
143        )
144        assert os.access(
145            os.path.join(
146                env.dir,
147                "versions",
148                "%s_this_is_a_really_long_name_with_lots_"
149                "of_characters_and_also_.py" % rid,
150            ),
151            os.F_OK,
152        )
153
154
155class ScriptNamingTest(TestBase):
156    @classmethod
157    def setup_class(cls):
158        _testing_config()
159
160    @classmethod
161    def teardown_class(cls):
162        clear_staging_env()
163
164    def test_args(self):
165        script = ScriptDirectory(
166            _get_staging_directory(),
167            file_template="%(rev)s_%(slug)s_"
168            "%(year)s_%(month)s_"
169            "%(day)s_%(hour)s_"
170            "%(minute)s_%(second)s",
171        )
172        create_date = datetime.datetime(2012, 7, 25, 15, 8, 5)
173        eq_(
174            script._rev_path(
175                script.versions, "12345", "this is a message", create_date
176            ),
177            os.path.abspath(
178                "%s/versions/12345_this_is_a_"
179                "message_2012_7_25_15_8_5.py" % _get_staging_directory()
180            ),
181        )
182
183    def _test_tz(self, timezone_arg, given, expected):
184        script = ScriptDirectory(
185            _get_staging_directory(),
186            file_template="%(rev)s_%(slug)s_"
187            "%(year)s_%(month)s_"
188            "%(day)s_%(hour)s_"
189            "%(minute)s_%(second)s",
190            timezone=timezone_arg,
191        )
192
193        with mock.patch(
194            "alembic.script.base.datetime",
195            mock.Mock(
196                datetime=mock.Mock(utcnow=lambda: given, now=lambda: given)
197            ),
198        ):
199            create_date = script._generate_create_date()
200        eq_(create_date, expected)
201
202    def test_custom_tz(self):
203        self._test_tz(
204            "EST5EDT",
205            datetime.datetime(2012, 7, 25, 15, 8, 5),
206            datetime.datetime(
207                2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz("EST5EDT")
208            ),
209        )
210
211    def test_custom_tz_lowercase(self):
212        self._test_tz(
213            "est5edt",
214            datetime.datetime(2012, 7, 25, 15, 8, 5),
215            datetime.datetime(
216                2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz("EST5EDT")
217            ),
218        )
219
220    def test_custom_tz_utc(self):
221        self._test_tz(
222            "utc",
223            datetime.datetime(2012, 7, 25, 15, 8, 5),
224            datetime.datetime(2012, 7, 25, 15, 8, 5, tzinfo=tz.gettz("UTC")),
225        )
226
227    def test_custom_tzdata_tz(self):
228        self._test_tz(
229            "Europe/Berlin",
230            datetime.datetime(2012, 7, 25, 15, 8, 5),
231            datetime.datetime(
232                2012, 7, 25, 17, 8, 5, tzinfo=tz.gettz("Europe/Berlin")
233            ),
234        )
235
236    def test_default_tz(self):
237        self._test_tz(
238            None,
239            datetime.datetime(2012, 7, 25, 15, 8, 5),
240            datetime.datetime(2012, 7, 25, 15, 8, 5),
241        )
242
243    def test_tz_cant_locate(self):
244        assert_raises_message(
245            CommandError,
246            "Can't locate timezone: fake",
247            self._test_tz,
248            "fake",
249            datetime.datetime(2012, 7, 25, 15, 8, 5),
250            datetime.datetime(2012, 7, 25, 15, 8, 5),
251        )
252
253
254class RevisionCommandTest(TestBase):
255    def setUp(self):
256        self.env = staging_env()
257        self.cfg = _sqlite_testing_config()
258        self.a, self.b, self.c = three_rev_fixture(self.cfg)
259
260    def tearDown(self):
261        clear_staging_env()
262
263    def test_create_script_basic(self):
264        rev = command.revision(self.cfg, message="some message")
265        script = ScriptDirectory.from_config(self.cfg)
266        rev = script.get_revision(rev.revision)
267        eq_(rev.down_revision, self.c)
268        assert "some message" in rev.doc
269
270    def test_create_script_splice(self):
271        rev = command.revision(
272            self.cfg, message="some message", head=self.b, splice=True
273        )
274        script = ScriptDirectory.from_config(self.cfg)
275        rev = script.get_revision(rev.revision)
276        eq_(rev.down_revision, self.b)
277        assert "some message" in rev.doc
278        eq_(set(script.get_heads()), set([rev.revision, self.c]))
279
280    def test_create_script_missing_splice(self):
281        assert_raises_message(
282            util.CommandError,
283            "Revision %s is not a head revision; please specify --splice "
284            "to create a new branch from this revision" % self.b,
285            command.revision,
286            self.cfg,
287            message="some message",
288            head=self.b,
289        )
290
291    def test_illegal_revision_chars(self):
292        assert_raises_message(
293            util.CommandError,
294            r"Character\(s\) '-' not allowed in "
295            "revision identifier 'no-dashes'",
296            command.revision,
297            self.cfg,
298            message="some message",
299            rev_id="no-dashes",
300        )
301
302        assert not os.path.exists(
303            os.path.join(self.env.dir, "versions", "no-dashes_some_message.py")
304        )
305
306        assert_raises_message(
307            util.CommandError,
308            r"Character\(s\) '@' not allowed in "
309            "revision identifier 'no@atsigns'",
310            command.revision,
311            self.cfg,
312            message="some message",
313            rev_id="no@atsigns",
314        )
315
316        assert_raises_message(
317            util.CommandError,
318            r"Character\(s\) '-, @' not allowed in revision "
319            "identifier 'no@atsigns-ordashes'",
320            command.revision,
321            self.cfg,
322            message="some message",
323            rev_id="no@atsigns-ordashes",
324        )
325
326        assert_raises_message(
327            util.CommandError,
328            r"Character\(s\) '\+' not allowed in revision "
329            r"identifier 'no\+plussignseither'",
330            command.revision,
331            self.cfg,
332            message="some message",
333            rev_id="no+plussignseither",
334        )
335
336    def test_create_script_branches(self):
337        rev = command.revision(
338            self.cfg, message="some message", branch_label="foobar"
339        )
340        script = ScriptDirectory.from_config(self.cfg)
341        rev = script.get_revision(rev.revision)
342        eq_(script.get_revision("foobar"), rev)
343
344    def test_create_script_branches_old_template(self):
345        script = ScriptDirectory.from_config(self.cfg)
346        with open(os.path.join(script.dir, "script.py.mako"), "w") as file_:
347            file_.write(
348                "<%text>#</%text> ${message}\n"
349                "revision = ${repr(up_revision)}\n"
350                "down_revision = ${repr(down_revision)}\n\n"
351                "def upgrade():\n"
352                "    ${upgrades if upgrades else 'pass'}\n\n"
353                "def downgrade():\n"
354                "    ${downgrade if downgrades else 'pass'}\n\n"
355            )
356
357        # works OK if no branch names
358        command.revision(self.cfg, message="some message")
359
360        assert_raises_message(
361            util.CommandError,
362            r"Version \w+ specified branch_labels foobar, "
363            r"however the migration file .+?\b does not have them; have you "
364            "upgraded your script.py.mako to include the 'branch_labels' "
365            r"section\?",
366            command.revision,
367            self.cfg,
368            message="some message",
369            branch_label="foobar",
370        )
371
372
373class CustomizeRevisionTest(TestBase):
374    def setUp(self):
375        self.env = staging_env()
376        self.cfg = _multi_dir_testing_config()
377        self.cfg.set_main_option("revision_environment", "true")
378
379        script = ScriptDirectory.from_config(self.cfg)
380        self.model1 = util.rev_id()
381        self.model2 = util.rev_id()
382        self.model3 = util.rev_id()
383        for model, name in [
384            (self.model1, "model1"),
385            (self.model2, "model2"),
386            (self.model3, "model3"),
387        ]:
388            script.generate_revision(
389                model,
390                name,
391                refresh=True,
392                version_path=os.path.join(_get_staging_directory(), name),
393                head="base",
394            )
395
396            write_script(
397                script,
398                model,
399                """\
400"%s"
401revision = '%s'
402down_revision = None
403branch_labels = ['%s']
404
405from alembic import op
406
407
408def upgrade():
409    pass
410
411
412def downgrade():
413    pass
414
415"""
416                % (name, model, name),
417            )
418
419    def tearDown(self):
420        clear_staging_env()
421
422    def _env_fixture(self, fn, target_metadata):
423        self.engine = engine = _sqlite_file_db()
424
425        def run_env(self):
426            from alembic import context
427
428            with engine.connect() as connection:
429                context.configure(
430                    connection=connection,
431                    target_metadata=target_metadata,
432                    process_revision_directives=fn,
433                )
434                with context.begin_transaction():
435                    context.run_migrations()
436
437        return mock.patch(
438            "alembic.script.base.ScriptDirectory.run_env", run_env
439        )
440
441    def test_new_locations_no_autogen(self):
442        m = sa.MetaData()
443
444        def process_revision_directives(context, rev, generate_revisions):
445            generate_revisions[:] = [
446                ops.MigrationScript(
447                    util.rev_id(),
448                    ops.UpgradeOps(),
449                    ops.DowngradeOps(),
450                    version_path=os.path.join(
451                        _get_staging_directory(), "model1"
452                    ),
453                    head="model1@head",
454                ),
455                ops.MigrationScript(
456                    util.rev_id(),
457                    ops.UpgradeOps(),
458                    ops.DowngradeOps(),
459                    version_path=os.path.join(
460                        _get_staging_directory(), "model2"
461                    ),
462                    head="model2@head",
463                ),
464                ops.MigrationScript(
465                    util.rev_id(),
466                    ops.UpgradeOps(),
467                    ops.DowngradeOps(),
468                    version_path=os.path.join(
469                        _get_staging_directory(), "model3"
470                    ),
471                    head="model3@head",
472                ),
473            ]
474
475        with self._env_fixture(process_revision_directives, m):
476            revs = command.revision(self.cfg, message="some message")
477
478        script = ScriptDirectory.from_config(self.cfg)
479
480        for rev, model in [
481            (revs[0], "model1"),
482            (revs[1], "model2"),
483            (revs[2], "model3"),
484        ]:
485            rev_script = script.get_revision(rev.revision)
486            eq_(
487                rev_script.path,
488                os.path.abspath(
489                    os.path.join(
490                        _get_staging_directory(),
491                        model,
492                        "%s_.py" % (rev_script.revision,),
493                    )
494                ),
495            )
496            assert os.path.exists(rev_script.path)
497
498    def test_renders_added_directives_no_autogen(self):
499        m = sa.MetaData()
500
501        def process_revision_directives(context, rev, generate_revisions):
502            generate_revisions[0].upgrade_ops.ops.append(
503                ops.CreateIndexOp("some_index", "some_table", ["a", "b"])
504            )
505
506        with self._env_fixture(process_revision_directives, m):
507            rev = command.revision(
508                self.cfg, message="some message", head="model1@head", sql=True
509            )
510
511        with mock.patch.object(rev.module, "op") as op_mock:
512            rev.module.upgrade()
513        eq_(
514            op_mock.mock_calls,
515            [
516                mock.call.create_index(
517                    "some_index", "some_table", ["a", "b"], unique=False
518                )
519            ],
520        )
521
522    def test_autogen(self):
523        m = sa.MetaData()
524        sa.Table("t", m, sa.Column("x", sa.Integer))
525
526        def process_revision_directives(context, rev, generate_revisions):
527            existing_upgrades = generate_revisions[0].upgrade_ops
528            existing_downgrades = generate_revisions[0].downgrade_ops
529
530            # model1 will run the upgrades, e.g. create the table,
531            # model2 will run the downgrades as upgrades, e.g. drop
532            # the table again
533
534            generate_revisions[:] = [
535                ops.MigrationScript(
536                    util.rev_id(),
537                    existing_upgrades,
538                    ops.DowngradeOps(),
539                    version_path=os.path.join(
540                        _get_staging_directory(), "model1"
541                    ),
542                    head="model1@head",
543                ),
544                ops.MigrationScript(
545                    util.rev_id(),
546                    ops.UpgradeOps(ops=existing_downgrades.ops),
547                    ops.DowngradeOps(),
548                    version_path=os.path.join(
549                        _get_staging_directory(), "model2"
550                    ),
551                    head="model2@head",
552                ),
553            ]
554
555        with self._env_fixture(process_revision_directives, m):
556            command.upgrade(self.cfg, "heads")
557
558            eq_(inspect(self.engine).get_table_names(), ["alembic_version"])
559
560            command.revision(
561                self.cfg, message="some message", autogenerate=True
562            )
563
564            command.upgrade(self.cfg, "model1@head")
565
566            eq_(
567                inspect(self.engine).get_table_names(),
568                ["alembic_version", "t"],
569            )
570
571            command.upgrade(self.cfg, "model2@head")
572
573            eq_(inspect(self.engine).get_table_names(), ["alembic_version"])
574
575    def test_programmatic_command_option(self):
576        def process_revision_directives(context, rev, generate_revisions):
577            generate_revisions[0].message = "test programatic"
578            generate_revisions[0].upgrade_ops = ops.UpgradeOps(
579                ops=[
580                    ops.CreateTableOp(
581                        "test_table",
582                        [
583                            sa.Column("id", sa.Integer(), primary_key=True),
584                            sa.Column("name", sa.String(50), nullable=False),
585                        ],
586                    )
587                ]
588            )
589            generate_revisions[0].downgrade_ops = ops.DowngradeOps(
590                ops=[ops.DropTableOp("test_table")]
591            )
592
593        with self._env_fixture(None, None):
594            rev = command.revision(
595                self.cfg,
596                head="model1@head",
597                process_revision_directives=process_revision_directives,
598            )
599
600        with open(rev.path) as handle:
601            result = handle.read()
602        assert (
603            (
604                """
605def upgrade():
606    # ### commands auto generated by Alembic - please adjust! ###
607    op.create_table('test_table',
608    sa.Column('id', sa.Integer(), nullable=False),
609    sa.Column('name', sa.String(length=50), nullable=False),
610    sa.PrimaryKeyConstraint('id')
611    )
612    # ### end Alembic commands ###
613"""
614            )
615            in result
616        )
617
618
619class ScriptAccessorTest(TestBase):
620    def test_upgrade_downgrade_ops_list_accessors(self):
621        u1 = ops.UpgradeOps(ops=[])
622        d1 = ops.DowngradeOps(ops=[])
623        m1 = ops.MigrationScript("somerev", u1, d1)
624        is_(m1.upgrade_ops, u1)
625        is_(m1.downgrade_ops, d1)
626        u2 = ops.UpgradeOps(ops=[])
627        d2 = ops.DowngradeOps(ops=[])
628        m1._upgrade_ops.append(u2)
629        m1._downgrade_ops.append(d2)
630
631        assert_raises_message(
632            ValueError,
633            "This MigrationScript instance has a multiple-entry list for "
634            "UpgradeOps; please use the upgrade_ops_list attribute.",
635            getattr,
636            m1,
637            "upgrade_ops",
638        )
639        assert_raises_message(
640            ValueError,
641            "This MigrationScript instance has a multiple-entry list for "
642            "DowngradeOps; please use the downgrade_ops_list attribute.",
643            getattr,
644            m1,
645            "downgrade_ops",
646        )
647        eq_(m1.upgrade_ops_list, [u1, u2])
648        eq_(m1.downgrade_ops_list, [d1, d2])
649
650
651class ImportsTest(TestBase):
652    def setUp(self):
653        self.env = staging_env()
654        self.cfg = _sqlite_testing_config()
655
656    def tearDown(self):
657        clear_staging_env()
658
659    def _env_fixture(self, target_metadata, **kw):
660        self.engine = engine = _sqlite_file_db()
661
662        def run_env(self):
663            from alembic import context
664
665            with engine.connect() as connection:
666                context.configure(
667                    connection=connection,
668                    target_metadata=target_metadata,
669                    **kw
670                )
671                with context.begin_transaction():
672                    context.run_migrations()
673
674        return mock.patch(
675            "alembic.script.base.ScriptDirectory.run_env", run_env
676        )
677
678    def test_imports_in_script(self):
679        from sqlalchemy import MetaData, Table, Column
680        from sqlalchemy.dialects.mysql import VARCHAR
681
682        type_ = VARCHAR(20, charset="utf8", national=True)
683
684        m = MetaData()
685
686        Table("t", m, Column("x", type_))
687
688        def process_revision_directives(context, rev, generate_revisions):
689            generate_revisions[0].imports.add(
690                "from sqlalchemy.dialects.mysql import TINYINT"
691            )
692
693        with self._env_fixture(
694            m, process_revision_directives=process_revision_directives
695        ):
696            rev = command.revision(
697                self.cfg, message="some message", autogenerate=True
698            )
699
700        with open(rev.path) as file_:
701            contents = file_.read()
702            assert "from sqlalchemy.dialects import mysql" in contents
703            assert "from sqlalchemy.dialects.mysql import TINYINT" in contents
704
705
706class MultiContextTest(TestBase):
707    """test the multidb template for autogenerate front-to-back"""
708
709    def setUp(self):
710        self.engine1 = _sqlite_file_db(tempname="eng1.db")
711        self.engine2 = _sqlite_file_db(tempname="eng2.db")
712        self.engine3 = _sqlite_file_db(tempname="eng3.db")
713
714        self.env = staging_env(template="multidb")
715        self.cfg = _multidb_testing_config(
716            {
717                "engine1": self.engine1,
718                "engine2": self.engine2,
719                "engine3": self.engine3,
720            }
721        )
722
723    def _write_metadata(self, meta):
724        path = os.path.join(_get_staging_directory(), "scripts", "env.py")
725        with open(path) as env_:
726            existing_env = env_.read()
727        existing_env = existing_env.replace("target_metadata = {}", meta)
728        with open(path, "w") as env_:
729            env_.write(existing_env)
730
731    def tearDown(self):
732        clear_staging_env()
733
734    def test_autogen(self):
735        self._write_metadata(
736            """
737import sqlalchemy as sa
738
739m1 = sa.MetaData()
740m2 = sa.MetaData()
741m3 = sa.MetaData()
742target_metadata = {"engine1": m1, "engine2": m2, "engine3": m3}
743
744sa.Table('e1t1', m1, sa.Column('x', sa.Integer))
745sa.Table('e2t1', m2, sa.Column('y', sa.Integer))
746sa.Table('e3t1', m3, sa.Column('z', sa.Integer))
747
748"""
749        )
750
751        rev = command.revision(
752            self.cfg, message="some message", autogenerate=True
753        )
754        with mock.patch.object(rev.module, "op") as op_mock:
755            rev.module.upgrade_engine1()
756            eq_(
757                op_mock.mock_calls[-1],
758                mock.call.create_table("e1t1", mock.ANY),
759            )
760            rev.module.upgrade_engine2()
761            eq_(
762                op_mock.mock_calls[-1],
763                mock.call.create_table("e2t1", mock.ANY),
764            )
765            rev.module.upgrade_engine3()
766            eq_(
767                op_mock.mock_calls[-1],
768                mock.call.create_table("e3t1", mock.ANY),
769            )
770            rev.module.downgrade_engine1()
771            eq_(op_mock.mock_calls[-1], mock.call.drop_table("e1t1"))
772            rev.module.downgrade_engine2()
773            eq_(op_mock.mock_calls[-1], mock.call.drop_table("e2t1"))
774            rev.module.downgrade_engine3()
775            eq_(op_mock.mock_calls[-1], mock.call.drop_table("e3t1"))
776
777
778class RewriterTest(TestBase):
779    def test_all_traverse(self):
780        writer = autogenerate.Rewriter()
781
782        mocker = mock.Mock(side_effect=lambda context, revision, op: op)
783        writer.rewrites(ops.MigrateOperation)(mocker)
784
785        addcolop = ops.AddColumnOp("t1", sa.Column("x", sa.Integer()))
786
787        directives = [
788            ops.MigrationScript(
789                util.rev_id(),
790                ops.UpgradeOps(ops=[ops.ModifyTableOps("t1", ops=[addcolop])]),
791                ops.DowngradeOps(ops=[]),
792            )
793        ]
794
795        ctx, rev = mock.Mock(), mock.Mock()
796        writer(ctx, rev, directives)
797        eq_(
798            mocker.mock_calls,
799            [
800                mock.call(ctx, rev, directives[0]),
801                mock.call(ctx, rev, directives[0].upgrade_ops),
802                mock.call(ctx, rev, directives[0].upgrade_ops.ops[0]),
803                mock.call(ctx, rev, addcolop),
804                mock.call(ctx, rev, directives[0].downgrade_ops),
805            ],
806        )
807
808    def test_double_migrate_table(self):
809        writer = autogenerate.Rewriter()
810
811        idx_ops = []
812
813        @writer.rewrites(ops.ModifyTableOps)
814        def second_table(context, revision, op):
815            return [
816                op,
817                ops.ModifyTableOps(
818                    "t2",
819                    ops=[ops.AddColumnOp("t2", sa.Column("x", sa.Integer()))],
820                ),
821            ]
822
823        @writer.rewrites(ops.AddColumnOp)
824        def add_column(context, revision, op):
825            idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name])
826            idx_ops.append(idx_op)
827            return [op, idx_op]
828
829        directives = [
830            ops.MigrationScript(
831                util.rev_id(),
832                ops.UpgradeOps(
833                    ops=[
834                        ops.ModifyTableOps(
835                            "t1",
836                            ops=[
837                                ops.AddColumnOp(
838                                    "t1", sa.Column("x", sa.Integer())
839                                )
840                            ],
841                        )
842                    ]
843                ),
844                ops.DowngradeOps(ops=[]),
845            )
846        ]
847
848        ctx, rev = mock.Mock(), mock.Mock()
849        writer(ctx, rev, directives)
850        eq_(
851            [d.table_name for d in directives[0].upgrade_ops.ops], ["t1", "t2"]
852        )
853        is_(directives[0].upgrade_ops.ops[0].ops[1], idx_ops[0])
854        is_(directives[0].upgrade_ops.ops[1].ops[1], idx_ops[1])
855
856    def test_chained_ops(self):
857        writer1 = autogenerate.Rewriter()
858        writer2 = autogenerate.Rewriter()
859
860        @writer1.rewrites(ops.AddColumnOp)
861        def add_column_nullable(context, revision, op):
862            if op.column.nullable:
863                return op
864            else:
865                op.column.nullable = True
866                return [
867                    op,
868                    ops.AlterColumnOp(
869                        op.table_name,
870                        op.column.name,
871                        modify_nullable=False,
872                        existing_type=op.column.type,
873                    ),
874                ]
875
876        @writer2.rewrites(ops.AddColumnOp)
877        def add_column_idx(context, revision, op):
878            idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name])
879            return [op, idx_op]
880
881        directives = [
882            ops.MigrationScript(
883                util.rev_id(),
884                ops.UpgradeOps(
885                    ops=[
886                        ops.ModifyTableOps(
887                            "t1",
888                            ops=[
889                                ops.AddColumnOp(
890                                    "t1",
891                                    sa.Column(
892                                        "x", sa.Integer(), nullable=False
893                                    ),
894                                )
895                            ],
896                        )
897                    ]
898                ),
899                ops.DowngradeOps(ops=[]),
900            )
901        ]
902
903        ctx, rev = mock.Mock(), mock.Mock()
904        writer1.chain(writer2)(ctx, rev, directives)
905
906        eq_(
907            autogenerate.render_python_code(directives[0].upgrade_ops),
908            "# ### commands auto generated by Alembic - please adjust! ###\n"
909            "    op.add_column('t1', "
910            "sa.Column('x', sa.Integer(), nullable=True))\n"
911            "    op.create_index('ixt', 't1', ['x'], unique=False)\n"
912            "    op.alter_column('t1', 'x',\n"
913            "               existing_type=sa.Integer(),\n"
914            "               nullable=False)\n"
915            "    # ### end Alembic commands ###",
916        )
917
918    def test_no_needless_pass(self):
919        writer1 = autogenerate.Rewriter()
920
921        @writer1.rewrites(ops.AlterColumnOp)
922        def rewrite_alter_column(context, revision, op):
923            return []
924
925        directives = [
926            ops.MigrationScript(
927                util.rev_id(),
928                ops.UpgradeOps(
929                    ops=[
930                        ops.ModifyTableOps(
931                            "t1",
932                            ops=[
933                                ops.AlterColumnOp(
934                                    "foo",
935                                    "bar",
936                                    modify_nullable=False,
937                                    existing_type=sa.Integer(),
938                                ),
939                                ops.AlterColumnOp(
940                                    "foo",
941                                    "bar",
942                                    modify_nullable=False,
943                                    existing_type=sa.Integer(),
944                                ),
945                            ],
946                        ),
947                        ops.ModifyTableOps(
948                            "t1",
949                            ops=[
950                                ops.AlterColumnOp(
951                                    "foo",
952                                    "bar",
953                                    modify_nullable=False,
954                                    existing_type=sa.Integer(),
955                                )
956                            ],
957                        ),
958                    ]
959                ),
960                ops.DowngradeOps(ops=[]),
961            )
962        ]
963        ctx, rev = mock.Mock(), mock.Mock()
964        writer1(ctx, rev, directives)
965
966        eq_(
967            autogenerate.render_python_code(directives[0].upgrade_ops),
968            "# ### commands auto generated by Alembic - please adjust! ###\n"
969            "    pass\n"
970            "    # ### end Alembic commands ###",
971        )
972
973    def test_multiple_passes_with_mutations(self):
974        writer1 = autogenerate.Rewriter()
975
976        @writer1.rewrites(ops.CreateTableOp)
977        def rewrite_alter_column(context, revision, op):
978            op.table_name += "_pass"
979            return op
980
981        directives = [
982            ops.MigrationScript(
983                util.rev_id(),
984                ops.UpgradeOps(
985                    ops=[
986                        ops.CreateTableOp(
987                            "test_table",
988                            [sa.Column("id", sa.Integer(), primary_key=True)],
989                        )
990                    ]
991                ),
992                ops.DowngradeOps(ops=[]),
993            )
994        ]
995        ctx, rev = mock.Mock(), mock.Mock()
996        writer1(ctx, rev, directives)
997
998        directives[0].upgrade_ops_list.extend(
999            [
1000                ops.UpgradeOps(
1001                    ops=[
1002                        ops.CreateTableOp(
1003                            "another_test_table",
1004                            [sa.Column("id", sa.Integer(), primary_key=True)],
1005                        )
1006                    ]
1007                ),
1008                ops.UpgradeOps(
1009                    ops=[
1010                        ops.CreateTableOp(
1011                            "third_test_table",
1012                            [sa.Column("id", sa.Integer(), primary_key=True)],
1013                        )
1014                    ]
1015                ),
1016            ]
1017        )
1018
1019        writer1(ctx, rev, directives)
1020
1021        eq_(
1022            autogenerate.render_python_code(directives[0].upgrade_ops_list[0]),
1023            "# ### commands auto generated by Alembic - please adjust! ###\n"
1024            "    op.create_table('test_table_pass',\n"
1025            "    sa.Column('id', sa.Integer(), nullable=False),\n"
1026            "    sa.PrimaryKeyConstraint('id')\n"
1027            "    )\n"
1028            "    # ### end Alembic commands ###",
1029        )
1030        eq_(
1031            autogenerate.render_python_code(directives[0].upgrade_ops_list[1]),
1032            "# ### commands auto generated by Alembic - please adjust! ###\n"
1033            "    op.create_table('another_test_table_pass',\n"
1034            "    sa.Column('id', sa.Integer(), nullable=False),\n"
1035            "    sa.PrimaryKeyConstraint('id')\n"
1036            "    )\n"
1037            "    # ### end Alembic commands ###",
1038        )
1039        eq_(
1040            autogenerate.render_python_code(directives[0].upgrade_ops_list[2]),
1041            "# ### commands auto generated by Alembic - please adjust! ###\n"
1042            "    op.create_table('third_test_table_pass',\n"
1043            "    sa.Column('id', sa.Integer(), nullable=False),\n"
1044            "    sa.PrimaryKeyConstraint('id')\n"
1045            "    )\n"
1046            "    # ### end Alembic commands ###",
1047        )
1048
1049
1050class MultiDirRevisionCommandTest(TestBase):
1051    def setUp(self):
1052        self.env = staging_env()
1053        self.cfg = _multi_dir_testing_config()
1054
1055    def tearDown(self):
1056        clear_staging_env()
1057
1058    def test_multiple_dir_no_bases(self):
1059        assert_raises_message(
1060            util.CommandError,
1061            "Multiple version locations present, please specify "
1062            "--version-path",
1063            command.revision,
1064            self.cfg,
1065            message="some message",
1066        )
1067
1068    def test_multiple_dir_no_bases_invalid_version_path(self):
1069        assert_raises_message(
1070            util.CommandError,
1071            "Path foo/bar/ is not represented in current version locations",
1072            command.revision,
1073            self.cfg,
1074            message="x",
1075            version_path=os.path.join("foo/bar/"),
1076        )
1077
1078    def test_multiple_dir_no_bases_version_path(self):
1079        script = command.revision(
1080            self.cfg,
1081            message="x",
1082            version_path=os.path.join(_get_staging_directory(), "model1"),
1083        )
1084        assert os.access(script.path, os.F_OK)
1085
1086    def test_multiple_dir_chooses_base(self):
1087        command.revision(
1088            self.cfg,
1089            message="x",
1090            head="base",
1091            version_path=os.path.join(_get_staging_directory(), "model1"),
1092        )
1093
1094        script2 = command.revision(
1095            self.cfg,
1096            message="y",
1097            head="base",
1098            version_path=os.path.join(_get_staging_directory(), "model2"),
1099        )
1100
1101        script3 = command.revision(
1102            self.cfg, message="y2", head=script2.revision
1103        )
1104
1105        eq_(
1106            os.path.dirname(script3.path),
1107            os.path.abspath(os.path.join(_get_staging_directory(), "model2")),
1108        )
1109        assert os.access(script3.path, os.F_OK)
1110
1111
1112class TemplateArgsTest(TestBase):
1113    def setUp(self):
1114        staging_env()
1115        self.cfg = _no_sql_testing_config(
1116            directives="\nrevision_environment=true\n"
1117        )
1118
1119    def tearDown(self):
1120        clear_staging_env()
1121
1122    def test_args_propagate(self):
1123        config = _no_sql_testing_config()
1124        script = ScriptDirectory.from_config(config)
1125        template_args = {"x": "x1", "y": "y1", "z": "z1"}
1126        env = EnvironmentContext(config, script, template_args=template_args)
1127        env.configure(
1128            dialect_name="sqlite", template_args={"y": "y2", "q": "q1"}
1129        )
1130        eq_(template_args, {"x": "x1", "y": "y2", "z": "z1", "q": "q1"})
1131
1132    def test_tmpl_args_revision(self):
1133        env_file_fixture(
1134            """
1135context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"})
1136"""
1137        )
1138        script_file_fixture(
1139            """
1140# somearg: ${somearg}
1141revision = ${repr(up_revision)}
1142down_revision = ${repr(down_revision)}
1143"""
1144        )
1145
1146        command.revision(self.cfg, message="some rev")
1147        script = ScriptDirectory.from_config(self.cfg)
1148
1149        rev = script.get_revision("head")
1150        with open(rev.path) as f:
1151            text = f.read()
1152        assert "somearg: somevalue" in text
1153
1154    def test_bad_render(self):
1155        env_file_fixture(
1156            """
1157context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"})
1158"""
1159        )
1160        script_file_fixture(
1161            """
1162    <% z = x + y %>
1163"""
1164        )
1165
1166        try:
1167            command.revision(self.cfg, message="some rev")
1168        except CommandError as ce:
1169            m = re.match(
1170                r"^Template rendering failed; see (.+?) "
1171                "for a template-oriented",
1172                str(ce),
1173            )
1174            assert m, "Command error did not produce a file"
1175            with open(m.group(1)) as handle:
1176                contents = handle.read()
1177            os.remove(m.group(1))
1178            assert "<% z = x + y %>" in contents
1179
1180
1181class DuplicateVersionLocationsTest(TestBase):
1182    def setUp(self):
1183        self.env = staging_env()
1184        self.cfg = _multi_dir_testing_config(
1185            # this is a duplicate of one of the paths
1186            # already present in this fixture
1187            extra_version_location="%(here)s/model1"
1188        )
1189
1190        script = ScriptDirectory.from_config(self.cfg)
1191        self.model1 = util.rev_id()
1192        self.model2 = util.rev_id()
1193        self.model3 = util.rev_id()
1194        for model, name in [
1195            (self.model1, "model1"),
1196            (self.model2, "model2"),
1197            (self.model3, "model3"),
1198        ]:
1199            script.generate_revision(
1200                model,
1201                name,
1202                refresh=True,
1203                version_path=os.path.join(_get_staging_directory(), name),
1204                head="base",
1205            )
1206            write_script(
1207                script,
1208                model,
1209                """\
1210"%s"
1211revision = '%s'
1212down_revision = None
1213branch_labels = ['%s']
1214
1215from alembic import op
1216
1217
1218def upgrade():
1219    pass
1220
1221
1222def downgrade():
1223    pass
1224
1225"""
1226                % (name, model, name),
1227            )
1228
1229    def tearDown(self):
1230        clear_staging_env()
1231
1232    def test_env_emits_warning(self):
1233        msg = (
1234            "File %s loaded twice! ignoring. "
1235            "Please ensure version_locations is unique."
1236            % (
1237                os.path.realpath(
1238                    os.path.join(
1239                        _get_staging_directory(),
1240                        "model1",
1241                        "%s_model1.py" % self.model1,
1242                    )
1243                )
1244            )
1245        )
1246        with assertions.expect_warnings(msg, regex=False):
1247            script = ScriptDirectory.from_config(self.cfg)
1248            script.revision_map.heads
1249            eq_(
1250                [rev.revision for rev in script.walk_revisions()],
1251                [self.model1, self.model2, self.model3],
1252            )
1253
1254
1255class NormPathTest(TestBase):
1256    def setUp(self):
1257        self.env = staging_env()
1258
1259    def tearDown(self):
1260        clear_staging_env()
1261
1262    def test_script_location(self):
1263        config = _no_sql_testing_config()
1264
1265        script = ScriptDirectory.from_config(config)
1266
1267        def normpath(path):
1268            return path.replace("/", ":NORM:")
1269
1270        normpath = mock.Mock(side_effect=normpath)
1271
1272        with mock.patch("os.path.normpath", normpath):
1273            eq_(
1274                script._version_locations,
1275                (
1276                    os.path.abspath(
1277                        os.path.join(
1278                            _get_staging_directory(), "scripts", "versions"
1279                        )
1280                    ).replace("/", ":NORM:"),
1281                ),
1282            )
1283
1284            eq_(
1285                script.versions,
1286                os.path.abspath(
1287                    os.path.join(
1288                        _get_staging_directory(), "scripts", "versions"
1289                    )
1290                ).replace("/", ":NORM:"),
1291            )
1292
1293    def test_script_location_muliple(self):
1294        config = _multi_dir_testing_config()
1295
1296        script = ScriptDirectory.from_config(config)
1297
1298        def normpath(path):
1299            return path.replace("/", ":NORM:")
1300
1301        normpath = mock.Mock(side_effect=normpath)
1302
1303        with mock.patch("os.path.normpath", normpath):
1304            eq_(
1305                script._version_locations,
1306                [
1307                    os.path.abspath(
1308                        os.path.join(_get_staging_directory(), "model1/")
1309                    ).replace("/", ":NORM:"),
1310                    os.path.abspath(
1311                        os.path.join(_get_staging_directory(), "model2/")
1312                    ).replace("/", ":NORM:"),
1313                    os.path.abspath(
1314                        os.path.join(_get_staging_directory(), "model3/")
1315                    ).replace("/", ":NORM:"),
1316                ],
1317            )
1318