1from sqlalchemy import Boolean
2from sqlalchemy import Column
3from sqlalchemy import DATETIME
4from sqlalchemy import exc
5from sqlalchemy import Float
6from sqlalchemy import func
7from sqlalchemy import inspect
8from sqlalchemy import Integer
9from sqlalchemy import MetaData
10from sqlalchemy import Table
11from sqlalchemy import text
12from sqlalchemy import TIMESTAMP
13
14from alembic import op
15from alembic import util
16from alembic.autogenerate import api
17from alembic.autogenerate import compare
18from alembic.migration import MigrationContext
19from alembic.operations import ops
20from alembic.testing import assert_raises_message
21from alembic.testing import combinations
22from alembic.testing import config
23from alembic.testing.env import clear_staging_env
24from alembic.testing.env import staging_env
25from alembic.testing.fixtures import AlterColRoundTripFixture
26from alembic.testing.fixtures import op_fixture
27from alembic.testing.fixtures import TestBase
28from alembic.util import sqla_compat
29
30
31class MySQLOpTest(TestBase):
32    def test_create_table_with_comment(self):
33        context = op_fixture("mysql")
34        op.create_table(
35            "t2",
36            Column("c1", Integer, primary_key=True),
37            comment="This is a table comment",
38        )
39        context.assert_contains("COMMENT='This is a table comment'")
40
41    def test_create_table_with_column_comments(self):
42        context = op_fixture("mysql")
43        op.create_table(
44            "t2",
45            Column("c1", Integer, primary_key=True, comment="c1 comment"),
46            Column("c2", Integer, comment="c2 comment"),
47            comment="This is a table comment",
48        )
49
50        context.assert_(
51            "CREATE TABLE t2 "
52            "(c1 INTEGER NOT NULL COMMENT 'c1 comment' AUTO_INCREMENT, "
53            # TODO: why is there no space at the end here? is that on the
54            # SQLA side?
55            "c2 INTEGER COMMENT 'c2 comment', PRIMARY KEY (c1))"
56            "COMMENT='This is a table comment'"
57        )
58
59    def test_add_column_with_comment(self):
60        context = op_fixture("mysql")
61        op.add_column("t", Column("q", Integer, comment="This is a comment"))
62        context.assert_(
63            "ALTER TABLE t ADD COLUMN q INTEGER COMMENT 'This is a comment'"
64        )
65
66    def test_rename_column(self):
67        context = op_fixture("mysql")
68        op.alter_column(
69            "t1", "c1", new_column_name="c2", existing_type=Integer
70        )
71        context.assert_("ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL")
72
73    def test_rename_column_quotes_needed_one(self):
74        context = op_fixture("mysql")
75        op.alter_column(
76            "MyTable",
77            "ColumnOne",
78            new_column_name="ColumnTwo",
79            existing_type=Integer,
80        )
81        context.assert_(
82            "ALTER TABLE `MyTable` CHANGE `ColumnOne` `ColumnTwo` INTEGER NULL"
83        )
84
85    def test_rename_column_quotes_needed_two(self):
86        context = op_fixture("mysql")
87        op.alter_column(
88            "my table",
89            "column one",
90            new_column_name="column two",
91            existing_type=Integer,
92        )
93        context.assert_(
94            "ALTER TABLE `my table` CHANGE `column one` "
95            "`column two` INTEGER NULL"
96        )
97
98    def test_rename_column_serv_default(self):
99        context = op_fixture("mysql")
100        op.alter_column(
101            "t1",
102            "c1",
103            new_column_name="c2",
104            existing_type=Integer,
105            existing_server_default="q",
106        )
107        context.assert_("ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL DEFAULT 'q'")
108
109    def test_rename_column_serv_compiled_default(self):
110        context = op_fixture("mysql")
111        op.alter_column(
112            "t1",
113            "c1",
114            existing_type=Integer,
115            server_default=func.utc_thing(func.current_timestamp()),
116        )
117        # this is not a valid MySQL default but the point is to just
118        # test SQL expression rendering
119        context.assert_(
120            "ALTER TABLE t1 ALTER COLUMN c1 "
121            "SET DEFAULT utc_thing(CURRENT_TIMESTAMP)"
122        )
123
124    def test_rename_column_autoincrement(self):
125        context = op_fixture("mysql")
126        op.alter_column(
127            "t1",
128            "c1",
129            new_column_name="c2",
130            existing_type=Integer,
131            existing_autoincrement=True,
132        )
133        context.assert_(
134            "ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL AUTO_INCREMENT"
135        )
136
137    def test_col_add_autoincrement(self):
138        context = op_fixture("mysql")
139        op.alter_column("t1", "c1", existing_type=Integer, autoincrement=True)
140        context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NULL AUTO_INCREMENT")
141
142    def test_col_remove_autoincrement(self):
143        context = op_fixture("mysql")
144        op.alter_column(
145            "t1",
146            "c1",
147            existing_type=Integer,
148            existing_autoincrement=True,
149            autoincrement=False,
150        )
151        context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NULL")
152
153    def test_col_dont_remove_server_default(self):
154        context = op_fixture("mysql")
155        op.alter_column(
156            "t1",
157            "c1",
158            existing_type=Integer,
159            existing_server_default="1",
160            server_default=False,
161        )
162
163        context.assert_()
164
165    def test_alter_column_drop_default(self):
166        context = op_fixture("mysql")
167        op.alter_column("t", "c", existing_type=Integer, server_default=None)
168        context.assert_("ALTER TABLE t ALTER COLUMN c DROP DEFAULT")
169
170    def test_alter_column_remove_schematype(self):
171        context = op_fixture("mysql")
172        op.alter_column(
173            "t",
174            "c",
175            type_=Integer,
176            existing_type=Boolean(create_constraint=True, name="ck1"),
177            server_default=None,
178        )
179        context.assert_("ALTER TABLE t MODIFY c INTEGER NULL")
180
181    def test_alter_column_modify_default(self):
182        context = op_fixture("mysql")
183        # notice we dont need the existing type on this one...
184        op.alter_column("t", "c", server_default="1")
185        context.assert_("ALTER TABLE t ALTER COLUMN c SET DEFAULT '1'")
186
187    def test_alter_column_modify_datetime_default(self):
188        # use CHANGE format when the datatype is DATETIME or TIMESTAMP,
189        # as this is needed for a functional default which is what you'd
190        # get with a DATETIME/TIMESTAMP.  Will also work in the very unlikely
191        # case the default is a fixed timestamp value.
192        context = op_fixture("mysql")
193        op.alter_column(
194            "t",
195            "c",
196            existing_type=DATETIME(),
197            server_default=text("CURRENT_TIMESTAMP"),
198        )
199        context.assert_(
200            "ALTER TABLE t CHANGE c c DATETIME NULL DEFAULT CURRENT_TIMESTAMP"
201        )
202
203    def test_alter_column_modify_programmatic_default(self):
204        # test issue #736
205        # when autogenerate.compare creates the operation object
206        # programmatically, the server_default of the op has the full
207        # DefaultClause present.   make sure the usual renderer works.
208        context = op_fixture("mysql")
209
210        m1 = MetaData()
211
212        autogen_context = api.AutogenContext(context, m1)
213
214        operation = ops.AlterColumnOp("t", "c")
215        for fn in (
216            compare._compare_nullable,
217            compare._compare_type,
218            compare._compare_server_default,
219        ):
220            fn(
221                autogen_context,
222                operation,
223                None,
224                "t",
225                "c",
226                Column("c", Float(), nullable=False, server_default=text("0")),
227                Column("c", Float(), nullable=True, default=0),
228            )
229        op.invoke(operation)
230        context.assert_("ALTER TABLE t MODIFY c FLOAT NULL DEFAULT 0")
231
232    def test_col_not_nullable(self):
233        context = op_fixture("mysql")
234        op.alter_column("t1", "c1", nullable=False, existing_type=Integer)
235        context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL")
236
237    def test_col_not_nullable_existing_serv_default(self):
238        context = op_fixture("mysql")
239        op.alter_column(
240            "t1",
241            "c1",
242            nullable=False,
243            existing_type=Integer,
244            existing_server_default="5",
245        )
246        context.assert_(
247            "ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL DEFAULT '5'"
248        )
249
250    def test_col_nullable(self):
251        context = op_fixture("mysql")
252        op.alter_column("t1", "c1", nullable=True, existing_type=Integer)
253        context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NULL")
254
255    def test_col_multi_alter(self):
256        context = op_fixture("mysql")
257        op.alter_column(
258            "t1", "c1", nullable=False, server_default="q", type_=Integer
259        )
260        context.assert_(
261            "ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL DEFAULT 'q'"
262        )
263
264    def test_alter_column_multi_alter_w_drop_default(self):
265        context = op_fixture("mysql")
266        op.alter_column(
267            "t1", "c1", nullable=False, server_default=None, type_=Integer
268        )
269        context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL")
270
271    def test_col_alter_type_required(self):
272        op_fixture("mysql")
273        assert_raises_message(
274            util.CommandError,
275            "MySQL CHANGE/MODIFY COLUMN operations require the existing type.",
276            op.alter_column,
277            "t1",
278            "c1",
279            nullable=False,
280            server_default="q",
281        )
282
283    def test_alter_column_add_comment(self):
284        context = op_fixture("mysql")
285        op.alter_column(
286            "t1",
287            "c1",
288            comment="This is a column comment",
289            existing_type=Boolean(),
290            schema="foo",
291        )
292
293        context.assert_(
294            "ALTER TABLE foo.t1 MODIFY c1 BOOL NULL "
295            "COMMENT 'This is a column comment'"
296        )
297
298    def test_alter_column_add_comment_quoting(self):
299        context = op_fixture("mysql")
300        op.alter_column(
301            "t1",
302            "c1",
303            comment="This is a 'column' comment",
304            existing_type=Boolean(),
305            schema="foo",
306        )
307
308        context.assert_(
309            "ALTER TABLE foo.t1 MODIFY c1 BOOL NULL "
310            "COMMENT 'This is a ''column'' comment'"
311        )
312
313    def test_alter_column_drop_comment(self):
314        context = op_fixture("mysql")
315        op.alter_column(
316            "t",
317            "c",
318            existing_type=Boolean(),
319            schema="foo",
320            comment=None,
321            existing_comment="This is a column comment",
322        )
323
324        context.assert_("ALTER TABLE foo.t MODIFY c BOOL NULL")
325
326    def test_alter_column_existing_comment(self):
327        context = op_fixture("mysql")
328        op.alter_column(
329            "t1",
330            "c1",
331            nullable=False,
332            existing_comment="existing column comment",
333            existing_type=Integer,
334        )
335
336        context.assert_(
337            "ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL "
338            "COMMENT 'existing column comment'"
339        )
340
341    def test_rename_column_existing_comment(self):
342        context = op_fixture("mysql")
343        op.alter_column(
344            "t1",
345            "c1",
346            new_column_name="newc1",
347            existing_nullable=False,
348            existing_comment="existing column comment",
349            existing_type=Integer,
350        )
351
352        context.assert_(
353            "ALTER TABLE t1 CHANGE c1 newc1 INTEGER NOT NULL "
354            "COMMENT 'existing column comment'"
355        )
356
357    def test_alter_column_new_comment_replaces_existing(self):
358        context = op_fixture("mysql")
359        op.alter_column(
360            "t1",
361            "c1",
362            nullable=False,
363            comment="This is a column comment",
364            existing_comment="existing column comment",
365            existing_type=Integer,
366        )
367
368        context.assert_(
369            "ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL "
370            "COMMENT 'This is a column comment'"
371        )
372
373    def test_create_table_comment(self):
374        # this is handled by SQLAlchemy's compilers
375        context = op_fixture("mysql")
376        op.create_table_comment("t2", comment="t2 table", schema="foo")
377        context.assert_("ALTER TABLE foo.t2 COMMENT 't2 table'")
378
379    def test_drop_table_comment(self):
380        # this is handled by SQLAlchemy's compilers
381        context = op_fixture("mysql")
382        op.drop_table_comment("t2", existing_comment="t2 table", schema="foo")
383        context.assert_("ALTER TABLE foo.t2 COMMENT ''")
384
385    @config.requirements.computed_columns_api
386    def test_add_column_computed(self):
387        context = op_fixture("mysql")
388        op.add_column(
389            "t1",
390            Column("some_column", Integer, sqla_compat.Computed("foo * 5")),
391        )
392        context.assert_(
393            "ALTER TABLE t1 ADD COLUMN some_column "
394            "INTEGER GENERATED ALWAYS AS (foo * 5)"
395        )
396
397    def test_drop_fk(self):
398        context = op_fixture("mysql")
399        op.drop_constraint("f1", "t1", "foreignkey")
400        context.assert_("ALTER TABLE t1 DROP FOREIGN KEY f1")
401
402    def test_drop_fk_quoted(self):
403        context = op_fixture("mysql")
404        op.drop_constraint("MyFk", "MyTable", "foreignkey")
405        context.assert_("ALTER TABLE `MyTable` DROP FOREIGN KEY `MyFk`")
406
407    def test_drop_constraint_primary(self):
408        context = op_fixture("mysql")
409        op.drop_constraint("primary", "t1", type_="primary")
410        context.assert_("ALTER TABLE t1 DROP PRIMARY KEY")
411
412    def test_drop_unique(self):
413        context = op_fixture("mysql")
414        op.drop_constraint("f1", "t1", "unique")
415        context.assert_("ALTER TABLE t1 DROP INDEX f1")
416
417    def test_drop_unique_quoted(self):
418        context = op_fixture("mysql")
419        op.drop_constraint("MyUnique", "MyTable", "unique")
420        context.assert_("ALTER TABLE `MyTable` DROP INDEX `MyUnique`")
421
422    def test_drop_check_mariadb(self):
423        context = op_fixture("mariadb")
424        op.drop_constraint("f1", "t1", "check")
425        context.assert_("ALTER TABLE t1 DROP CONSTRAINT f1")
426
427    def test_drop_check_quoted_mariadb(self):
428        context = op_fixture("mariadb")
429        op.drop_constraint("MyCheck", "MyTable", "check")
430        context.assert_("ALTER TABLE `MyTable` DROP CONSTRAINT `MyCheck`")
431
432    def test_drop_check_mysql(self):
433        context = op_fixture("mysql")
434        op.drop_constraint("f1", "t1", "check")
435        context.assert_("ALTER TABLE t1 DROP CHECK f1")
436
437    def test_drop_check_quoted_mysql(self):
438        context = op_fixture("mysql")
439        op.drop_constraint("MyCheck", "MyTable", "check")
440        context.assert_("ALTER TABLE `MyTable` DROP CHECK `MyCheck`")
441
442    def test_drop_unknown(self):
443        op_fixture("mysql")
444        assert_raises_message(
445            TypeError,
446            "'type' can be one of 'check', 'foreignkey', "
447            "'primary', 'unique', None",
448            op.drop_constraint,
449            "f1",
450            "t1",
451            "typo",
452        )
453
454    def test_drop_generic_constraint(self):
455        op_fixture("mysql")
456        assert_raises_message(
457            NotImplementedError,
458            "No generic 'DROP CONSTRAINT' in MySQL - please "
459            "specify constraint type",
460            op.drop_constraint,
461            "f1",
462            "t1",
463        )
464
465    @combinations(
466        (lambda: sqla_compat.Computed("foo * 5"), lambda: None),
467        (lambda: None, lambda: sqla_compat.Computed("foo * 5")),
468        (
469            lambda: sqla_compat.Computed("foo * 42"),
470            lambda: sqla_compat.Computed("foo * 5"),
471        ),
472    )
473    @config.requirements.computed_columns_api
474    def test_alter_column_computed_not_supported(self, sd, esd):
475        op_fixture("mssql")
476        assert_raises_message(
477            exc.CompileError,
478            'Adding or removing a "computed" construct, e.g. '
479            "GENERATED ALWAYS AS, to or from an existing column is not "
480            "supported.",
481            op.alter_column,
482            "t1",
483            "c1",
484            server_default=sd(),
485            existing_server_default=esd(),
486        )
487
488    @combinations(
489        (lambda: sqla_compat.Identity(), lambda: None),
490        (lambda: None, lambda: sqla_compat.Identity()),
491        (
492            lambda: sqla_compat.Identity(),
493            lambda: sqla_compat.Identity(),
494        ),
495    )
496    @config.requirements.identity_columns_api
497    def test_alter_column_identity_not_supported(self, sd, esd):
498        op_fixture()
499        assert_raises_message(
500            exc.CompileError,
501            'Adding, removing or modifying an "identity" construct, '
502            "e.g. GENERATED AS IDENTITY, to or from an existing "
503            "column is not supported in this dialect.",
504            op.alter_column,
505            "t1",
506            "c1",
507            server_default=sd(),
508            existing_server_default=esd(),
509        )
510
511
512class MySQLBackendOpTest(AlterColRoundTripFixture, TestBase):
513    __only_on__ = "mysql", "mariadb"
514    __backend__ = True
515
516    def test_add_timestamp_server_default_current_timestamp(self):
517        self._run_alter_col(
518            {"type": TIMESTAMP()},
519            {"server_default": text("CURRENT_TIMESTAMP")},
520        )
521
522    def test_add_datetime_server_default_current_timestamp(self):
523        self._run_alter_col(
524            {"type": DATETIME()}, {"server_default": text("CURRENT_TIMESTAMP")}
525        )
526
527    def test_add_timestamp_server_default_now(self):
528        self._run_alter_col(
529            {"type": TIMESTAMP()},
530            {"server_default": text("NOW()")},
531            compare={"server_default": text("CURRENT_TIMESTAMP")},
532        )
533
534    def test_add_datetime_server_default_now(self):
535        self._run_alter_col(
536            {"type": DATETIME()},
537            {"server_default": text("NOW()")},
538            compare={"server_default": text("CURRENT_TIMESTAMP")},
539        )
540
541    def test_add_timestamp_server_default_current_timestamp_bundle_onupdate(
542        self,
543    ):
544        # note SQLAlchemy reflection bundles the ON UPDATE part into the
545        # server default reflection see
546        # https://github.com/sqlalchemy/sqlalchemy/issues/4652
547        self._run_alter_col(
548            {"type": TIMESTAMP()},
549            {
550                "server_default": text(
551                    "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"
552                )
553            },
554        )
555
556    def test_add_datetime_server_default_current_timestamp_bundle_onupdate(
557        self,
558    ):
559        # note SQLAlchemy reflection bundles the ON UPDATE part into the
560        # server default reflection see
561        # https://github.com/sqlalchemy/sqlalchemy/issues/4652
562        self._run_alter_col(
563            {"type": DATETIME()},
564            {
565                "server_default": text(
566                    "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"
567                )
568            },
569        )
570
571
572class MySQLDefaultCompareTest(TestBase):
573    __only_on__ = "mysql", "mariadb"
574    __backend__ = True
575
576    @classmethod
577    def setup_class(cls):
578        cls.bind = config.db
579        staging_env()
580        context = MigrationContext.configure(
581            connection=cls.bind.connect(),
582            opts={"compare_type": True, "compare_server_default": True},
583        )
584        connection = context.bind
585        cls.autogen_context = {
586            "imports": set(),
587            "connection": connection,
588            "dialect": connection.dialect,
589            "context": context,
590        }
591
592    @classmethod
593    def teardown_class(cls):
594        clear_staging_env()
595
596    def setUp(self):
597        self.metadata = MetaData()
598
599    def tearDown(self):
600        with config.db.begin() as conn:
601            self.metadata.drop_all(conn)
602
603    def _compare_default_roundtrip(self, type_, txt, alternate=None):
604        if alternate:
605            expected = True
606        else:
607            alternate = txt
608            expected = False
609        t = Table(
610            "test",
611            self.metadata,
612            Column(
613                "somecol", type_, server_default=text(txt) if txt else None
614            ),
615        )
616        t2 = Table(
617            "test",
618            MetaData(),
619            Column("somecol", type_, server_default=text(alternate)),
620        )
621        assert (
622            self._compare_default(t, t2, t2.c.somecol, alternate) is expected
623        )
624
625    def _compare_default(self, t1, t2, col, rendered):
626        t1.create(self.bind)
627        insp = inspect(self.bind)
628        cols = insp.get_columns(t1.name)
629        refl = Table(t1.name, MetaData())
630        sqla_compat._reflect_table(insp, refl, None)
631        ctx = self.autogen_context["context"]
632        return ctx.impl.compare_server_default(
633            refl.c[cols[0]["name"]], col, rendered, cols[0]["default"]
634        )
635
636    def test_compare_timestamp_current_timestamp(self):
637        self._compare_default_roundtrip(TIMESTAMP(), "CURRENT_TIMESTAMP")
638
639    def test_compare_timestamp_current_timestamp_diff(self):
640        self._compare_default_roundtrip(TIMESTAMP(), None, "CURRENT_TIMESTAMP")
641
642    def test_compare_timestamp_current_timestamp_bundle_onupdate(self):
643        self._compare_default_roundtrip(
644            TIMESTAMP(), "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"
645        )
646
647    def test_compare_timestamp_current_timestamp_diff_bundle_onupdate(self):
648        self._compare_default_roundtrip(
649            TIMESTAMP(), None, "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"
650        )
651
652    def test_compare_integer_from_none(self):
653        self._compare_default_roundtrip(Integer(), None, "0")
654
655    def test_compare_integer_same(self):
656        self._compare_default_roundtrip(Integer(), "5")
657
658    def test_compare_integer_diff(self):
659        self._compare_default_roundtrip(Integer(), "5", "7")
660
661    def test_compare_boolean_same(self):
662        self._compare_default_roundtrip(Boolean(), "1")
663
664    def test_compare_boolean_diff(self):
665        self._compare_default_roundtrip(Boolean(), "1", "0")
666