1from sqlalchemy import BigInteger
2from sqlalchemy import Boolean
3from sqlalchemy import Column
4from sqlalchemy import DateTime
5from sqlalchemy import Float
6from sqlalchemy import func
7from sqlalchemy import Index
8from sqlalchemy import inspect
9from sqlalchemy import Integer
10from sqlalchemy import Interval
11from sqlalchemy import MetaData
12from sqlalchemy import Numeric
13from sqlalchemy import Sequence
14from sqlalchemy import String
15from sqlalchemy import Table
16from sqlalchemy import text
17from sqlalchemy import types
18from sqlalchemy.dialects.postgresql import ARRAY
19from sqlalchemy.dialects.postgresql import BYTEA
20from sqlalchemy.dialects.postgresql import HSTORE
21from sqlalchemy.dialects.postgresql import JSON
22from sqlalchemy.dialects.postgresql import JSONB
23from sqlalchemy.dialects.postgresql import UUID
24from sqlalchemy.sql import column
25from sqlalchemy.sql import false
26from sqlalchemy.sql import table
27
28from alembic import autogenerate
29from alembic import command
30from alembic import op
31from alembic import util
32from alembic.autogenerate import api
33from alembic.autogenerate.compare import _compare_server_default
34from alembic.autogenerate.compare import _compare_tables
35from alembic.autogenerate.compare import _render_server_default_for_compare
36from alembic.migration import MigrationContext
37from alembic.operations import Operations
38from alembic.operations import ops
39from alembic.script import ScriptDirectory
40from alembic.testing import config
41from alembic.testing import eq_
42from alembic.testing import eq_ignore_whitespace
43from alembic.testing import provide_metadata
44from alembic.testing.env import _no_sql_testing_config
45from alembic.testing.env import clear_staging_env
46from alembic.testing.env import staging_env
47from alembic.testing.env import write_script
48from alembic.testing.fixtures import capture_context_buffer
49from alembic.testing.fixtures import op_fixture
50from alembic.testing.fixtures import TestBase
51from alembic.util import sqla_compat
52
53
54class PostgresqlOpTest(TestBase):
55    def test_rename_table_postgresql(self):
56        context = op_fixture("postgresql")
57        op.rename_table("t1", "t2")
58        context.assert_("ALTER TABLE t1 RENAME TO t2")
59
60    def test_rename_table_schema_postgresql(self):
61        context = op_fixture("postgresql")
62        op.rename_table("t1", "t2", schema="foo")
63        context.assert_("ALTER TABLE foo.t1 RENAME TO t2")
64
65    def test_create_index_postgresql_expressions(self):
66        context = op_fixture("postgresql")
67        op.create_index(
68            "geocoded",
69            "locations",
70            [text("lower(coordinates)")],
71            postgresql_where=text("locations.coordinates != Null"),
72        )
73        context.assert_(
74            "CREATE INDEX geocoded ON locations (lower(coordinates)) "
75            "WHERE locations.coordinates != Null"
76        )
77
78    def test_create_index_postgresql_where(self):
79        context = op_fixture("postgresql")
80        op.create_index(
81            "geocoded",
82            "locations",
83            ["coordinates"],
84            postgresql_where=text("locations.coordinates != Null"),
85        )
86        context.assert_(
87            "CREATE INDEX geocoded ON locations (coordinates) "
88            "WHERE locations.coordinates != Null"
89        )
90
91    def test_create_index_postgresql_concurrently(self):
92        context = op_fixture("postgresql")
93        op.create_index(
94            "geocoded",
95            "locations",
96            ["coordinates"],
97            postgresql_concurrently=True,
98        )
99        context.assert_(
100            "CREATE INDEX CONCURRENTLY geocoded ON locations (coordinates)"
101        )
102
103    def test_drop_index_postgresql_concurrently(self):
104        context = op_fixture("postgresql")
105        op.drop_index("geocoded", "locations", postgresql_concurrently=True)
106        context.assert_("DROP INDEX CONCURRENTLY geocoded")
107
108    def test_alter_column_type_using(self):
109        context = op_fixture("postgresql")
110        op.alter_column("t", "c", type_=Integer, postgresql_using="c::integer")
111        context.assert_(
112            "ALTER TABLE t ALTER COLUMN c TYPE INTEGER USING c::integer"
113        )
114
115    def test_col_w_pk_is_serial(self):
116        context = op_fixture("postgresql")
117        op.add_column("some_table", Column("q", Integer, primary_key=True))
118        context.assert_("ALTER TABLE some_table ADD COLUMN q SERIAL NOT NULL")
119
120    def test_create_exclude_constraint(self):
121        context = op_fixture("postgresql")
122        op.create_exclude_constraint(
123            "ex1", "t1", ("x", ">"), where="x > 5", using="gist"
124        )
125        context.assert_(
126            "ALTER TABLE t1 ADD CONSTRAINT ex1 EXCLUDE USING gist (x WITH >) "
127            "WHERE (x > 5)"
128        )
129
130    def test_create_exclude_constraint_quoted_literal(self):
131        context = op_fixture("postgresql")
132        op.create_exclude_constraint(
133            "ex1",
134            "SomeTable",
135            (column("SomeColumn"), ">"),
136            where='"SomeColumn" > 5',
137            using="gist",
138        )
139        context.assert_(
140            'ALTER TABLE "SomeTable" ADD CONSTRAINT ex1 EXCLUDE USING gist '
141            '("SomeColumn" WITH >) WHERE ("SomeColumn" > 5)'
142        )
143
144    def test_create_exclude_constraint_quoted_column(self):
145        context = op_fixture("postgresql")
146        op.create_exclude_constraint(
147            "ex1",
148            "SomeTable",
149            (column("SomeColumn"), ">"),
150            where=column("SomeColumn") > 5,
151            using="gist",
152        )
153        context.assert_(
154            'ALTER TABLE "SomeTable" ADD CONSTRAINT ex1 EXCLUDE '
155            'USING gist ("SomeColumn" WITH >) WHERE ("SomeColumn" > 5)'
156        )
157
158    @config.requirements.comments_api
159    def test_add_column_with_comment(self):
160        context = op_fixture("postgresql")
161        op.add_column("t", Column("q", Integer, comment="This is a comment"))
162        context.assert_(
163            "ALTER TABLE t ADD COLUMN q INTEGER",
164            "COMMENT ON COLUMN t.q IS 'This is a comment'",
165        )
166
167    @config.requirements.comments_api
168    def test_alter_column_with_comment(self):
169        context = op_fixture("postgresql")
170        op.alter_column(
171            "t",
172            "c",
173            nullable=False,
174            existing_type=Boolean(),
175            schema="foo",
176            comment="This is a column comment",
177        )
178
179        context.assert_(
180            "ALTER TABLE foo.t ALTER COLUMN c SET NOT NULL",
181            "COMMENT ON COLUMN foo.t.c IS 'This is a column comment'",
182        )
183
184    @config.requirements.comments_api
185    def test_alter_column_add_comment(self):
186        context = op_fixture("postgresql")
187        op.alter_column(
188            "t",
189            "c",
190            existing_type=Boolean(),
191            schema="foo",
192            comment="This is a column comment",
193        )
194
195        context.assert_(
196            "COMMENT ON COLUMN foo.t.c IS 'This is a column comment'"
197        )
198
199    @config.requirements.comments_api
200    def test_alter_column_add_comment_table_and_column_quoting(self):
201        context = op_fixture("postgresql")
202        op.alter_column(
203            "T",
204            "C",
205            existing_type=Boolean(),
206            schema="foo",
207            comment="This is a column comment",
208        )
209
210        context.assert_(
211            'COMMENT ON COLUMN foo."T"."C" IS \'This is a column comment\''
212        )
213
214    @config.requirements.comments_api
215    def test_alter_column_add_comment_quoting(self):
216        context = op_fixture("postgresql")
217        op.alter_column(
218            "t",
219            "c",
220            existing_type=Boolean(),
221            schema="foo",
222            comment="This is a column 'comment'",
223        )
224
225        context.assert_(
226            "COMMENT ON COLUMN foo.t.c IS 'This is a column ''comment'''"
227        )
228
229    @config.requirements.comments_api
230    def test_alter_column_drop_comment(self):
231        context = op_fixture("postgresql")
232        op.alter_column(
233            "t",
234            "c",
235            existing_type=Boolean(),
236            schema="foo",
237            comment=None,
238            existing_comment="This is a column comment",
239        )
240
241        context.assert_("COMMENT ON COLUMN foo.t.c IS NULL")
242
243    @config.requirements.comments_api
244    def test_create_table_with_comment(self):
245        context = op_fixture("postgresql")
246        op.create_table(
247            "t2",
248            Column("c1", Integer, primary_key=True),
249            Column("c2", Integer),
250            comment="t2 comment",
251        )
252        context.assert_(
253            "CREATE TABLE t2 (c1 SERIAL NOT NULL, "
254            "c2 INTEGER, PRIMARY KEY (c1))",
255            "COMMENT ON TABLE t2 IS 't2 comment'",
256        )
257
258    @config.requirements.comments_api
259    def test_create_table_with_column_comments(self):
260        context = op_fixture("postgresql")
261        op.create_table(
262            "t2",
263            Column("c1", Integer, primary_key=True, comment="c1 comment"),
264            Column("c2", Integer, comment="c2 comment"),
265            comment="t2 comment",
266        )
267        context.assert_(
268            "CREATE TABLE t2 (c1 SERIAL NOT NULL, "
269            "c2 INTEGER, PRIMARY KEY (c1))",
270            "COMMENT ON TABLE t2 IS 't2 comment'",
271            "COMMENT ON COLUMN t2.c1 IS 'c1 comment'",
272            "COMMENT ON COLUMN t2.c2 IS 'c2 comment'",
273        )
274
275    @config.requirements.comments_api
276    def test_create_table_comment(self):
277        # this is handled by SQLAlchemy's compilers
278        context = op_fixture("postgresql")
279        op.create_table_comment("t2", comment="t2 table", schema="foo")
280        context.assert_("COMMENT ON TABLE foo.t2 IS 't2 table'")
281
282    @config.requirements.comments_api
283    def test_drop_table_comment(self):
284        # this is handled by SQLAlchemy's compilers
285        context = op_fixture("postgresql")
286        op.drop_table_comment("t2", existing_comment="t2 table", schema="foo")
287        context.assert_("COMMENT ON TABLE foo.t2 IS NULL")
288
289    @config.requirements.computed_columns_api
290    def test_add_column_computed(self):
291        context = op_fixture("postgresql")
292        op.add_column(
293            "t1",
294            Column("some_column", Integer, sqla_compat.Computed("foo * 5")),
295        )
296        context.assert_(
297            "ALTER TABLE t1 ADD COLUMN some_column "
298            "INTEGER GENERATED ALWAYS AS (foo * 5) STORED"
299        )
300
301
302class PGAutocommitBlockTest(TestBase):
303    __only_on__ = "postgresql"
304    __backend__ = True
305
306    def setUp(self):
307        self.conn = conn = config.db.connect()
308
309        with conn.begin():
310            conn.execute(
311                text("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')")
312            )
313
314    def tearDown(self):
315        with self.conn.begin():
316            self.conn.execute(text("DROP TYPE mood"))
317
318    def test_alter_enum(self):
319        context = MigrationContext.configure(connection=self.conn)
320        with context.begin_transaction(_per_migration=True):
321            with context.autocommit_block():
322                context.execute(text("ALTER TYPE mood ADD VALUE 'soso'"))
323
324
325class PGOfflineEnumTest(TestBase):
326    def setUp(self):
327        staging_env()
328        self.cfg = cfg = _no_sql_testing_config()
329
330        self.rid = rid = util.rev_id()
331
332        self.script = script = ScriptDirectory.from_config(cfg)
333        script.generate_revision(rid, None, refresh=True)
334
335    def tearDown(self):
336        clear_staging_env()
337
338    def _inline_enum_script(self):
339        write_script(
340            self.script,
341            self.rid,
342            """
343revision = '%s'
344down_revision = None
345
346from alembic import op
347from sqlalchemy.dialects.postgresql import ENUM
348from sqlalchemy import Column
349
350
351def upgrade():
352    op.create_table("sometable",
353        Column("data", ENUM("one", "two", "three", name="pgenum"))
354    )
355
356
357def downgrade():
358    op.drop_table("sometable")
359"""
360            % self.rid,
361        )
362
363    def _distinct_enum_script(self):
364        write_script(
365            self.script,
366            self.rid,
367            """
368revision = '%s'
369down_revision = None
370
371from alembic import op
372from sqlalchemy.dialects.postgresql import ENUM
373from sqlalchemy import Column
374
375
376def upgrade():
377    enum = ENUM("one", "two", "three", name="pgenum", create_type=False)
378    enum.create(op.get_bind(), checkfirst=False)
379    op.create_table("sometable",
380        Column("data", enum)
381    )
382
383
384def downgrade():
385    op.drop_table("sometable")
386    ENUM(name="pgenum").drop(op.get_bind(), checkfirst=False)
387
388"""
389            % self.rid,
390        )
391
392    def test_offline_inline_enum_create(self):
393        self._inline_enum_script()
394        with capture_context_buffer() as buf:
395            command.upgrade(self.cfg, self.rid, sql=True)
396        assert (
397            "CREATE TYPE pgenum AS "
398            "ENUM ('one', 'two', 'three')" in buf.getvalue()
399        )
400        assert "CREATE TABLE sometable (\n    data pgenum\n)" in buf.getvalue()
401
402    def test_offline_inline_enum_drop(self):
403        self._inline_enum_script()
404        with capture_context_buffer() as buf:
405            command.downgrade(self.cfg, "%s:base" % self.rid, sql=True)
406        assert "DROP TABLE sometable" in buf.getvalue()
407        # no drop since we didn't emit events
408        assert "DROP TYPE pgenum" not in buf.getvalue()
409
410    def test_offline_distinct_enum_create(self):
411        self._distinct_enum_script()
412        with capture_context_buffer() as buf:
413            command.upgrade(self.cfg, self.rid, sql=True)
414        assert (
415            "CREATE TYPE pgenum AS ENUM "
416            "('one', 'two', 'three')" in buf.getvalue()
417        )
418        assert "CREATE TABLE sometable (\n    data pgenum\n)" in buf.getvalue()
419
420    def test_offline_distinct_enum_drop(self):
421        self._distinct_enum_script()
422        with capture_context_buffer() as buf:
423            command.downgrade(self.cfg, "%s:base" % self.rid, sql=True)
424        assert "DROP TABLE sometable" in buf.getvalue()
425        assert "DROP TYPE pgenum" in buf.getvalue()
426
427
428class PostgresqlInlineLiteralTest(TestBase):
429    __only_on__ = "postgresql"
430    __backend__ = True
431
432    @classmethod
433    def setup_class(cls):
434        cls.bind = config.db
435        with config.db.connect() as conn:
436            conn.execute(
437                text(
438                    """
439                create table tab (
440                    col varchar(50)
441                )
442            """
443                )
444            )
445            conn.execute(
446                text(
447                    """
448                insert into tab (col) values
449                    ('old data 1'),
450                    ('old data 2.1'),
451                    ('old data 3')
452            """
453                )
454            )
455
456    @classmethod
457    def teardown_class(cls):
458        with cls.bind.connect() as conn:
459            conn.execute(text("drop table tab"))
460
461    def setUp(self):
462        self.conn = self.bind.connect()
463        ctx = MigrationContext.configure(self.conn)
464        self.op = Operations(ctx)
465
466    def tearDown(self):
467        self.conn.close()
468
469    def test_inline_percent(self):
470        # TODO: here's the issue, you need to escape this.
471        tab = table("tab", column("col"))
472        self.op.execute(
473            tab.update()
474            .where(tab.c.col.like(self.op.inline_literal("%.%")))
475            .values(col=self.op.inline_literal("new data")),
476            execution_options={"no_parameters": True},
477        )
478        eq_(
479            self.conn.execute(
480                text("select count(*) from tab where col='new data'")
481            ).scalar(),
482            1,
483        )
484
485
486class PostgresqlDefaultCompareTest(TestBase):
487    __only_on__ = "postgresql"
488    __backend__ = True
489
490    @classmethod
491    def setup_class(cls):
492        cls.bind = config.db
493        staging_env()
494        cls.migration_context = MigrationContext.configure(
495            connection=cls.bind.connect(),
496            opts={"compare_type": True, "compare_server_default": True},
497        )
498
499    def setUp(self):
500        self.metadata = MetaData(self.bind)
501        self.autogen_context = api.AutogenContext(self.migration_context)
502
503    @classmethod
504    def teardown_class(cls):
505        clear_staging_env()
506
507    def tearDown(self):
508        self.metadata.drop_all()
509
510    def _compare_default_roundtrip(
511        self, type_, orig_default, alternate=None, diff_expected=None
512    ):
513        diff_expected = (
514            diff_expected
515            if diff_expected is not None
516            else alternate is not None
517        )
518        if alternate is None:
519            alternate = orig_default
520
521        t1 = Table(
522            "test",
523            self.metadata,
524            Column("somecol", type_, server_default=orig_default),
525        )
526        t2 = Table(
527            "test",
528            MetaData(),
529            Column("somecol", type_, server_default=alternate),
530        )
531
532        t1.create(self.bind)
533
534        insp = inspect(self.bind)
535        cols = insp.get_columns(t1.name)
536        insp_col = Column(
537            "somecol", cols[0]["type"], server_default=text(cols[0]["default"])
538        )
539        op = ops.AlterColumnOp("test", "somecol")
540        _compare_server_default(
541            self.autogen_context,
542            op,
543            None,
544            "test",
545            "somecol",
546            insp_col,
547            t2.c.somecol,
548        )
549
550        diffs = op.to_diff_tuple()
551        eq_(bool(diffs), diff_expected)
552
553    def _compare_default(self, t1, t2, col, rendered):
554        t1.create(self.bind, checkfirst=True)
555        insp = inspect(self.bind)
556        cols = insp.get_columns(t1.name)
557        ctx = self.autogen_context.migration_context
558
559        return ctx.impl.compare_server_default(
560            None, col, rendered, cols[0]["default"]
561        )
562
563    def test_compare_string_blank_default(self):
564        self._compare_default_roundtrip(String(8), "")
565
566    def test_compare_string_nonblank_default(self):
567        self._compare_default_roundtrip(String(8), "hi")
568
569    def test_compare_interval_str(self):
570        # this form shouldn't be used but testing here
571        # for compatibility
572        self._compare_default_roundtrip(Interval, "14 days")
573
574    @config.requirements.postgresql_uuid_ossp
575    def test_compare_uuid_text(self):
576        self._compare_default_roundtrip(UUID, text("uuid_generate_v4()"))
577
578    def test_compare_interval_text(self):
579        self._compare_default_roundtrip(Interval, text("'14 days'"))
580
581    def test_compare_array_of_integer_text(self):
582        self._compare_default_roundtrip(
583            ARRAY(Integer), text("(ARRAY[]::integer[])")
584        )
585
586    def test_compare_current_timestamp_text(self):
587        self._compare_default_roundtrip(
588            DateTime(), text("TIMEZONE('utc', CURRENT_TIMESTAMP)")
589        )
590
591    def test_compare_current_timestamp_fn_w_binds(self):
592        self._compare_default_roundtrip(
593            DateTime(), func.timezone("utc", func.current_timestamp())
594        )
595
596    def test_compare_integer_str(self):
597        self._compare_default_roundtrip(Integer(), "5")
598
599    def test_compare_integer_text(self):
600        self._compare_default_roundtrip(Integer(), text("5"))
601
602    def test_compare_integer_text_diff(self):
603        self._compare_default_roundtrip(Integer(), text("5"), "7")
604
605    def test_compare_float_str(self):
606        self._compare_default_roundtrip(Float(), "5.2")
607
608    def test_compare_float_text(self):
609        self._compare_default_roundtrip(Float(), text("5.2"))
610
611    def test_compare_float_no_diff1(self):
612        self._compare_default_roundtrip(
613            Float(), text("5.2"), "5.2", diff_expected=False
614        )
615
616    def test_compare_float_no_diff2(self):
617        self._compare_default_roundtrip(
618            Float(), "5.2", text("5.2"), diff_expected=False
619        )
620
621    def test_compare_float_no_diff3(self):
622        self._compare_default_roundtrip(
623            Float(), text("5"), text("5.0"), diff_expected=False
624        )
625
626    def test_compare_float_no_diff4(self):
627        self._compare_default_roundtrip(
628            Float(), "5", "5.0", diff_expected=False
629        )
630
631    def test_compare_float_no_diff5(self):
632        self._compare_default_roundtrip(
633            Float(), text("5"), "5.0", diff_expected=False
634        )
635
636    def test_compare_float_no_diff6(self):
637        self._compare_default_roundtrip(
638            Float(), "5", text("5.0"), diff_expected=False
639        )
640
641    def test_compare_numeric_no_diff(self):
642        self._compare_default_roundtrip(
643            Numeric(), text("5"), "5.0", diff_expected=False
644        )
645
646    def test_compare_unicode_literal(self):
647        self._compare_default_roundtrip(String(), u"im a default")
648
649    # TOOD: will need to actually eval() the repr() and
650    # spend more effort figuring out exactly the kind of expression
651    # to use
652    def _TODO_test_compare_character_str_w_singlequote(self):
653        self._compare_default_roundtrip(String(), "hel''lo")
654
655    def test_compare_character_str(self):
656        self._compare_default_roundtrip(String(), "hello")
657
658    def test_compare_character_text(self):
659        self._compare_default_roundtrip(String(), text("'hello'"))
660
661    def test_compare_character_str_diff(self):
662        self._compare_default_roundtrip(String(), "hello", "there")
663
664    def test_compare_character_text_diff(self):
665        self._compare_default_roundtrip(
666            String(), text("'hello'"), text("'there'")
667        )
668
669    def test_primary_key_skip(self):
670        """Test that SERIAL cols are just skipped"""
671        t1 = Table(
672            "sometable", self.metadata, Column("id", Integer, primary_key=True)
673        )
674        t2 = Table(
675            "sometable", MetaData(), Column("id", Integer, primary_key=True)
676        )
677        assert not self._compare_default(t1, t2, t2.c.id, "")
678
679
680class PostgresqlDetectSerialTest(TestBase):
681    __only_on__ = "postgresql"
682    __backend__ = True
683
684    @classmethod
685    def setup_class(cls):
686        cls.bind = config.db
687        staging_env()
688
689    def setUp(self):
690        self.conn = self.bind.connect()
691        self.migration_context = MigrationContext.configure(
692            connection=self.conn,
693            opts={"compare_type": True, "compare_server_default": True},
694        )
695        self.autogen_context = api.AutogenContext(self.migration_context)
696
697    def tearDown(self):
698        self.conn.close()
699
700    @classmethod
701    def teardown_class(cls):
702        clear_staging_env()
703
704    @provide_metadata
705    def _expect_default(self, c_expected, col, seq=None):
706        Table("t", self.metadata, col)
707
708        self.autogen_context.metadata = self.metadata
709
710        if seq:
711            seq._set_metadata(self.metadata)
712        self.metadata.create_all(config.db)
713
714        insp = inspect(config.db)
715
716        uo = ops.UpgradeOps(ops=[])
717        _compare_tables(
718            set([(None, "t")]), set([]), insp, uo, self.autogen_context
719        )
720        diffs = uo.as_diffs()
721        tab = diffs[0][1]
722
723        eq_(
724            _render_server_default_for_compare(
725                tab.c.x.server_default, tab.c.x, self.autogen_context
726            ),
727            c_expected,
728        )
729
730        insp = inspect(config.db)
731        uo = ops.UpgradeOps(ops=[])
732        m2 = MetaData()
733        Table("t", m2, Column("x", BigInteger()))
734        self.autogen_context.metadata = m2
735        _compare_tables(
736            set([(None, "t")]),
737            set([(None, "t")]),
738            insp,
739            uo,
740            self.autogen_context,
741        )
742        diffs = uo.as_diffs()
743        server_default = diffs[0][0][4]["existing_server_default"]
744        eq_(
745            _render_server_default_for_compare(
746                server_default, tab.c.x, self.autogen_context
747            ),
748            c_expected,
749        )
750
751    def test_serial(self):
752        self._expect_default(None, Column("x", Integer, primary_key=True))
753
754    def test_separate_seq(self):
755        seq = Sequence("x_id_seq")
756        self._expect_default(
757            "nextval('x_id_seq'::regclass)",
758            Column(
759                "x", Integer, server_default=seq.next_value(), primary_key=True
760            ),
761            seq,
762        )
763
764    def test_numeric(self):
765        seq = Sequence("x_id_seq")
766        self._expect_default(
767            "nextval('x_id_seq'::regclass)",
768            Column(
769                "x",
770                Numeric(8, 2),
771                server_default=seq.next_value(),
772                primary_key=True,
773            ),
774            seq,
775        )
776
777    def test_no_default(self):
778        self._expect_default(
779            None, Column("x", Integer, autoincrement=False, primary_key=True)
780        )
781
782
783class PostgresqlAutogenRenderTest(TestBase):
784    def setUp(self):
785        ctx_opts = {
786            "sqlalchemy_module_prefix": "sa.",
787            "alembic_module_prefix": "op.",
788            "target_metadata": MetaData(),
789        }
790        context = MigrationContext.configure(
791            dialect_name="postgresql", opts=ctx_opts
792        )
793
794        self.autogen_context = api.AutogenContext(context)
795
796    def test_render_add_index_pg_where(self):
797        autogen_context = self.autogen_context
798
799        m = MetaData()
800        t = Table("t", m, Column("x", String), Column("y", String))
801
802        idx = Index(
803            "foo_idx", t.c.x, t.c.y, postgresql_where=(t.c.y == "something")
804        )
805
806        op_obj = ops.CreateIndexOp.from_index(idx)
807
808        eq_ignore_whitespace(
809            autogenerate.render_op_text(autogen_context, op_obj),
810            """op.create_index('foo_idx', 't', \
811['x', 'y'], unique=False, """
812            """postgresql_where=sa.text(!U"y = 'something'"))""",
813        )
814
815    def test_render_server_default_native_boolean(self):
816        c = Column(
817            "updated_at", Boolean(), server_default=false(), nullable=False
818        )
819        result = autogenerate.render._render_column(c, self.autogen_context)
820        eq_ignore_whitespace(
821            result,
822            "sa.Column('updated_at', sa.Boolean(), "
823            "server_default=sa.text(!U'false'), "
824            "nullable=False)",
825        )
826
827    def test_postgresql_array_type(self):
828
829        eq_ignore_whitespace(
830            autogenerate.render._repr_type(
831                ARRAY(Integer), self.autogen_context
832            ),
833            "postgresql.ARRAY(sa.Integer())",
834        )
835
836        eq_ignore_whitespace(
837            autogenerate.render._repr_type(
838                ARRAY(DateTime(timezone=True)), self.autogen_context
839            ),
840            "postgresql.ARRAY(sa.DateTime(timezone=True))",
841        )
842
843        eq_ignore_whitespace(
844            autogenerate.render._repr_type(
845                ARRAY(BYTEA, as_tuple=True, dimensions=2), self.autogen_context
846            ),
847            "postgresql.ARRAY(postgresql.BYTEA(), "
848            "as_tuple=True, dimensions=2)",
849        )
850
851        assert (
852            "from sqlalchemy.dialects import postgresql"
853            in self.autogen_context.imports
854        )
855
856    def test_postgresql_hstore_subtypes(self):
857        eq_ignore_whitespace(
858            autogenerate.render._repr_type(HSTORE(), self.autogen_context),
859            "postgresql.HSTORE(text_type=sa.Text())",
860        )
861
862        eq_ignore_whitespace(
863            autogenerate.render._repr_type(
864                HSTORE(text_type=String()), self.autogen_context
865            ),
866            "postgresql.HSTORE(text_type=sa.String())",
867        )
868
869        eq_ignore_whitespace(
870            autogenerate.render._repr_type(
871                HSTORE(text_type=BYTEA()), self.autogen_context
872            ),
873            "postgresql.HSTORE(text_type=postgresql.BYTEA())",
874        )
875
876        assert (
877            "from sqlalchemy.dialects import postgresql"
878            in self.autogen_context.imports
879        )
880
881    def test_generic_array_type(self):
882
883        eq_ignore_whitespace(
884            autogenerate.render._repr_type(
885                types.ARRAY(Integer), self.autogen_context
886            ),
887            "sa.ARRAY(sa.Integer())",
888        )
889
890        eq_ignore_whitespace(
891            autogenerate.render._repr_type(
892                types.ARRAY(DateTime(timezone=True)), self.autogen_context
893            ),
894            "sa.ARRAY(sa.DateTime(timezone=True))",
895        )
896
897        assert (
898            "from sqlalchemy.dialects import postgresql"
899            not in self.autogen_context.imports
900        )
901
902        eq_ignore_whitespace(
903            autogenerate.render._repr_type(
904                types.ARRAY(BYTEA, as_tuple=True, dimensions=2),
905                self.autogen_context,
906            ),
907            "sa.ARRAY(postgresql.BYTEA(), as_tuple=True, dimensions=2)",
908        )
909
910        assert (
911            "from sqlalchemy.dialects import postgresql"
912            in self.autogen_context.imports
913        )
914
915    def test_array_type_user_defined_inner(self):
916        def repr_type(typestring, object_, autogen_context):
917            if typestring == "type" and isinstance(object_, String):
918                return "foobar.MYVARCHAR"
919            else:
920                return False
921
922        self.autogen_context.opts.update(render_item=repr_type)
923
924        eq_ignore_whitespace(
925            autogenerate.render._repr_type(
926                ARRAY(String), self.autogen_context
927            ),
928            "postgresql.ARRAY(foobar.MYVARCHAR)",
929        )
930
931    def test_add_exclude_constraint(self):
932        from sqlalchemy.dialects.postgresql import ExcludeConstraint
933
934        autogen_context = self.autogen_context
935
936        m = MetaData()
937        t = Table("t", m, Column("x", String), Column("y", String))
938
939        op_obj = ops.AddConstraintOp.from_constraint(
940            ExcludeConstraint(
941                (t.c.x, ">"), where=t.c.x != 2, using="gist", name="t_excl_x"
942            )
943        )
944
945        eq_ignore_whitespace(
946            autogenerate.render_op_text(autogen_context, op_obj),
947            "op.create_exclude_constraint('t_excl_x', "
948            "'t', (sa.column('x'), '>'), "
949            "where=sa.text(!U'x != 2'), using='gist')",
950        )
951
952    def test_add_exclude_constraint_case_sensitive(self):
953        from sqlalchemy.dialects.postgresql import ExcludeConstraint
954
955        autogen_context = self.autogen_context
956
957        m = MetaData()
958        t = Table(
959            "TTAble", m, Column("XColumn", String), Column("YColumn", String)
960        )
961
962        op_obj = ops.AddConstraintOp.from_constraint(
963            ExcludeConstraint(
964                (t.c.XColumn, ">"),
965                where=t.c.XColumn != 2,
966                using="gist",
967                name="t_excl_x",
968            )
969        )
970
971        eq_ignore_whitespace(
972            autogenerate.render_op_text(autogen_context, op_obj),
973            "op.create_exclude_constraint('t_excl_x', 'TTAble', "
974            "(sa.column('XColumn'), '>'), "
975            "where=sa.text(!U'\"XColumn\" != 2'), using='gist')",
976        )
977
978    def test_inline_exclude_constraint(self):
979        from sqlalchemy.dialects.postgresql import ExcludeConstraint
980
981        autogen_context = self.autogen_context
982
983        m = MetaData()
984        t = Table(
985            "t",
986            m,
987            Column("x", String),
988            Column("y", String),
989            ExcludeConstraint(
990                (column("x"), ">"),
991                using="gist",
992                where="x != 2",
993                name="t_excl_x",
994            ),
995        )
996
997        op_obj = ops.CreateTableOp.from_table(t)
998
999        eq_ignore_whitespace(
1000            autogenerate.render_op_text(autogen_context, op_obj),
1001            "op.create_table('t',sa.Column('x', sa.String(), nullable=True),"
1002            "sa.Column('y', sa.String(), nullable=True),"
1003            "postgresql.ExcludeConstraint((sa.column('x'), '>'), "
1004            "where=sa.text(!U'x != 2'), using='gist', name='t_excl_x')"
1005            ")",
1006        )
1007
1008    def test_inline_exclude_constraint_case_sensitive(self):
1009        from sqlalchemy.dialects.postgresql import ExcludeConstraint
1010
1011        autogen_context = self.autogen_context
1012
1013        m = MetaData()
1014        t = Table(
1015            "TTable", m, Column("XColumn", String), Column("YColumn", String)
1016        )
1017        ExcludeConstraint(
1018            (t.c.XColumn, ">"),
1019            using="gist",
1020            where='"XColumn" != 2',
1021            name="TExclX",
1022        )
1023
1024        op_obj = ops.CreateTableOp.from_table(t)
1025
1026        eq_ignore_whitespace(
1027            autogenerate.render_op_text(autogen_context, op_obj),
1028            "op.create_table('TTable',sa.Column('XColumn', sa.String(), "
1029            "nullable=True),"
1030            "sa.Column('YColumn', sa.String(), nullable=True),"
1031            "postgresql.ExcludeConstraint((sa.column('XColumn'), '>'), "
1032            "where=sa.text(!U'\"XColumn\" != 2'), using='gist', "
1033            "name='TExclX'))",
1034        )
1035
1036    def test_json_type(self):
1037        eq_ignore_whitespace(
1038            autogenerate.render._repr_type(JSON(), self.autogen_context),
1039            "postgresql.JSON(astext_type=sa.Text())",
1040        )
1041
1042    def test_jsonb_type(self):
1043        eq_ignore_whitespace(
1044            autogenerate.render._repr_type(JSONB(), self.autogen_context),
1045            "postgresql.JSONB(astext_type=sa.Text())",
1046        )
1047