1from typing import Any
2from typing import cast
3from typing import Dict
4from typing import List
5from typing import Optional
6from typing import Tuple
7from typing import TYPE_CHECKING
8from typing import Union
9
10from sqlalchemy import CheckConstraint
11from sqlalchemy import Column
12from sqlalchemy import ForeignKeyConstraint
13from sqlalchemy import Index
14from sqlalchemy import MetaData
15from sqlalchemy import PrimaryKeyConstraint
16from sqlalchemy import schema as sql_schema
17from sqlalchemy import Table
18from sqlalchemy import types as sqltypes
19from sqlalchemy.events import SchemaEventTarget
20from sqlalchemy.util import OrderedDict
21from sqlalchemy.util import topological
22
23from ..util import exc
24from ..util.sqla_compat import _columns_for_constraint
25from ..util.sqla_compat import _copy
26from ..util.sqla_compat import _ensure_scope_for_ddl
27from ..util.sqla_compat import _fk_is_self_referential
28from ..util.sqla_compat import _insert_inline
29from ..util.sqla_compat import _is_type_bound
30from ..util.sqla_compat import _remove_column_from_collection
31from ..util.sqla_compat import _select
32
33if TYPE_CHECKING:
34    from typing import Literal
35
36    from sqlalchemy.engine import Dialect
37    from sqlalchemy.sql.elements import ColumnClause
38    from sqlalchemy.sql.elements import quoted_name
39    from sqlalchemy.sql.functions import Function
40    from sqlalchemy.sql.schema import Constraint
41    from sqlalchemy.sql.type_api import TypeEngine
42
43    from ..ddl.impl import DefaultImpl
44
45
46class BatchOperationsImpl:
47    def __init__(
48        self,
49        operations,
50        table_name,
51        schema,
52        recreate,
53        copy_from,
54        table_args,
55        table_kwargs,
56        reflect_args,
57        reflect_kwargs,
58        naming_convention,
59        partial_reordering,
60    ):
61        self.operations = operations
62        self.table_name = table_name
63        self.schema = schema
64        if recreate not in ("auto", "always", "never"):
65            raise ValueError(
66                "recreate may be one of 'auto', 'always', or 'never'."
67            )
68        self.recreate = recreate
69        self.copy_from = copy_from
70        self.table_args = table_args
71        self.table_kwargs = dict(table_kwargs)
72        self.reflect_args = reflect_args
73        self.reflect_kwargs = dict(reflect_kwargs)
74        self.reflect_kwargs.setdefault(
75            "listeners", list(self.reflect_kwargs.get("listeners", ()))
76        )
77        self.reflect_kwargs["listeners"].append(
78            ("column_reflect", operations.impl.autogen_column_reflect)
79        )
80        self.naming_convention = naming_convention
81        self.partial_reordering = partial_reordering
82        self.batch = []
83
84    @property
85    def dialect(self) -> "Dialect":
86        return self.operations.impl.dialect
87
88    @property
89    def impl(self) -> "DefaultImpl":
90        return self.operations.impl
91
92    def _should_recreate(self) -> bool:
93        if self.recreate == "auto":
94            return self.operations.impl.requires_recreate_in_batch(self)
95        elif self.recreate == "always":
96            return True
97        else:
98            return False
99
100    def flush(self) -> None:
101        should_recreate = self._should_recreate()
102
103        with _ensure_scope_for_ddl(self.impl.connection):
104            if not should_recreate:
105                for opname, arg, kw in self.batch:
106                    fn = getattr(self.operations.impl, opname)
107                    fn(*arg, **kw)
108            else:
109                if self.naming_convention:
110                    m1 = MetaData(naming_convention=self.naming_convention)
111                else:
112                    m1 = MetaData()
113
114                if self.copy_from is not None:
115                    existing_table = self.copy_from
116                    reflected = False
117                else:
118                    existing_table = Table(
119                        self.table_name,
120                        m1,
121                        schema=self.schema,
122                        autoload_with=self.operations.get_bind(),
123                        *self.reflect_args,
124                        **self.reflect_kwargs
125                    )
126                    reflected = True
127
128                batch_impl = ApplyBatchImpl(
129                    self.impl,
130                    existing_table,
131                    self.table_args,
132                    self.table_kwargs,
133                    reflected,
134                    partial_reordering=self.partial_reordering,
135                )
136                for opname, arg, kw in self.batch:
137                    fn = getattr(batch_impl, opname)
138                    fn(*arg, **kw)
139
140                batch_impl._create(self.impl)
141
142    def alter_column(self, *arg, **kw) -> None:
143        self.batch.append(("alter_column", arg, kw))
144
145    def add_column(self, *arg, **kw) -> None:
146        if (
147            "insert_before" in kw or "insert_after" in kw
148        ) and not self._should_recreate():
149            raise exc.CommandError(
150                "Can't specify insert_before or insert_after when using "
151                "ALTER; please specify recreate='always'"
152            )
153        self.batch.append(("add_column", arg, kw))
154
155    def drop_column(self, *arg, **kw) -> None:
156        self.batch.append(("drop_column", arg, kw))
157
158    def add_constraint(self, const: "Constraint") -> None:
159        self.batch.append(("add_constraint", (const,), {}))
160
161    def drop_constraint(self, const: "Constraint") -> None:
162        self.batch.append(("drop_constraint", (const,), {}))
163
164    def rename_table(self, *arg, **kw):
165        self.batch.append(("rename_table", arg, kw))
166
167    def create_index(self, idx: "Index") -> None:
168        self.batch.append(("create_index", (idx,), {}))
169
170    def drop_index(self, idx: "Index") -> None:
171        self.batch.append(("drop_index", (idx,), {}))
172
173    def create_table_comment(self, table):
174        self.batch.append(("create_table_comment", (table,), {}))
175
176    def drop_table_comment(self, table):
177        self.batch.append(("drop_table_comment", (table,), {}))
178
179    def create_table(self, table):
180        raise NotImplementedError("Can't create table in batch mode")
181
182    def drop_table(self, table):
183        raise NotImplementedError("Can't drop table in batch mode")
184
185    def create_column_comment(self, column):
186        self.batch.append(("create_column_comment", (column,), {}))
187
188
189class ApplyBatchImpl:
190    def __init__(
191        self,
192        impl: "DefaultImpl",
193        table: "Table",
194        table_args: tuple,
195        table_kwargs: Dict[str, Any],
196        reflected: bool,
197        partial_reordering: tuple = (),
198    ) -> None:
199        self.impl = impl
200        self.table = table  # this is a Table object
201        self.table_args = table_args
202        self.table_kwargs = table_kwargs
203        self.temp_table_name = self._calc_temp_name(table.name)
204        self.new_table: Optional[Table] = None
205
206        self.partial_reordering = partial_reordering  # tuple of tuples
207        self.add_col_ordering: Tuple[
208            Tuple[str, str], ...
209        ] = ()  # tuple of tuples
210
211        self.column_transfers = OrderedDict(
212            (c.name, {"expr": c}) for c in self.table.c
213        )
214        self.existing_ordering = list(self.column_transfers)
215
216        self.reflected = reflected
217        self._grab_table_elements()
218
219    @classmethod
220    def _calc_temp_name(cls, tablename: "quoted_name") -> str:
221        return ("_alembic_tmp_%s" % tablename)[0:50]
222
223    def _grab_table_elements(self) -> None:
224        schema = self.table.schema
225        self.columns: Dict[str, "Column"] = OrderedDict()
226        for c in self.table.c:
227            c_copy = _copy(c, schema=schema)
228            c_copy.unique = c_copy.index = False
229            # ensure that the type object was copied,
230            # as we may need to modify it in-place
231            if isinstance(c.type, SchemaEventTarget):
232                assert c_copy.type is not c.type
233            self.columns[c.name] = c_copy
234        self.named_constraints: Dict[str, "Constraint"] = {}
235        self.unnamed_constraints = []
236        self.col_named_constraints = {}
237        self.indexes: Dict[str, "Index"] = {}
238        self.new_indexes: Dict[str, "Index"] = {}
239
240        for const in self.table.constraints:
241            if _is_type_bound(const):
242                continue
243            elif (
244                self.reflected
245                and isinstance(const, CheckConstraint)
246                and not const.name
247            ):
248                # TODO: we are skipping unnamed reflected CheckConstraint
249                # because
250                # we have no way to determine _is_type_bound() for these.
251                pass
252            elif const.name:
253                self.named_constraints[const.name] = const
254            else:
255                self.unnamed_constraints.append(const)
256
257        if not self.reflected:
258            for col in self.table.c:
259                for const in col.constraints:
260                    if const.name:
261                        self.col_named_constraints[const.name] = (col, const)
262
263        for idx in self.table.indexes:
264            self.indexes[idx.name] = idx
265
266        for k in self.table.kwargs:
267            self.table_kwargs.setdefault(k, self.table.kwargs[k])
268
269    def _adjust_self_columns_for_partial_reordering(self) -> None:
270        pairs = set()
271
272        col_by_idx = list(self.columns)
273
274        if self.partial_reordering:
275            for tuple_ in self.partial_reordering:
276                for index, elem in enumerate(tuple_):
277                    if index > 0:
278                        pairs.add((tuple_[index - 1], elem))
279        else:
280            for index, elem in enumerate(self.existing_ordering):
281                if index > 0:
282                    pairs.add((col_by_idx[index - 1], elem))
283
284        pairs.update(self.add_col_ordering)
285
286        # this can happen if some columns were dropped and not removed
287        # from existing_ordering.  this should be prevented already, but
288        # conservatively making sure this didn't happen
289        pairs_list = [p for p in pairs if p[0] != p[1]]
290
291        sorted_ = list(
292            topological.sort(pairs_list, col_by_idx, deterministic_order=True)
293        )
294        self.columns = OrderedDict((k, self.columns[k]) for k in sorted_)
295        self.column_transfers = OrderedDict(
296            (k, self.column_transfers[k]) for k in sorted_
297        )
298
299    def _transfer_elements_to_new_table(self) -> None:
300        assert self.new_table is None, "Can only create new table once"
301
302        m = MetaData()
303        schema = self.table.schema
304
305        if self.partial_reordering or self.add_col_ordering:
306            self._adjust_self_columns_for_partial_reordering()
307
308        self.new_table = new_table = Table(
309            self.temp_table_name,
310            m,
311            *(list(self.columns.values()) + list(self.table_args)),
312            schema=schema,
313            **self.table_kwargs
314        )
315
316        for const in (
317            list(self.named_constraints.values()) + self.unnamed_constraints
318        ):
319
320            const_columns = set(
321                [c.key for c in _columns_for_constraint(const)]
322            )
323
324            if not const_columns.issubset(self.column_transfers):
325                continue
326
327            const_copy: "Constraint"
328            if isinstance(const, ForeignKeyConstraint):
329                if _fk_is_self_referential(const):
330                    # for self-referential constraint, refer to the
331                    # *original* table name, and not _alembic_batch_temp.
332                    # This is consistent with how we're handling
333                    # FK constraints from other tables; we assume SQLite
334                    # no foreign keys just keeps the names unchanged, so
335                    # when we rename back, they match again.
336                    const_copy = _copy(
337                        const, schema=schema, target_table=self.table
338                    )
339                else:
340                    # "target_table" for ForeignKeyConstraint.copy() is
341                    # only used if the FK is detected as being
342                    # self-referential, which we are handling above.
343                    const_copy = _copy(const, schema=schema)
344            else:
345                const_copy = _copy(
346                    const, schema=schema, target_table=new_table
347                )
348            if isinstance(const, ForeignKeyConstraint):
349                self._setup_referent(m, const)
350            new_table.append_constraint(const_copy)
351
352    def _gather_indexes_from_both_tables(self) -> List["Index"]:
353        assert self.new_table is not None
354        idx: List[Index] = []
355        idx.extend(self.indexes.values())
356        for index in self.new_indexes.values():
357            idx.append(
358                Index(
359                    index.name,
360                    unique=index.unique,
361                    *[self.new_table.c[col] for col in index.columns.keys()],
362                    **index.kwargs
363                )
364            )
365        return idx
366
367    def _setup_referent(
368        self, metadata: "MetaData", constraint: "ForeignKeyConstraint"
369    ) -> None:
370        spec = constraint.elements[
371            0
372        ]._get_colspec()  # type:ignore[attr-defined]
373        parts = spec.split(".")
374        tname = parts[-2]
375        if len(parts) == 3:
376            referent_schema = parts[0]
377        else:
378            referent_schema = None
379
380        if tname != self.temp_table_name:
381            key = sql_schema._get_table_key(tname, referent_schema)
382
383            def colspec(elem: Any):
384                return elem._get_colspec()
385
386            if key in metadata.tables:
387                t = metadata.tables[key]
388                for elem in constraint.elements:
389                    colname = colspec(elem).split(".")[-1]
390                    if colname not in t.c:
391                        t.append_column(Column(colname, sqltypes.NULLTYPE))
392            else:
393                Table(
394                    tname,
395                    metadata,
396                    *[
397                        Column(n, sqltypes.NULLTYPE)
398                        for n in [
399                            colspec(elem).split(".")[-1]
400                            for elem in constraint.elements
401                        ]
402                    ],
403                    schema=referent_schema
404                )
405
406    def _create(self, op_impl: "DefaultImpl") -> None:
407        self._transfer_elements_to_new_table()
408
409        op_impl.prep_table_for_batch(self, self.table)
410        assert self.new_table is not None
411        op_impl.create_table(self.new_table)
412
413        try:
414            op_impl._exec(
415                _insert_inline(self.new_table).from_select(
416                    list(
417                        k
418                        for k, transfer in self.column_transfers.items()
419                        if "expr" in transfer
420                    ),
421                    _select(
422                        *[
423                            transfer["expr"]
424                            for transfer in self.column_transfers.values()
425                            if "expr" in transfer
426                        ]
427                    ),
428                )
429            )
430            op_impl.drop_table(self.table)
431        except:
432            op_impl.drop_table(self.new_table)
433            raise
434        else:
435            op_impl.rename_table(
436                self.temp_table_name, self.table.name, schema=self.table.schema
437            )
438            self.new_table.name = self.table.name
439            try:
440                for idx in self._gather_indexes_from_both_tables():
441                    op_impl.create_index(idx)
442            finally:
443                self.new_table.name = self.temp_table_name
444
445    def alter_column(
446        self,
447        table_name: str,
448        column_name: str,
449        nullable: Optional[bool] = None,
450        server_default: Optional[Union["Function", str, bool]] = False,
451        name: Optional[str] = None,
452        type_: Optional["TypeEngine"] = None,
453        autoincrement: None = None,
454        comment: Union[str, "Literal[False]"] = False,
455        **kw
456    ) -> None:
457        existing = self.columns[column_name]
458        existing_transfer: Dict[str, Any] = self.column_transfers[column_name]
459        if name is not None and name != column_name:
460            # note that we don't change '.key' - we keep referring
461            # to the renamed column by its old key in _create().  neat!
462            existing.name = name
463            existing_transfer["name"] = name
464
465            # pop named constraints for Boolean/Enum for rename
466            if (
467                "existing_type" in kw
468                and isinstance(kw["existing_type"], SchemaEventTarget)
469                and kw["existing_type"].name  # type:ignore[attr-defined]
470            ):
471                self.named_constraints.pop(
472                    kw["existing_type"].name, None  # type:ignore[attr-defined]
473                )
474
475        if type_ is not None:
476            type_ = sqltypes.to_instance(type_)
477            # old type is being discarded so turn off eventing
478            # rules. Alternatively we can
479            # erase the events set up by this type, but this is simpler.
480            # we also ignore the drop_constraint that will come here from
481            # Operations.implementation_for(alter_column)
482
483            if isinstance(existing.type, SchemaEventTarget):
484                existing.type._create_events = (  # type:ignore[attr-defined]
485                    existing.type.create_constraint  # type:ignore[attr-defined] # noqa
486                ) = False
487
488            self.impl.cast_for_batch_migrate(
489                existing, existing_transfer, type_
490            )
491
492            existing.type = type_
493
494            # we *dont* however set events for the new type, because
495            # alter_column is invoked from
496            # Operations.implementation_for(alter_column) which already
497            # will emit an add_constraint()
498
499        if nullable is not None:
500            existing.nullable = nullable
501        if server_default is not False:
502            if server_default is None:
503                existing.server_default = None
504            else:
505                sql_schema.DefaultClause(
506                    server_default
507                )._set_parent(  # type:ignore[attr-defined]
508                    existing
509                )
510        if autoincrement is not None:
511            existing.autoincrement = bool(autoincrement)
512
513        if comment is not False:
514            existing.comment = comment
515
516    def _setup_dependencies_for_add_column(
517        self,
518        colname: str,
519        insert_before: Optional[str],
520        insert_after: Optional[str],
521    ) -> None:
522        index_cols = self.existing_ordering
523        col_indexes = {name: i for i, name in enumerate(index_cols)}
524
525        if not self.partial_reordering:
526            if insert_after:
527                if not insert_before:
528                    if insert_after in col_indexes:
529                        # insert after an existing column
530                        idx = col_indexes[insert_after] + 1
531                        if idx < len(index_cols):
532                            insert_before = index_cols[idx]
533                    else:
534                        # insert after a column that is also new
535                        insert_before = dict(self.add_col_ordering)[
536                            insert_after
537                        ]
538            if insert_before:
539                if not insert_after:
540                    if insert_before in col_indexes:
541                        # insert before an existing column
542                        idx = col_indexes[insert_before] - 1
543                        if idx >= 0:
544                            insert_after = index_cols[idx]
545                    else:
546                        # insert before a column that is also new
547                        insert_after = dict(
548                            (b, a) for a, b in self.add_col_ordering
549                        )[insert_before]
550
551        if insert_before:
552            self.add_col_ordering += ((colname, insert_before),)
553        if insert_after:
554            self.add_col_ordering += ((insert_after, colname),)
555
556        if (
557            not self.partial_reordering
558            and not insert_before
559            and not insert_after
560            and col_indexes
561        ):
562            self.add_col_ordering += ((index_cols[-1], colname),)
563
564    def add_column(
565        self,
566        table_name: str,
567        column: "Column",
568        insert_before: Optional[str] = None,
569        insert_after: Optional[str] = None,
570        **kw
571    ) -> None:
572        self._setup_dependencies_for_add_column(
573            column.name, insert_before, insert_after
574        )
575        # we copy the column because operations.add_column()
576        # gives us a Column that is part of a Table already.
577        self.columns[column.name] = _copy(column, schema=self.table.schema)
578        self.column_transfers[column.name] = {}
579
580    def drop_column(
581        self, table_name: str, column: Union["ColumnClause", "Column"], **kw
582    ) -> None:
583        if column.name in self.table.primary_key.columns:
584            _remove_column_from_collection(
585                self.table.primary_key.columns, column
586            )
587        del self.columns[column.name]
588        del self.column_transfers[column.name]
589        self.existing_ordering.remove(column.name)
590
591        # pop named constraints for Boolean/Enum for rename
592        if (
593            "existing_type" in kw
594            and isinstance(kw["existing_type"], SchemaEventTarget)
595            and kw["existing_type"].name  # type:ignore[attr-defined]
596        ):
597            self.named_constraints.pop(
598                kw["existing_type"].name, None  # type:ignore[attr-defined]
599            )
600
601    def create_column_comment(self, column):
602        """the batch table creation function will issue create_column_comment
603        on the real "impl" as part of the create table process.
604
605        That is, the Column object will have the comment on it already,
606        so when it is received by add_column() it will be a normal part of
607        the CREATE TABLE and doesn't need an extra step here.
608
609        """
610
611    def create_table_comment(self, table):
612        """the batch table creation function will issue create_table_comment
613        on the real "impl" as part of the create table process.
614
615        """
616
617    def drop_table_comment(self, table):
618        """the batch table creation function will issue drop_table_comment
619        on the real "impl" as part of the create table process.
620
621        """
622
623    def add_constraint(self, const: "Constraint") -> None:
624        if not const.name:
625            raise ValueError("Constraint must have a name")
626        if isinstance(const, sql_schema.PrimaryKeyConstraint):
627            if self.table.primary_key in self.unnamed_constraints:
628                self.unnamed_constraints.remove(self.table.primary_key)
629
630        self.named_constraints[const.name] = const
631
632    def drop_constraint(self, const: "Constraint") -> None:
633        if not const.name:
634            raise ValueError("Constraint must have a name")
635        try:
636            if const.name in self.col_named_constraints:
637                col, const = self.col_named_constraints.pop(const.name)
638
639                for col_const in list(self.columns[col.name].constraints):
640                    if col_const.name == const.name:
641                        self.columns[col.name].constraints.remove(col_const)
642            else:
643                const = self.named_constraints.pop(cast(str, const.name))
644        except KeyError:
645            if _is_type_bound(const):
646                # type-bound constraints are only included in the new
647                # table via their type object in any case, so ignore the
648                # drop_constraint() that comes here via the
649                # Operations.implementation_for(alter_column)
650                return
651            raise ValueError("No such constraint: '%s'" % const.name)
652        else:
653            if isinstance(const, PrimaryKeyConstraint):
654                for col in const.columns:
655                    self.columns[col.name].primary_key = False
656
657    def create_index(self, idx: "Index") -> None:
658        self.new_indexes[idx.name] = idx
659
660    def drop_index(self, idx: "Index") -> None:
661        try:
662            del self.indexes[idx.name]
663        except KeyError:
664            raise ValueError("No such index: '%s'" % idx.name)
665
666    def rename_table(self, *arg, **kw):
667        raise NotImplementedError("TODO")
668