1import logging
2import re
3
4from sqlalchemy import Column
5from sqlalchemy import Numeric
6from sqlalchemy import text
7from sqlalchemy import types as sqltypes
8from sqlalchemy.dialects.postgresql import BIGINT
9from sqlalchemy.dialects.postgresql import ExcludeConstraint
10from sqlalchemy.dialects.postgresql import INTEGER
11from sqlalchemy.sql.expression import ColumnClause
12from sqlalchemy.sql.expression import UnaryExpression
13from sqlalchemy.types import NULLTYPE
14
15from .base import alter_column
16from .base import alter_table
17from .base import AlterColumn
18from .base import ColumnComment
19from .base import compiles
20from .base import format_column_name
21from .base import format_table_name
22from .base import format_type
23from .base import RenameTable
24from .impl import DefaultImpl
25from .. import util
26from ..autogenerate import render
27from ..operations import ops
28from ..operations import schemaobj
29from ..operations.base import BatchOperations
30from ..operations.base import Operations
31from ..util import compat
32from ..util import sqla_compat
33
34
35log = logging.getLogger(__name__)
36
37
38class PostgresqlImpl(DefaultImpl):
39    __dialect__ = "postgresql"
40    transactional_ddl = True
41    type_synonyms = DefaultImpl.type_synonyms + (
42        {"FLOAT", "DOUBLE PRECISION"},
43    )
44
45    def prep_table_for_batch(self, table):
46        for constraint in table.constraints:
47            if constraint.name is not None:
48                self.drop_constraint(constraint)
49
50    def compare_server_default(
51        self,
52        inspector_column,
53        metadata_column,
54        rendered_metadata_default,
55        rendered_inspector_default,
56    ):
57        # don't do defaults for SERIAL columns
58        if (
59            metadata_column.primary_key
60            and metadata_column is metadata_column.table._autoincrement_column
61        ):
62            return False
63
64        conn_col_default = rendered_inspector_default
65
66        defaults_equal = conn_col_default == rendered_metadata_default
67        if defaults_equal:
68            return False
69
70        if None in (conn_col_default, rendered_metadata_default):
71            return not defaults_equal
72
73        if compat.py2k:
74            # look for a python 2 "u''" string and filter
75            m = re.match(r"^u'(.*)'$", rendered_metadata_default)
76            if m:
77                rendered_metadata_default = "'%s'" % m.group(1)
78
79        # check for unquoted string and quote for PG String types
80        if (
81            not isinstance(inspector_column.type, Numeric)
82            and metadata_column.server_default is not None
83            and isinstance(
84                metadata_column.server_default.arg, compat.string_types
85            )
86            and not re.match(r"^'.*'$", rendered_metadata_default)
87        ):
88            rendered_metadata_default = "'%s'" % rendered_metadata_default
89
90        return not self.connection.scalar(
91            text(
92                "SELECT %s = %s"
93                % (conn_col_default, rendered_metadata_default)
94            )
95        )
96
97    def alter_column(
98        self,
99        table_name,
100        column_name,
101        nullable=None,
102        server_default=False,
103        name=None,
104        type_=None,
105        schema=None,
106        autoincrement=None,
107        existing_type=None,
108        existing_server_default=None,
109        existing_nullable=None,
110        existing_autoincrement=None,
111        **kw
112    ):
113
114        using = kw.pop("postgresql_using", None)
115
116        if using is not None and type_ is None:
117            raise util.CommandError(
118                "postgresql_using must be used with the type_ parameter"
119            )
120
121        if type_ is not None:
122            self._exec(
123                PostgresqlColumnType(
124                    table_name,
125                    column_name,
126                    type_,
127                    schema=schema,
128                    using=using,
129                    existing_type=existing_type,
130                    existing_server_default=existing_server_default,
131                    existing_nullable=existing_nullable,
132                )
133            )
134
135        super(PostgresqlImpl, self).alter_column(
136            table_name,
137            column_name,
138            nullable=nullable,
139            server_default=server_default,
140            name=name,
141            schema=schema,
142            autoincrement=autoincrement,
143            existing_type=existing_type,
144            existing_server_default=existing_server_default,
145            existing_nullable=existing_nullable,
146            existing_autoincrement=existing_autoincrement,
147            **kw
148        )
149
150    def autogen_column_reflect(self, inspector, table, column_info):
151        if column_info.get("default") and isinstance(
152            column_info["type"], (INTEGER, BIGINT)
153        ):
154            seq_match = re.match(
155                r"nextval\('(.+?)'::regclass\)", column_info["default"]
156            )
157            if seq_match:
158                info = sqla_compat._exec_on_inspector(
159                    inspector,
160                    text(
161                        "select c.relname, a.attname "
162                        "from pg_class as c join "
163                        "pg_depend d on d.objid=c.oid and "
164                        "d.classid='pg_class'::regclass and "
165                        "d.refclassid='pg_class'::regclass "
166                        "join pg_class t on t.oid=d.refobjid "
167                        "join pg_attribute a on a.attrelid=t.oid and "
168                        "a.attnum=d.refobjsubid "
169                        "where c.relkind='S' and c.relname=:seqname"
170                    ),
171                    seqname=seq_match.group(1),
172                ).first()
173                if info:
174                    seqname, colname = info
175                    if colname == column_info["name"]:
176                        log.info(
177                            "Detected sequence named '%s' as "
178                            "owned by integer column '%s(%s)', "
179                            "assuming SERIAL and omitting",
180                            seqname,
181                            table.name,
182                            colname,
183                        )
184                        # sequence, and the owner is this column,
185                        # its a SERIAL - whack it!
186                        del column_info["default"]
187
188    def correct_for_autogen_constraints(
189        self,
190        conn_unique_constraints,
191        conn_indexes,
192        metadata_unique_constraints,
193        metadata_indexes,
194    ):
195
196        conn_indexes_by_name = dict((c.name, c) for c in conn_indexes)
197
198        doubled_constraints = set(
199            index
200            for index in conn_indexes
201            if index.info.get("duplicates_constraint")
202        )
203
204        for ix in doubled_constraints:
205            conn_indexes.remove(ix)
206
207        for idx in list(metadata_indexes):
208            if idx.name in conn_indexes_by_name:
209                continue
210            exprs = idx.expressions
211            for expr in exprs:
212                while isinstance(expr, UnaryExpression):
213                    expr = expr.element
214                if not isinstance(expr, Column):
215                    util.warn(
216                        "autogenerate skipping functional index %s; "
217                        "not supported by SQLAlchemy reflection" % idx.name
218                    )
219                    metadata_indexes.discard(idx)
220
221    def render_type(self, type_, autogen_context):
222        mod = type(type_).__module__
223        if not mod.startswith("sqlalchemy.dialects.postgresql"):
224            return False
225
226        if hasattr(self, "_render_%s_type" % type_.__visit_name__):
227            meth = getattr(self, "_render_%s_type" % type_.__visit_name__)
228            return meth(type_, autogen_context)
229
230        return False
231
232    def _render_HSTORE_type(self, type_, autogen_context):
233        return render._render_type_w_subtype(
234            type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
235        )
236
237    def _render_ARRAY_type(self, type_, autogen_context):
238        return render._render_type_w_subtype(
239            type_, autogen_context, "item_type", r"(.+?\()"
240        )
241
242    def _render_JSON_type(self, type_, autogen_context):
243        return render._render_type_w_subtype(
244            type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
245        )
246
247    def _render_JSONB_type(self, type_, autogen_context):
248        return render._render_type_w_subtype(
249            type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
250        )
251
252
253class PostgresqlColumnType(AlterColumn):
254    def __init__(self, name, column_name, type_, **kw):
255        using = kw.pop("using", None)
256        super(PostgresqlColumnType, self).__init__(name, column_name, **kw)
257        self.type_ = sqltypes.to_instance(type_)
258        self.using = using
259
260
261@compiles(RenameTable, "postgresql")
262def visit_rename_table(element, compiler, **kw):
263    return "%s RENAME TO %s" % (
264        alter_table(compiler, element.table_name, element.schema),
265        format_table_name(compiler, element.new_table_name, None),
266    )
267
268
269@compiles(PostgresqlColumnType, "postgresql")
270def visit_column_type(element, compiler, **kw):
271    return "%s %s %s %s" % (
272        alter_table(compiler, element.table_name, element.schema),
273        alter_column(compiler, element.column_name),
274        "TYPE %s" % format_type(compiler, element.type_),
275        "USING %s" % element.using if element.using else "",
276    )
277
278
279@compiles(ColumnComment, "postgresql")
280def visit_column_comment(element, compiler, **kw):
281    ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
282    comment = (
283        compiler.sql_compiler.render_literal_value(
284            element.comment, sqltypes.String()
285        )
286        if element.comment is not None
287        else "NULL"
288    )
289
290    return ddl.format(
291        table_name=format_table_name(
292            compiler, element.table_name, element.schema
293        ),
294        column_name=format_column_name(compiler, element.column_name),
295        comment=comment,
296    )
297
298
299@Operations.register_operation("create_exclude_constraint")
300@BatchOperations.register_operation(
301    "create_exclude_constraint", "batch_create_exclude_constraint"
302)
303@ops.AddConstraintOp.register_add_constraint("exclude_constraint")
304class CreateExcludeConstraintOp(ops.AddConstraintOp):
305    """Represent a create exclude constraint operation."""
306
307    constraint_type = "exclude"
308
309    def __init__(
310        self,
311        constraint_name,
312        table_name,
313        elements,
314        where=None,
315        schema=None,
316        _orig_constraint=None,
317        **kw
318    ):
319        self.constraint_name = constraint_name
320        self.table_name = table_name
321        self.elements = elements
322        self.where = where
323        self.schema = schema
324        self._orig_constraint = _orig_constraint
325        self.kw = kw
326
327    @classmethod
328    def from_constraint(cls, constraint):
329        constraint_table = sqla_compat._table_for_constraint(constraint)
330
331        return cls(
332            constraint.name,
333            constraint_table.name,
334            [(expr, op) for expr, name, op in constraint._render_exprs],
335            where=constraint.where,
336            schema=constraint_table.schema,
337            _orig_constraint=constraint,
338            deferrable=constraint.deferrable,
339            initially=constraint.initially,
340            using=constraint.using,
341        )
342
343    def to_constraint(self, migration_context=None):
344        if self._orig_constraint is not None:
345            return self._orig_constraint
346        schema_obj = schemaobj.SchemaObjects(migration_context)
347        t = schema_obj.table(self.table_name, schema=self.schema)
348        excl = ExcludeConstraint(
349            *self.elements,
350            name=self.constraint_name,
351            where=self.where,
352            **self.kw
353        )
354        for expr, name, oper in excl._render_exprs:
355            t.append_column(Column(name, NULLTYPE))
356        t.append_constraint(excl)
357        return excl
358
359    @classmethod
360    def create_exclude_constraint(
361        cls, operations, constraint_name, table_name, *elements, **kw
362    ):
363        """Issue an alter to create an EXCLUDE constraint using the
364        current migration context.
365
366        .. note::  This method is Postgresql specific, and additionally
367           requires at least SQLAlchemy 1.0.
368
369        e.g.::
370
371            from alembic import op
372
373            op.create_exclude_constraint(
374                "user_excl",
375                "user",
376
377                ("period", '&&'),
378                ("group", '='),
379                where=("group != 'some group'")
380
381            )
382
383        Note that the expressions work the same way as that of
384        the ``ExcludeConstraint`` object itself; if plain strings are
385        passed, quoting rules must be applied manually.
386
387        :param name: Name of the constraint.
388        :param table_name: String name of the source table.
389        :param elements: exclude conditions.
390        :param where: SQL expression or SQL string with optional WHERE
391         clause.
392        :param deferrable: optional bool. If set, emit DEFERRABLE or
393         NOT DEFERRABLE when issuing DDL for this constraint.
394        :param initially: optional string. If set, emit INITIALLY <value>
395         when issuing DDL for this constraint.
396        :param schema: Optional schema name to operate within.
397
398        .. versionadded:: 0.9.0
399
400        """
401        op = cls(constraint_name, table_name, elements, **kw)
402        return operations.invoke(op)
403
404    @classmethod
405    def batch_create_exclude_constraint(
406        cls, operations, constraint_name, *elements, **kw
407    ):
408        """Issue a "create exclude constraint" instruction using the
409        current batch migration context.
410
411        .. note::  This method is Postgresql specific, and additionally
412           requires at least SQLAlchemy 1.0.
413
414        .. versionadded:: 0.9.0
415
416        .. seealso::
417
418            :meth:`.Operations.create_exclude_constraint`
419
420        """
421        kw["schema"] = operations.impl.schema
422        op = cls(constraint_name, operations.impl.table_name, elements, **kw)
423        return operations.invoke(op)
424
425
426@render.renderers.dispatch_for(CreateExcludeConstraintOp)
427def _add_exclude_constraint(autogen_context, op):
428    return _exclude_constraint(op.to_constraint(), autogen_context, alter=True)
429
430
431@render._constraint_renderers.dispatch_for(ExcludeConstraint)
432def _render_inline_exclude_constraint(constraint, autogen_context):
433    rendered = render._user_defined_render(
434        "exclude", constraint, autogen_context
435    )
436    if rendered is not False:
437        return rendered
438
439    return _exclude_constraint(constraint, autogen_context, False)
440
441
442def _postgresql_autogenerate_prefix(autogen_context):
443
444    imports = autogen_context.imports
445    if imports is not None:
446        imports.add("from sqlalchemy.dialects import postgresql")
447    return "postgresql."
448
449
450def _exclude_constraint(constraint, autogen_context, alter):
451    opts = []
452
453    has_batch = autogen_context._has_batch
454
455    if constraint.deferrable:
456        opts.append(("deferrable", str(constraint.deferrable)))
457    if constraint.initially:
458        opts.append(("initially", str(constraint.initially)))
459    if constraint.using:
460        opts.append(("using", str(constraint.using)))
461    if not has_batch and alter and constraint.table.schema:
462        opts.append(("schema", render._ident(constraint.table.schema)))
463    if not alter and constraint.name:
464        opts.append(
465            ("name", render._render_gen_name(autogen_context, constraint.name))
466        )
467
468    if alter:
469        args = [
470            repr(render._render_gen_name(autogen_context, constraint.name))
471        ]
472        if not has_batch:
473            args += [repr(render._ident(constraint.table.name))]
474        args.extend(
475            [
476                "(%s, %r)"
477                % (
478                    _render_potential_column(sqltext, autogen_context),
479                    opstring,
480                )
481                for sqltext, name, opstring in constraint._render_exprs
482            ]
483        )
484        if constraint.where is not None:
485            args.append(
486                "where=%s"
487                % render._render_potential_expr(
488                    constraint.where, autogen_context
489                )
490            )
491        args.extend(["%s=%r" % (k, v) for k, v in opts])
492        return "%(prefix)screate_exclude_constraint(%(args)s)" % {
493            "prefix": render._alembic_autogenerate_prefix(autogen_context),
494            "args": ", ".join(args),
495        }
496    else:
497        args = [
498            "(%s, %r)"
499            % (_render_potential_column(sqltext, autogen_context), opstring)
500            for sqltext, name, opstring in constraint._render_exprs
501        ]
502        if constraint.where is not None:
503            args.append(
504                "where=%s"
505                % render._render_potential_expr(
506                    constraint.where, autogen_context
507                )
508            )
509        args.extend(["%s=%r" % (k, v) for k, v in opts])
510        return "%(prefix)sExcludeConstraint(%(args)s)" % {
511            "prefix": _postgresql_autogenerate_prefix(autogen_context),
512            "args": ", ".join(args),
513        }
514
515
516def _render_potential_column(value, autogen_context):
517    if isinstance(value, ColumnClause):
518        template = "%(prefix)scolumn(%(name)r)"
519
520        return template % {
521            "prefix": render._sqlalchemy_autogenerate_prefix(autogen_context),
522            "name": value.name,
523        }
524
525    else:
526        return render._render_potential_expr(
527            value, autogen_context, wrap_in_text=False
528        )
529