1import logging 2import re 3 4from sqlalchemy import Column 5from sqlalchemy import Numeric 6from sqlalchemy import text 7from sqlalchemy import types as sqltypes 8from sqlalchemy.dialects.postgresql import BIGINT 9from sqlalchemy.dialects.postgresql import ExcludeConstraint 10from sqlalchemy.dialects.postgresql import INTEGER 11from sqlalchemy.sql.expression import ColumnClause 12from sqlalchemy.sql.expression import UnaryExpression 13from sqlalchemy.types import NULLTYPE 14 15from .base import alter_column 16from .base import alter_table 17from .base import AlterColumn 18from .base import ColumnComment 19from .base import compiles 20from .base import format_column_name 21from .base import format_table_name 22from .base import format_type 23from .base import RenameTable 24from .impl import DefaultImpl 25from .. import util 26from ..autogenerate import render 27from ..operations import ops 28from ..operations import schemaobj 29from ..operations.base import BatchOperations 30from ..operations.base import Operations 31from ..util import compat 32from ..util import sqla_compat 33 34 35log = logging.getLogger(__name__) 36 37 38class PostgresqlImpl(DefaultImpl): 39 __dialect__ = "postgresql" 40 transactional_ddl = True 41 type_synonyms = DefaultImpl.type_synonyms + ( 42 {"FLOAT", "DOUBLE PRECISION"}, 43 ) 44 45 def prep_table_for_batch(self, table): 46 for constraint in table.constraints: 47 if constraint.name is not None: 48 self.drop_constraint(constraint) 49 50 def compare_server_default( 51 self, 52 inspector_column, 53 metadata_column, 54 rendered_metadata_default, 55 rendered_inspector_default, 56 ): 57 # don't do defaults for SERIAL columns 58 if ( 59 metadata_column.primary_key 60 and metadata_column is metadata_column.table._autoincrement_column 61 ): 62 return False 63 64 conn_col_default = rendered_inspector_default 65 66 defaults_equal = conn_col_default == rendered_metadata_default 67 if defaults_equal: 68 return False 69 70 if None in (conn_col_default, rendered_metadata_default): 71 return not defaults_equal 72 73 if compat.py2k: 74 # look for a python 2 "u''" string and filter 75 m = re.match(r"^u'(.*)'$", rendered_metadata_default) 76 if m: 77 rendered_metadata_default = "'%s'" % m.group(1) 78 79 # check for unquoted string and quote for PG String types 80 if ( 81 not isinstance(inspector_column.type, Numeric) 82 and metadata_column.server_default is not None 83 and isinstance( 84 metadata_column.server_default.arg, compat.string_types 85 ) 86 and not re.match(r"^'.*'$", rendered_metadata_default) 87 ): 88 rendered_metadata_default = "'%s'" % rendered_metadata_default 89 90 return not self.connection.scalar( 91 text( 92 "SELECT %s = %s" 93 % (conn_col_default, rendered_metadata_default) 94 ) 95 ) 96 97 def alter_column( 98 self, 99 table_name, 100 column_name, 101 nullable=None, 102 server_default=False, 103 name=None, 104 type_=None, 105 schema=None, 106 autoincrement=None, 107 existing_type=None, 108 existing_server_default=None, 109 existing_nullable=None, 110 existing_autoincrement=None, 111 **kw 112 ): 113 114 using = kw.pop("postgresql_using", None) 115 116 if using is not None and type_ is None: 117 raise util.CommandError( 118 "postgresql_using must be used with the type_ parameter" 119 ) 120 121 if type_ is not None: 122 self._exec( 123 PostgresqlColumnType( 124 table_name, 125 column_name, 126 type_, 127 schema=schema, 128 using=using, 129 existing_type=existing_type, 130 existing_server_default=existing_server_default, 131 existing_nullable=existing_nullable, 132 ) 133 ) 134 135 super(PostgresqlImpl, self).alter_column( 136 table_name, 137 column_name, 138 nullable=nullable, 139 server_default=server_default, 140 name=name, 141 schema=schema, 142 autoincrement=autoincrement, 143 existing_type=existing_type, 144 existing_server_default=existing_server_default, 145 existing_nullable=existing_nullable, 146 existing_autoincrement=existing_autoincrement, 147 **kw 148 ) 149 150 def autogen_column_reflect(self, inspector, table, column_info): 151 if column_info.get("default") and isinstance( 152 column_info["type"], (INTEGER, BIGINT) 153 ): 154 seq_match = re.match( 155 r"nextval\('(.+?)'::regclass\)", column_info["default"] 156 ) 157 if seq_match: 158 info = sqla_compat._exec_on_inspector( 159 inspector, 160 text( 161 "select c.relname, a.attname " 162 "from pg_class as c join " 163 "pg_depend d on d.objid=c.oid and " 164 "d.classid='pg_class'::regclass and " 165 "d.refclassid='pg_class'::regclass " 166 "join pg_class t on t.oid=d.refobjid " 167 "join pg_attribute a on a.attrelid=t.oid and " 168 "a.attnum=d.refobjsubid " 169 "where c.relkind='S' and c.relname=:seqname" 170 ), 171 seqname=seq_match.group(1), 172 ).first() 173 if info: 174 seqname, colname = info 175 if colname == column_info["name"]: 176 log.info( 177 "Detected sequence named '%s' as " 178 "owned by integer column '%s(%s)', " 179 "assuming SERIAL and omitting", 180 seqname, 181 table.name, 182 colname, 183 ) 184 # sequence, and the owner is this column, 185 # its a SERIAL - whack it! 186 del column_info["default"] 187 188 def correct_for_autogen_constraints( 189 self, 190 conn_unique_constraints, 191 conn_indexes, 192 metadata_unique_constraints, 193 metadata_indexes, 194 ): 195 196 conn_indexes_by_name = dict((c.name, c) for c in conn_indexes) 197 198 doubled_constraints = set( 199 index 200 for index in conn_indexes 201 if index.info.get("duplicates_constraint") 202 ) 203 204 for ix in doubled_constraints: 205 conn_indexes.remove(ix) 206 207 for idx in list(metadata_indexes): 208 if idx.name in conn_indexes_by_name: 209 continue 210 exprs = idx.expressions 211 for expr in exprs: 212 while isinstance(expr, UnaryExpression): 213 expr = expr.element 214 if not isinstance(expr, Column): 215 util.warn( 216 "autogenerate skipping functional index %s; " 217 "not supported by SQLAlchemy reflection" % idx.name 218 ) 219 metadata_indexes.discard(idx) 220 221 def render_type(self, type_, autogen_context): 222 mod = type(type_).__module__ 223 if not mod.startswith("sqlalchemy.dialects.postgresql"): 224 return False 225 226 if hasattr(self, "_render_%s_type" % type_.__visit_name__): 227 meth = getattr(self, "_render_%s_type" % type_.__visit_name__) 228 return meth(type_, autogen_context) 229 230 return False 231 232 def _render_HSTORE_type(self, type_, autogen_context): 233 return render._render_type_w_subtype( 234 type_, autogen_context, "text_type", r"(.+?\(.*text_type=)" 235 ) 236 237 def _render_ARRAY_type(self, type_, autogen_context): 238 return render._render_type_w_subtype( 239 type_, autogen_context, "item_type", r"(.+?\()" 240 ) 241 242 def _render_JSON_type(self, type_, autogen_context): 243 return render._render_type_w_subtype( 244 type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)" 245 ) 246 247 def _render_JSONB_type(self, type_, autogen_context): 248 return render._render_type_w_subtype( 249 type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)" 250 ) 251 252 253class PostgresqlColumnType(AlterColumn): 254 def __init__(self, name, column_name, type_, **kw): 255 using = kw.pop("using", None) 256 super(PostgresqlColumnType, self).__init__(name, column_name, **kw) 257 self.type_ = sqltypes.to_instance(type_) 258 self.using = using 259 260 261@compiles(RenameTable, "postgresql") 262def visit_rename_table(element, compiler, **kw): 263 return "%s RENAME TO %s" % ( 264 alter_table(compiler, element.table_name, element.schema), 265 format_table_name(compiler, element.new_table_name, None), 266 ) 267 268 269@compiles(PostgresqlColumnType, "postgresql") 270def visit_column_type(element, compiler, **kw): 271 return "%s %s %s %s" % ( 272 alter_table(compiler, element.table_name, element.schema), 273 alter_column(compiler, element.column_name), 274 "TYPE %s" % format_type(compiler, element.type_), 275 "USING %s" % element.using if element.using else "", 276 ) 277 278 279@compiles(ColumnComment, "postgresql") 280def visit_column_comment(element, compiler, **kw): 281 ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}" 282 comment = ( 283 compiler.sql_compiler.render_literal_value( 284 element.comment, sqltypes.String() 285 ) 286 if element.comment is not None 287 else "NULL" 288 ) 289 290 return ddl.format( 291 table_name=format_table_name( 292 compiler, element.table_name, element.schema 293 ), 294 column_name=format_column_name(compiler, element.column_name), 295 comment=comment, 296 ) 297 298 299@Operations.register_operation("create_exclude_constraint") 300@BatchOperations.register_operation( 301 "create_exclude_constraint", "batch_create_exclude_constraint" 302) 303@ops.AddConstraintOp.register_add_constraint("exclude_constraint") 304class CreateExcludeConstraintOp(ops.AddConstraintOp): 305 """Represent a create exclude constraint operation.""" 306 307 constraint_type = "exclude" 308 309 def __init__( 310 self, 311 constraint_name, 312 table_name, 313 elements, 314 where=None, 315 schema=None, 316 _orig_constraint=None, 317 **kw 318 ): 319 self.constraint_name = constraint_name 320 self.table_name = table_name 321 self.elements = elements 322 self.where = where 323 self.schema = schema 324 self._orig_constraint = _orig_constraint 325 self.kw = kw 326 327 @classmethod 328 def from_constraint(cls, constraint): 329 constraint_table = sqla_compat._table_for_constraint(constraint) 330 331 return cls( 332 constraint.name, 333 constraint_table.name, 334 [(expr, op) for expr, name, op in constraint._render_exprs], 335 where=constraint.where, 336 schema=constraint_table.schema, 337 _orig_constraint=constraint, 338 deferrable=constraint.deferrable, 339 initially=constraint.initially, 340 using=constraint.using, 341 ) 342 343 def to_constraint(self, migration_context=None): 344 if self._orig_constraint is not None: 345 return self._orig_constraint 346 schema_obj = schemaobj.SchemaObjects(migration_context) 347 t = schema_obj.table(self.table_name, schema=self.schema) 348 excl = ExcludeConstraint( 349 *self.elements, 350 name=self.constraint_name, 351 where=self.where, 352 **self.kw 353 ) 354 for expr, name, oper in excl._render_exprs: 355 t.append_column(Column(name, NULLTYPE)) 356 t.append_constraint(excl) 357 return excl 358 359 @classmethod 360 def create_exclude_constraint( 361 cls, operations, constraint_name, table_name, *elements, **kw 362 ): 363 """Issue an alter to create an EXCLUDE constraint using the 364 current migration context. 365 366 .. note:: This method is Postgresql specific, and additionally 367 requires at least SQLAlchemy 1.0. 368 369 e.g.:: 370 371 from alembic import op 372 373 op.create_exclude_constraint( 374 "user_excl", 375 "user", 376 377 ("period", '&&'), 378 ("group", '='), 379 where=("group != 'some group'") 380 381 ) 382 383 Note that the expressions work the same way as that of 384 the ``ExcludeConstraint`` object itself; if plain strings are 385 passed, quoting rules must be applied manually. 386 387 :param name: Name of the constraint. 388 :param table_name: String name of the source table. 389 :param elements: exclude conditions. 390 :param where: SQL expression or SQL string with optional WHERE 391 clause. 392 :param deferrable: optional bool. If set, emit DEFERRABLE or 393 NOT DEFERRABLE when issuing DDL for this constraint. 394 :param initially: optional string. If set, emit INITIALLY <value> 395 when issuing DDL for this constraint. 396 :param schema: Optional schema name to operate within. 397 398 .. versionadded:: 0.9.0 399 400 """ 401 op = cls(constraint_name, table_name, elements, **kw) 402 return operations.invoke(op) 403 404 @classmethod 405 def batch_create_exclude_constraint( 406 cls, operations, constraint_name, *elements, **kw 407 ): 408 """Issue a "create exclude constraint" instruction using the 409 current batch migration context. 410 411 .. note:: This method is Postgresql specific, and additionally 412 requires at least SQLAlchemy 1.0. 413 414 .. versionadded:: 0.9.0 415 416 .. seealso:: 417 418 :meth:`.Operations.create_exclude_constraint` 419 420 """ 421 kw["schema"] = operations.impl.schema 422 op = cls(constraint_name, operations.impl.table_name, elements, **kw) 423 return operations.invoke(op) 424 425 426@render.renderers.dispatch_for(CreateExcludeConstraintOp) 427def _add_exclude_constraint(autogen_context, op): 428 return _exclude_constraint(op.to_constraint(), autogen_context, alter=True) 429 430 431@render._constraint_renderers.dispatch_for(ExcludeConstraint) 432def _render_inline_exclude_constraint(constraint, autogen_context): 433 rendered = render._user_defined_render( 434 "exclude", constraint, autogen_context 435 ) 436 if rendered is not False: 437 return rendered 438 439 return _exclude_constraint(constraint, autogen_context, False) 440 441 442def _postgresql_autogenerate_prefix(autogen_context): 443 444 imports = autogen_context.imports 445 if imports is not None: 446 imports.add("from sqlalchemy.dialects import postgresql") 447 return "postgresql." 448 449 450def _exclude_constraint(constraint, autogen_context, alter): 451 opts = [] 452 453 has_batch = autogen_context._has_batch 454 455 if constraint.deferrable: 456 opts.append(("deferrable", str(constraint.deferrable))) 457 if constraint.initially: 458 opts.append(("initially", str(constraint.initially))) 459 if constraint.using: 460 opts.append(("using", str(constraint.using))) 461 if not has_batch and alter and constraint.table.schema: 462 opts.append(("schema", render._ident(constraint.table.schema))) 463 if not alter and constraint.name: 464 opts.append( 465 ("name", render._render_gen_name(autogen_context, constraint.name)) 466 ) 467 468 if alter: 469 args = [ 470 repr(render._render_gen_name(autogen_context, constraint.name)) 471 ] 472 if not has_batch: 473 args += [repr(render._ident(constraint.table.name))] 474 args.extend( 475 [ 476 "(%s, %r)" 477 % ( 478 _render_potential_column(sqltext, autogen_context), 479 opstring, 480 ) 481 for sqltext, name, opstring in constraint._render_exprs 482 ] 483 ) 484 if constraint.where is not None: 485 args.append( 486 "where=%s" 487 % render._render_potential_expr( 488 constraint.where, autogen_context 489 ) 490 ) 491 args.extend(["%s=%r" % (k, v) for k, v in opts]) 492 return "%(prefix)screate_exclude_constraint(%(args)s)" % { 493 "prefix": render._alembic_autogenerate_prefix(autogen_context), 494 "args": ", ".join(args), 495 } 496 else: 497 args = [ 498 "(%s, %r)" 499 % (_render_potential_column(sqltext, autogen_context), opstring) 500 for sqltext, name, opstring in constraint._render_exprs 501 ] 502 if constraint.where is not None: 503 args.append( 504 "where=%s" 505 % render._render_potential_expr( 506 constraint.where, autogen_context 507 ) 508 ) 509 args.extend(["%s=%r" % (k, v) for k, v in opts]) 510 return "%(prefix)sExcludeConstraint(%(args)s)" % { 511 "prefix": _postgresql_autogenerate_prefix(autogen_context), 512 "args": ", ".join(args), 513 } 514 515 516def _render_potential_column(value, autogen_context): 517 if isinstance(value, ColumnClause): 518 template = "%(prefix)scolumn(%(name)r)" 519 520 return template % { 521 "prefix": render._sqlalchemy_autogenerate_prefix(autogen_context), 522 "name": value.name, 523 } 524 525 else: 526 return render._render_potential_expr( 527 value, autogen_context, wrap_in_text=False 528 ) 529