1from typing import Any 2from typing import cast 3from typing import Dict 4from typing import List 5from typing import Optional 6from typing import Tuple 7from typing import TYPE_CHECKING 8from typing import Union 9 10from sqlalchemy import CheckConstraint 11from sqlalchemy import Column 12from sqlalchemy import ForeignKeyConstraint 13from sqlalchemy import Index 14from sqlalchemy import MetaData 15from sqlalchemy import PrimaryKeyConstraint 16from sqlalchemy import schema as sql_schema 17from sqlalchemy import Table 18from sqlalchemy import types as sqltypes 19from sqlalchemy.events import SchemaEventTarget 20from sqlalchemy.util import OrderedDict 21from sqlalchemy.util import topological 22 23from ..util import exc 24from ..util.sqla_compat import _columns_for_constraint 25from ..util.sqla_compat import _copy 26from ..util.sqla_compat import _ensure_scope_for_ddl 27from ..util.sqla_compat import _fk_is_self_referential 28from ..util.sqla_compat import _insert_inline 29from ..util.sqla_compat import _is_type_bound 30from ..util.sqla_compat import _remove_column_from_collection 31from ..util.sqla_compat import _select 32 33if TYPE_CHECKING: 34 from typing import Literal 35 36 from sqlalchemy.engine import Dialect 37 from sqlalchemy.sql.elements import ColumnClause 38 from sqlalchemy.sql.elements import quoted_name 39 from sqlalchemy.sql.functions import Function 40 from sqlalchemy.sql.schema import Constraint 41 from sqlalchemy.sql.type_api import TypeEngine 42 43 from ..ddl.impl import DefaultImpl 44 45 46class BatchOperationsImpl: 47 def __init__( 48 self, 49 operations, 50 table_name, 51 schema, 52 recreate, 53 copy_from, 54 table_args, 55 table_kwargs, 56 reflect_args, 57 reflect_kwargs, 58 naming_convention, 59 partial_reordering, 60 ): 61 self.operations = operations 62 self.table_name = table_name 63 self.schema = schema 64 if recreate not in ("auto", "always", "never"): 65 raise ValueError( 66 "recreate may be one of 'auto', 'always', or 'never'." 67 ) 68 self.recreate = recreate 69 self.copy_from = copy_from 70 self.table_args = table_args 71 self.table_kwargs = dict(table_kwargs) 72 self.reflect_args = reflect_args 73 self.reflect_kwargs = dict(reflect_kwargs) 74 self.reflect_kwargs.setdefault( 75 "listeners", list(self.reflect_kwargs.get("listeners", ())) 76 ) 77 self.reflect_kwargs["listeners"].append( 78 ("column_reflect", operations.impl.autogen_column_reflect) 79 ) 80 self.naming_convention = naming_convention 81 self.partial_reordering = partial_reordering 82 self.batch = [] 83 84 @property 85 def dialect(self) -> "Dialect": 86 return self.operations.impl.dialect 87 88 @property 89 def impl(self) -> "DefaultImpl": 90 return self.operations.impl 91 92 def _should_recreate(self) -> bool: 93 if self.recreate == "auto": 94 return self.operations.impl.requires_recreate_in_batch(self) 95 elif self.recreate == "always": 96 return True 97 else: 98 return False 99 100 def flush(self) -> None: 101 should_recreate = self._should_recreate() 102 103 with _ensure_scope_for_ddl(self.impl.connection): 104 if not should_recreate: 105 for opname, arg, kw in self.batch: 106 fn = getattr(self.operations.impl, opname) 107 fn(*arg, **kw) 108 else: 109 if self.naming_convention: 110 m1 = MetaData(naming_convention=self.naming_convention) 111 else: 112 m1 = MetaData() 113 114 if self.copy_from is not None: 115 existing_table = self.copy_from 116 reflected = False 117 else: 118 existing_table = Table( 119 self.table_name, 120 m1, 121 schema=self.schema, 122 autoload_with=self.operations.get_bind(), 123 *self.reflect_args, 124 **self.reflect_kwargs 125 ) 126 reflected = True 127 128 batch_impl = ApplyBatchImpl( 129 self.impl, 130 existing_table, 131 self.table_args, 132 self.table_kwargs, 133 reflected, 134 partial_reordering=self.partial_reordering, 135 ) 136 for opname, arg, kw in self.batch: 137 fn = getattr(batch_impl, opname) 138 fn(*arg, **kw) 139 140 batch_impl._create(self.impl) 141 142 def alter_column(self, *arg, **kw) -> None: 143 self.batch.append(("alter_column", arg, kw)) 144 145 def add_column(self, *arg, **kw) -> None: 146 if ( 147 "insert_before" in kw or "insert_after" in kw 148 ) and not self._should_recreate(): 149 raise exc.CommandError( 150 "Can't specify insert_before or insert_after when using " 151 "ALTER; please specify recreate='always'" 152 ) 153 self.batch.append(("add_column", arg, kw)) 154 155 def drop_column(self, *arg, **kw) -> None: 156 self.batch.append(("drop_column", arg, kw)) 157 158 def add_constraint(self, const: "Constraint") -> None: 159 self.batch.append(("add_constraint", (const,), {})) 160 161 def drop_constraint(self, const: "Constraint") -> None: 162 self.batch.append(("drop_constraint", (const,), {})) 163 164 def rename_table(self, *arg, **kw): 165 self.batch.append(("rename_table", arg, kw)) 166 167 def create_index(self, idx: "Index") -> None: 168 self.batch.append(("create_index", (idx,), {})) 169 170 def drop_index(self, idx: "Index") -> None: 171 self.batch.append(("drop_index", (idx,), {})) 172 173 def create_table_comment(self, table): 174 self.batch.append(("create_table_comment", (table,), {})) 175 176 def drop_table_comment(self, table): 177 self.batch.append(("drop_table_comment", (table,), {})) 178 179 def create_table(self, table): 180 raise NotImplementedError("Can't create table in batch mode") 181 182 def drop_table(self, table): 183 raise NotImplementedError("Can't drop table in batch mode") 184 185 def create_column_comment(self, column): 186 self.batch.append(("create_column_comment", (column,), {})) 187 188 189class ApplyBatchImpl: 190 def __init__( 191 self, 192 impl: "DefaultImpl", 193 table: "Table", 194 table_args: tuple, 195 table_kwargs: Dict[str, Any], 196 reflected: bool, 197 partial_reordering: tuple = (), 198 ) -> None: 199 self.impl = impl 200 self.table = table # this is a Table object 201 self.table_args = table_args 202 self.table_kwargs = table_kwargs 203 self.temp_table_name = self._calc_temp_name(table.name) 204 self.new_table: Optional[Table] = None 205 206 self.partial_reordering = partial_reordering # tuple of tuples 207 self.add_col_ordering: Tuple[ 208 Tuple[str, str], ... 209 ] = () # tuple of tuples 210 211 self.column_transfers = OrderedDict( 212 (c.name, {"expr": c}) for c in self.table.c 213 ) 214 self.existing_ordering = list(self.column_transfers) 215 216 self.reflected = reflected 217 self._grab_table_elements() 218 219 @classmethod 220 def _calc_temp_name(cls, tablename: "quoted_name") -> str: 221 return ("_alembic_tmp_%s" % tablename)[0:50] 222 223 def _grab_table_elements(self) -> None: 224 schema = self.table.schema 225 self.columns: Dict[str, "Column"] = OrderedDict() 226 for c in self.table.c: 227 c_copy = _copy(c, schema=schema) 228 c_copy.unique = c_copy.index = False 229 # ensure that the type object was copied, 230 # as we may need to modify it in-place 231 if isinstance(c.type, SchemaEventTarget): 232 assert c_copy.type is not c.type 233 self.columns[c.name] = c_copy 234 self.named_constraints: Dict[str, "Constraint"] = {} 235 self.unnamed_constraints = [] 236 self.col_named_constraints = {} 237 self.indexes: Dict[str, "Index"] = {} 238 self.new_indexes: Dict[str, "Index"] = {} 239 240 for const in self.table.constraints: 241 if _is_type_bound(const): 242 continue 243 elif ( 244 self.reflected 245 and isinstance(const, CheckConstraint) 246 and not const.name 247 ): 248 # TODO: we are skipping unnamed reflected CheckConstraint 249 # because 250 # we have no way to determine _is_type_bound() for these. 251 pass 252 elif const.name: 253 self.named_constraints[const.name] = const 254 else: 255 self.unnamed_constraints.append(const) 256 257 if not self.reflected: 258 for col in self.table.c: 259 for const in col.constraints: 260 if const.name: 261 self.col_named_constraints[const.name] = (col, const) 262 263 for idx in self.table.indexes: 264 self.indexes[idx.name] = idx 265 266 for k in self.table.kwargs: 267 self.table_kwargs.setdefault(k, self.table.kwargs[k]) 268 269 def _adjust_self_columns_for_partial_reordering(self) -> None: 270 pairs = set() 271 272 col_by_idx = list(self.columns) 273 274 if self.partial_reordering: 275 for tuple_ in self.partial_reordering: 276 for index, elem in enumerate(tuple_): 277 if index > 0: 278 pairs.add((tuple_[index - 1], elem)) 279 else: 280 for index, elem in enumerate(self.existing_ordering): 281 if index > 0: 282 pairs.add((col_by_idx[index - 1], elem)) 283 284 pairs.update(self.add_col_ordering) 285 286 # this can happen if some columns were dropped and not removed 287 # from existing_ordering. this should be prevented already, but 288 # conservatively making sure this didn't happen 289 pairs_list = [p for p in pairs if p[0] != p[1]] 290 291 sorted_ = list( 292 topological.sort(pairs_list, col_by_idx, deterministic_order=True) 293 ) 294 self.columns = OrderedDict((k, self.columns[k]) for k in sorted_) 295 self.column_transfers = OrderedDict( 296 (k, self.column_transfers[k]) for k in sorted_ 297 ) 298 299 def _transfer_elements_to_new_table(self) -> None: 300 assert self.new_table is None, "Can only create new table once" 301 302 m = MetaData() 303 schema = self.table.schema 304 305 if self.partial_reordering or self.add_col_ordering: 306 self._adjust_self_columns_for_partial_reordering() 307 308 self.new_table = new_table = Table( 309 self.temp_table_name, 310 m, 311 *(list(self.columns.values()) + list(self.table_args)), 312 schema=schema, 313 **self.table_kwargs 314 ) 315 316 for const in ( 317 list(self.named_constraints.values()) + self.unnamed_constraints 318 ): 319 320 const_columns = set( 321 [c.key for c in _columns_for_constraint(const)] 322 ) 323 324 if not const_columns.issubset(self.column_transfers): 325 continue 326 327 const_copy: "Constraint" 328 if isinstance(const, ForeignKeyConstraint): 329 if _fk_is_self_referential(const): 330 # for self-referential constraint, refer to the 331 # *original* table name, and not _alembic_batch_temp. 332 # This is consistent with how we're handling 333 # FK constraints from other tables; we assume SQLite 334 # no foreign keys just keeps the names unchanged, so 335 # when we rename back, they match again. 336 const_copy = _copy( 337 const, schema=schema, target_table=self.table 338 ) 339 else: 340 # "target_table" for ForeignKeyConstraint.copy() is 341 # only used if the FK is detected as being 342 # self-referential, which we are handling above. 343 const_copy = _copy(const, schema=schema) 344 else: 345 const_copy = _copy( 346 const, schema=schema, target_table=new_table 347 ) 348 if isinstance(const, ForeignKeyConstraint): 349 self._setup_referent(m, const) 350 new_table.append_constraint(const_copy) 351 352 def _gather_indexes_from_both_tables(self) -> List["Index"]: 353 assert self.new_table is not None 354 idx: List[Index] = [] 355 idx.extend(self.indexes.values()) 356 for index in self.new_indexes.values(): 357 idx.append( 358 Index( 359 index.name, 360 unique=index.unique, 361 *[self.new_table.c[col] for col in index.columns.keys()], 362 **index.kwargs 363 ) 364 ) 365 return idx 366 367 def _setup_referent( 368 self, metadata: "MetaData", constraint: "ForeignKeyConstraint" 369 ) -> None: 370 spec = constraint.elements[ 371 0 372 ]._get_colspec() # type:ignore[attr-defined] 373 parts = spec.split(".") 374 tname = parts[-2] 375 if len(parts) == 3: 376 referent_schema = parts[0] 377 else: 378 referent_schema = None 379 380 if tname != self.temp_table_name: 381 key = sql_schema._get_table_key(tname, referent_schema) 382 383 def colspec(elem: Any): 384 return elem._get_colspec() 385 386 if key in metadata.tables: 387 t = metadata.tables[key] 388 for elem in constraint.elements: 389 colname = colspec(elem).split(".")[-1] 390 if colname not in t.c: 391 t.append_column(Column(colname, sqltypes.NULLTYPE)) 392 else: 393 Table( 394 tname, 395 metadata, 396 *[ 397 Column(n, sqltypes.NULLTYPE) 398 for n in [ 399 colspec(elem).split(".")[-1] 400 for elem in constraint.elements 401 ] 402 ], 403 schema=referent_schema 404 ) 405 406 def _create(self, op_impl: "DefaultImpl") -> None: 407 self._transfer_elements_to_new_table() 408 409 op_impl.prep_table_for_batch(self, self.table) 410 assert self.new_table is not None 411 op_impl.create_table(self.new_table) 412 413 try: 414 op_impl._exec( 415 _insert_inline(self.new_table).from_select( 416 list( 417 k 418 for k, transfer in self.column_transfers.items() 419 if "expr" in transfer 420 ), 421 _select( 422 *[ 423 transfer["expr"] 424 for transfer in self.column_transfers.values() 425 if "expr" in transfer 426 ] 427 ), 428 ) 429 ) 430 op_impl.drop_table(self.table) 431 except: 432 op_impl.drop_table(self.new_table) 433 raise 434 else: 435 op_impl.rename_table( 436 self.temp_table_name, self.table.name, schema=self.table.schema 437 ) 438 self.new_table.name = self.table.name 439 try: 440 for idx in self._gather_indexes_from_both_tables(): 441 op_impl.create_index(idx) 442 finally: 443 self.new_table.name = self.temp_table_name 444 445 def alter_column( 446 self, 447 table_name: str, 448 column_name: str, 449 nullable: Optional[bool] = None, 450 server_default: Optional[Union["Function", str, bool]] = False, 451 name: Optional[str] = None, 452 type_: Optional["TypeEngine"] = None, 453 autoincrement: None = None, 454 comment: Union[str, "Literal[False]"] = False, 455 **kw 456 ) -> None: 457 existing = self.columns[column_name] 458 existing_transfer: Dict[str, Any] = self.column_transfers[column_name] 459 if name is not None and name != column_name: 460 # note that we don't change '.key' - we keep referring 461 # to the renamed column by its old key in _create(). neat! 462 existing.name = name 463 existing_transfer["name"] = name 464 465 # pop named constraints for Boolean/Enum for rename 466 if ( 467 "existing_type" in kw 468 and isinstance(kw["existing_type"], SchemaEventTarget) 469 and kw["existing_type"].name # type:ignore[attr-defined] 470 ): 471 self.named_constraints.pop( 472 kw["existing_type"].name, None # type:ignore[attr-defined] 473 ) 474 475 if type_ is not None: 476 type_ = sqltypes.to_instance(type_) 477 # old type is being discarded so turn off eventing 478 # rules. Alternatively we can 479 # erase the events set up by this type, but this is simpler. 480 # we also ignore the drop_constraint that will come here from 481 # Operations.implementation_for(alter_column) 482 483 if isinstance(existing.type, SchemaEventTarget): 484 existing.type._create_events = ( # type:ignore[attr-defined] 485 existing.type.create_constraint # type:ignore[attr-defined] # noqa 486 ) = False 487 488 self.impl.cast_for_batch_migrate( 489 existing, existing_transfer, type_ 490 ) 491 492 existing.type = type_ 493 494 # we *dont* however set events for the new type, because 495 # alter_column is invoked from 496 # Operations.implementation_for(alter_column) which already 497 # will emit an add_constraint() 498 499 if nullable is not None: 500 existing.nullable = nullable 501 if server_default is not False: 502 if server_default is None: 503 existing.server_default = None 504 else: 505 sql_schema.DefaultClause( 506 server_default 507 )._set_parent( # type:ignore[attr-defined] 508 existing 509 ) 510 if autoincrement is not None: 511 existing.autoincrement = bool(autoincrement) 512 513 if comment is not False: 514 existing.comment = comment 515 516 def _setup_dependencies_for_add_column( 517 self, 518 colname: str, 519 insert_before: Optional[str], 520 insert_after: Optional[str], 521 ) -> None: 522 index_cols = self.existing_ordering 523 col_indexes = {name: i for i, name in enumerate(index_cols)} 524 525 if not self.partial_reordering: 526 if insert_after: 527 if not insert_before: 528 if insert_after in col_indexes: 529 # insert after an existing column 530 idx = col_indexes[insert_after] + 1 531 if idx < len(index_cols): 532 insert_before = index_cols[idx] 533 else: 534 # insert after a column that is also new 535 insert_before = dict(self.add_col_ordering)[ 536 insert_after 537 ] 538 if insert_before: 539 if not insert_after: 540 if insert_before in col_indexes: 541 # insert before an existing column 542 idx = col_indexes[insert_before] - 1 543 if idx >= 0: 544 insert_after = index_cols[idx] 545 else: 546 # insert before a column that is also new 547 insert_after = dict( 548 (b, a) for a, b in self.add_col_ordering 549 )[insert_before] 550 551 if insert_before: 552 self.add_col_ordering += ((colname, insert_before),) 553 if insert_after: 554 self.add_col_ordering += ((insert_after, colname),) 555 556 if ( 557 not self.partial_reordering 558 and not insert_before 559 and not insert_after 560 and col_indexes 561 ): 562 self.add_col_ordering += ((index_cols[-1], colname),) 563 564 def add_column( 565 self, 566 table_name: str, 567 column: "Column", 568 insert_before: Optional[str] = None, 569 insert_after: Optional[str] = None, 570 **kw 571 ) -> None: 572 self._setup_dependencies_for_add_column( 573 column.name, insert_before, insert_after 574 ) 575 # we copy the column because operations.add_column() 576 # gives us a Column that is part of a Table already. 577 self.columns[column.name] = _copy(column, schema=self.table.schema) 578 self.column_transfers[column.name] = {} 579 580 def drop_column( 581 self, table_name: str, column: Union["ColumnClause", "Column"], **kw 582 ) -> None: 583 if column.name in self.table.primary_key.columns: 584 _remove_column_from_collection( 585 self.table.primary_key.columns, column 586 ) 587 del self.columns[column.name] 588 del self.column_transfers[column.name] 589 self.existing_ordering.remove(column.name) 590 591 # pop named constraints for Boolean/Enum for rename 592 if ( 593 "existing_type" in kw 594 and isinstance(kw["existing_type"], SchemaEventTarget) 595 and kw["existing_type"].name # type:ignore[attr-defined] 596 ): 597 self.named_constraints.pop( 598 kw["existing_type"].name, None # type:ignore[attr-defined] 599 ) 600 601 def create_column_comment(self, column): 602 """the batch table creation function will issue create_column_comment 603 on the real "impl" as part of the create table process. 604 605 That is, the Column object will have the comment on it already, 606 so when it is received by add_column() it will be a normal part of 607 the CREATE TABLE and doesn't need an extra step here. 608 609 """ 610 611 def create_table_comment(self, table): 612 """the batch table creation function will issue create_table_comment 613 on the real "impl" as part of the create table process. 614 615 """ 616 617 def drop_table_comment(self, table): 618 """the batch table creation function will issue drop_table_comment 619 on the real "impl" as part of the create table process. 620 621 """ 622 623 def add_constraint(self, const: "Constraint") -> None: 624 if not const.name: 625 raise ValueError("Constraint must have a name") 626 if isinstance(const, sql_schema.PrimaryKeyConstraint): 627 if self.table.primary_key in self.unnamed_constraints: 628 self.unnamed_constraints.remove(self.table.primary_key) 629 630 self.named_constraints[const.name] = const 631 632 def drop_constraint(self, const: "Constraint") -> None: 633 if not const.name: 634 raise ValueError("Constraint must have a name") 635 try: 636 if const.name in self.col_named_constraints: 637 col, const = self.col_named_constraints.pop(const.name) 638 639 for col_const in list(self.columns[col.name].constraints): 640 if col_const.name == const.name: 641 self.columns[col.name].constraints.remove(col_const) 642 else: 643 const = self.named_constraints.pop(cast(str, const.name)) 644 except KeyError: 645 if _is_type_bound(const): 646 # type-bound constraints are only included in the new 647 # table via their type object in any case, so ignore the 648 # drop_constraint() that comes here via the 649 # Operations.implementation_for(alter_column) 650 return 651 raise ValueError("No such constraint: '%s'" % const.name) 652 else: 653 if isinstance(const, PrimaryKeyConstraint): 654 for col in const.columns: 655 self.columns[col.name].primary_key = False 656 657 def create_index(self, idx: "Index") -> None: 658 self.new_indexes[idx.name] = idx 659 660 def drop_index(self, idx: "Index") -> None: 661 try: 662 del self.indexes[idx.name] 663 except KeyError: 664 raise ValueError("No such index: '%s'" % idx.name) 665 666 def rename_table(self, *arg, **kw): 667 raise NotImplementedError("TODO") 668