1"""Provide the 'autogenerate' feature which can produce migration operations
2automatically."""
3
4import contextlib
5from typing import Any
6from typing import Callable
7from typing import Dict
8from typing import Iterator
9from typing import Optional
10from typing import Set
11from typing import Tuple
12from typing import TYPE_CHECKING
13from typing import Union
14
15from sqlalchemy import inspect
16
17from . import compare
18from . import render
19from .. import util
20from ..operations import ops
21
22if TYPE_CHECKING:
23    from sqlalchemy.engine import Connection
24    from sqlalchemy.engine import Dialect
25    from sqlalchemy.engine import Inspector
26    from sqlalchemy.sql.schema import Column
27    from sqlalchemy.sql.schema import ForeignKeyConstraint
28    from sqlalchemy.sql.schema import Index
29    from sqlalchemy.sql.schema import MetaData
30    from sqlalchemy.sql.schema import Table
31    from sqlalchemy.sql.schema import UniqueConstraint
32
33    from alembic.config import Config
34    from alembic.operations.ops import MigrationScript
35    from alembic.operations.ops import UpgradeOps
36    from alembic.runtime.migration import MigrationContext
37    from alembic.script.base import Script
38    from alembic.script.base import ScriptDirectory
39
40
41def compare_metadata(context: "MigrationContext", metadata: "MetaData") -> Any:
42    """Compare a database schema to that given in a
43    :class:`~sqlalchemy.schema.MetaData` instance.
44
45    The database connection is presented in the context
46    of a :class:`.MigrationContext` object, which
47    provides database connectivity as well as optional
48    comparison functions to use for datatypes and
49    server defaults - see the "autogenerate" arguments
50    at :meth:`.EnvironmentContext.configure`
51    for details on these.
52
53    The return format is a list of "diff" directives,
54    each representing individual differences::
55
56        from alembic.migration import MigrationContext
57        from alembic.autogenerate import compare_metadata
58        from sqlalchemy.schema import SchemaItem
59        from sqlalchemy.types import TypeEngine
60        from sqlalchemy import (create_engine, MetaData, Column,
61                Integer, String, Table, text)
62        import pprint
63
64        engine = create_engine("sqlite://")
65
66        with engine.begin() as conn:
67            conn.execute(text('''
68                create table foo (
69                    id integer not null primary key,
70                    old_data varchar,
71                    x integer
72                )'''))
73
74            conn.execute(text('''
75                create table bar (
76                    data varchar
77                )'''))
78
79        metadata = MetaData()
80        Table('foo', metadata,
81            Column('id', Integer, primary_key=True),
82            Column('data', Integer),
83            Column('x', Integer, nullable=False)
84        )
85        Table('bat', metadata,
86            Column('info', String)
87        )
88
89        mc = MigrationContext.configure(engine.connect())
90
91        diff = compare_metadata(mc, metadata)
92        pprint.pprint(diff, indent=2, width=20)
93
94    Output::
95
96        [ ( 'add_table',
97            Table('bat', MetaData(bind=None),
98                Column('info', String(), table=<bat>), schema=None)),
99          ( 'remove_table',
100            Table(u'bar', MetaData(bind=None),
101                Column(u'data', VARCHAR(), table=<bar>), schema=None)),
102          ( 'add_column',
103            None,
104            'foo',
105            Column('data', Integer(), table=<foo>)),
106          ( 'remove_column',
107            None,
108            'foo',
109            Column(u'old_data', VARCHAR(), table=None)),
110          [ ( 'modify_nullable',
111              None,
112              'foo',
113              u'x',
114              { 'existing_server_default': None,
115                'existing_type': INTEGER()},
116              True,
117              False)]]
118
119
120    :param context: a :class:`.MigrationContext`
121     instance.
122    :param metadata: a :class:`~sqlalchemy.schema.MetaData`
123     instance.
124
125    .. seealso::
126
127        :func:`.produce_migrations` - produces a :class:`.MigrationScript`
128        structure based on metadata comparison.
129
130    """
131
132    migration_script = produce_migrations(context, metadata)
133    return migration_script.upgrade_ops.as_diffs()
134
135
136def produce_migrations(
137    context: "MigrationContext", metadata: "MetaData"
138) -> "MigrationScript":
139    """Produce a :class:`.MigrationScript` structure based on schema
140    comparison.
141
142    This function does essentially what :func:`.compare_metadata` does,
143    but then runs the resulting list of diffs to produce the full
144    :class:`.MigrationScript` object.   For an example of what this looks like,
145    see the example in :ref:`customizing_revision`.
146
147    .. seealso::
148
149        :func:`.compare_metadata` - returns more fundamental "diff"
150        data from comparing a schema.
151
152    """
153
154    autogen_context = AutogenContext(context, metadata=metadata)
155
156    migration_script = ops.MigrationScript(
157        rev_id=None,
158        upgrade_ops=ops.UpgradeOps([]),
159        downgrade_ops=ops.DowngradeOps([]),
160    )
161
162    compare._populate_migration_script(autogen_context, migration_script)
163
164    return migration_script
165
166
167def render_python_code(
168    up_or_down_op: "UpgradeOps",
169    sqlalchemy_module_prefix: str = "sa.",
170    alembic_module_prefix: str = "op.",
171    render_as_batch: bool = False,
172    imports: Tuple[str, ...] = (),
173    render_item: None = None,
174    migration_context: Optional["MigrationContext"] = None,
175) -> str:
176    """Render Python code given an :class:`.UpgradeOps` or
177    :class:`.DowngradeOps` object.
178
179    This is a convenience function that can be used to test the
180    autogenerate output of a user-defined :class:`.MigrationScript` structure.
181
182    """
183    opts = {
184        "sqlalchemy_module_prefix": sqlalchemy_module_prefix,
185        "alembic_module_prefix": alembic_module_prefix,
186        "render_item": render_item,
187        "render_as_batch": render_as_batch,
188    }
189
190    if migration_context is None:
191        from ..runtime.migration import MigrationContext
192        from sqlalchemy.engine.default import DefaultDialect
193
194        migration_context = MigrationContext.configure(
195            dialect=DefaultDialect()
196        )
197
198    autogen_context = AutogenContext(migration_context, opts=opts)
199    autogen_context.imports = set(imports)
200    return render._indent(
201        render._render_cmd_body(up_or_down_op, autogen_context)
202    )
203
204
205def _render_migration_diffs(
206    context: "MigrationContext", template_args: Dict[Any, Any]
207) -> None:
208    """legacy, used by test_autogen_composition at the moment"""
209
210    autogen_context = AutogenContext(context)
211
212    upgrade_ops = ops.UpgradeOps([])
213    compare._produce_net_changes(autogen_context, upgrade_ops)
214
215    migration_script = ops.MigrationScript(
216        rev_id=None,
217        upgrade_ops=upgrade_ops,
218        downgrade_ops=upgrade_ops.reverse(),
219    )
220
221    render._render_python_into_templatevars(
222        autogen_context, migration_script, template_args
223    )
224
225
226class AutogenContext:
227    """Maintains configuration and state that's specific to an
228    autogenerate operation."""
229
230    metadata: Optional["MetaData"] = None
231    """The :class:`~sqlalchemy.schema.MetaData` object
232    representing the destination.
233
234    This object is the one that is passed within ``env.py``
235    to the :paramref:`.EnvironmentContext.configure.target_metadata`
236    parameter.  It represents the structure of :class:`.Table` and other
237    objects as stated in the current database model, and represents the
238    destination structure for the database being examined.
239
240    While the :class:`~sqlalchemy.schema.MetaData` object is primarily
241    known as a collection of :class:`~sqlalchemy.schema.Table` objects,
242    it also has an :attr:`~sqlalchemy.schema.MetaData.info` dictionary
243    that may be used by end-user schemes to store additional schema-level
244    objects that are to be compared in custom autogeneration schemes.
245
246    """
247
248    connection: Optional["Connection"] = None
249    """The :class:`~sqlalchemy.engine.base.Connection` object currently
250    connected to the database backend being compared.
251
252    This is obtained from the :attr:`.MigrationContext.bind` and is
253    ultimately set up in the ``env.py`` script.
254
255    """
256
257    dialect: Optional["Dialect"] = None
258    """The :class:`~sqlalchemy.engine.Dialect` object currently in use.
259
260    This is normally obtained from the
261    :attr:`~sqlalchemy.engine.base.Connection.dialect` attribute.
262
263    """
264
265    imports: Set[str] = None  # type: ignore[assignment]
266    """A ``set()`` which contains string Python import directives.
267
268    The directives are to be rendered into the ``${imports}`` section
269    of a script template.  The set is normally empty and can be modified
270    within hooks such as the
271    :paramref:`.EnvironmentContext.configure.render_item` hook.
272
273    .. seealso::
274
275        :ref:`autogen_render_types`
276
277    """
278
279    migration_context: "MigrationContext" = None  # type: ignore[assignment]
280    """The :class:`.MigrationContext` established by the ``env.py`` script."""
281
282    def __init__(
283        self,
284        migration_context: "MigrationContext",
285        metadata: Optional["MetaData"] = None,
286        opts: Optional[dict] = None,
287        autogenerate: bool = True,
288    ) -> None:
289
290        if (
291            autogenerate
292            and migration_context is not None
293            and migration_context.as_sql
294        ):
295            raise util.CommandError(
296                "autogenerate can't use as_sql=True as it prevents querying "
297                "the database for schema information"
298            )
299
300        if opts is None:
301            opts = migration_context.opts
302
303        self.metadata = metadata = (
304            opts.get("target_metadata", None) if metadata is None else metadata
305        )
306
307        if (
308            autogenerate
309            and metadata is None
310            and migration_context is not None
311            and migration_context.script is not None
312        ):
313            raise util.CommandError(
314                "Can't proceed with --autogenerate option; environment "
315                "script %s does not provide "
316                "a MetaData object or sequence of objects to the context."
317                % (migration_context.script.env_py_location)
318            )
319
320        include_object = opts.get("include_object", None)
321        include_name = opts.get("include_name", None)
322
323        object_filters = []
324        name_filters = []
325        if include_object:
326            object_filters.append(include_object)
327        if include_name:
328            name_filters.append(include_name)
329
330        self._object_filters = object_filters
331        self._name_filters = name_filters
332
333        self.migration_context = migration_context
334        if self.migration_context is not None:
335            self.connection = self.migration_context.bind
336            self.dialect = self.migration_context.dialect
337
338        self.imports = set()
339        self.opts: Dict[str, Any] = opts
340        self._has_batch: bool = False
341
342    @util.memoized_property
343    def inspector(self) -> "Inspector":
344        if self.connection is None:
345            raise TypeError(
346                "can't return inspector as this "
347                "AutogenContext has no database connection"
348            )
349        return inspect(self.connection)
350
351    @contextlib.contextmanager
352    def _within_batch(self) -> Iterator[None]:
353        self._has_batch = True
354        yield
355        self._has_batch = False
356
357    def run_name_filters(
358        self,
359        name: Optional[str],
360        type_: str,
361        parent_names: Dict[str, Optional[str]],
362    ) -> bool:
363        """Run the context's name filters and return True if the targets
364        should be part of the autogenerate operation.
365
366        This method should be run for every kind of name encountered within the
367        reflection side of an autogenerate operation, giving the environment
368        the chance to filter what names should be reflected as database
369        objects.  The filters here are produced directly via the
370        :paramref:`.EnvironmentContext.configure.include_name` parameter.
371
372        """
373
374        if "schema_name" in parent_names:
375            if type_ == "table":
376                table_name = name
377            else:
378                table_name = parent_names.get("table_name", None)
379            if table_name:
380                schema_name = parent_names["schema_name"]
381                if schema_name:
382                    parent_names["schema_qualified_table_name"] = "%s.%s" % (
383                        schema_name,
384                        table_name,
385                    )
386                else:
387                    parent_names["schema_qualified_table_name"] = table_name
388
389        for fn in self._name_filters:
390
391            if not fn(name, type_, parent_names):
392                return False
393        else:
394            return True
395
396    def run_object_filters(
397        self,
398        object_: Union[
399            "Table",
400            "Index",
401            "Column",
402            "UniqueConstraint",
403            "ForeignKeyConstraint",
404        ],
405        name: Optional[str],
406        type_: str,
407        reflected: bool,
408        compare_to: Optional[
409            Union["Table", "Index", "Column", "UniqueConstraint"]
410        ],
411    ) -> bool:
412        """Run the context's object filters and return True if the targets
413        should be part of the autogenerate operation.
414
415        This method should be run for every kind of object encountered within
416        an autogenerate operation, giving the environment the chance
417        to filter what objects should be included in the comparison.
418        The filters here are produced directly via the
419        :paramref:`.EnvironmentContext.configure.include_object` parameter.
420
421        """
422        for fn in self._object_filters:
423            if not fn(object_, name, type_, reflected, compare_to):
424                return False
425        else:
426            return True
427
428    run_filters = run_object_filters
429
430    @util.memoized_property
431    def sorted_tables(self):
432        """Return an aggregate of the :attr:`.MetaData.sorted_tables` collection(s).
433
434        For a sequence of :class:`.MetaData` objects, this
435        concatenates the :attr:`.MetaData.sorted_tables` collection
436        for each individual :class:`.MetaData`  in the order of the
437        sequence.  It does **not** collate the sorted tables collections.
438
439        """
440        result = []
441        for m in util.to_list(self.metadata):
442            result.extend(m.sorted_tables)
443        return result
444
445    @util.memoized_property
446    def table_key_to_table(self):
447        """Return an aggregate  of the :attr:`.MetaData.tables` dictionaries.
448
449        The :attr:`.MetaData.tables` collection is a dictionary of table key
450        to :class:`.Table`; this method aggregates the dictionary across
451        multiple :class:`.MetaData` objects into one dictionary.
452
453        Duplicate table keys are **not** supported; if two :class:`.MetaData`
454        objects contain the same table key, an exception is raised.
455
456        """
457        result = {}
458        for m in util.to_list(self.metadata):
459            intersect = set(result).intersection(set(m.tables))
460            if intersect:
461                raise ValueError(
462                    "Duplicate table keys across multiple "
463                    "MetaData objects: %s"
464                    % (", ".join('"%s"' % key for key in sorted(intersect)))
465                )
466
467            result.update(m.tables)
468        return result
469
470
471class RevisionContext:
472    """Maintains configuration and state that's specific to a revision
473    file generation operation."""
474
475    def __init__(
476        self,
477        config: "Config",
478        script_directory: "ScriptDirectory",
479        command_args: Dict[str, Any],
480        process_revision_directives: Optional[Callable] = None,
481    ) -> None:
482        self.config = config
483        self.script_directory = script_directory
484        self.command_args = command_args
485        self.process_revision_directives = process_revision_directives
486        self.template_args = {
487            "config": config  # Let templates use config for
488            # e.g. multiple databases
489        }
490        self.generated_revisions = [self._default_revision()]
491
492    def _to_script(
493        self, migration_script: "MigrationScript"
494    ) -> Optional["Script"]:
495        template_args: Dict[str, Any] = self.template_args.copy()
496
497        if getattr(migration_script, "_needs_render", False):
498            autogen_context = self._last_autogen_context
499
500            # clear out existing imports if we are doing multiple
501            # renders
502            autogen_context.imports = set()
503            if migration_script.imports:
504                autogen_context.imports.update(migration_script.imports)
505            render._render_python_into_templatevars(
506                autogen_context, migration_script, template_args
507            )
508
509        assert migration_script.rev_id is not None
510        return self.script_directory.generate_revision(
511            migration_script.rev_id,
512            migration_script.message,
513            refresh=True,
514            head=migration_script.head,
515            splice=migration_script.splice,
516            branch_labels=migration_script.branch_label,
517            version_path=migration_script.version_path,
518            depends_on=migration_script.depends_on,
519            **template_args
520        )
521
522    def run_autogenerate(
523        self, rev: tuple, migration_context: "MigrationContext"
524    ):
525        self._run_environment(rev, migration_context, True)
526
527    def run_no_autogenerate(
528        self, rev: tuple, migration_context: "MigrationContext"
529    ):
530        self._run_environment(rev, migration_context, False)
531
532    def _run_environment(
533        self,
534        rev: tuple,
535        migration_context: "MigrationContext",
536        autogenerate: bool,
537    ):
538        if autogenerate:
539            if self.command_args["sql"]:
540                raise util.CommandError(
541                    "Using --sql with --autogenerate does not make any sense"
542                )
543            if set(self.script_directory.get_revisions(rev)) != set(
544                self.script_directory.get_revisions("heads")
545            ):
546                raise util.CommandError("Target database is not up to date.")
547
548        upgrade_token = migration_context.opts["upgrade_token"]
549        downgrade_token = migration_context.opts["downgrade_token"]
550
551        migration_script = self.generated_revisions[-1]
552        if not getattr(migration_script, "_needs_render", False):
553            migration_script.upgrade_ops_list[-1].upgrade_token = upgrade_token
554            migration_script.downgrade_ops_list[
555                -1
556            ].downgrade_token = downgrade_token
557            migration_script._needs_render = True
558        else:
559            migration_script._upgrade_ops.append(
560                ops.UpgradeOps([], upgrade_token=upgrade_token)
561            )
562            migration_script._downgrade_ops.append(
563                ops.DowngradeOps([], downgrade_token=downgrade_token)
564            )
565
566        autogen_context = AutogenContext(
567            migration_context, autogenerate=autogenerate
568        )
569        self._last_autogen_context: AutogenContext = autogen_context
570
571        if autogenerate:
572            compare._populate_migration_script(
573                autogen_context, migration_script
574            )
575
576        if self.process_revision_directives:
577            self.process_revision_directives(
578                migration_context, rev, self.generated_revisions
579            )
580
581        hook = migration_context.opts["process_revision_directives"]
582        if hook:
583            hook(migration_context, rev, self.generated_revisions)
584
585        for migration_script in self.generated_revisions:
586            migration_script._needs_render = True
587
588    def _default_revision(self) -> "MigrationScript":
589        command_args: Dict[str, Any] = self.command_args
590        op = ops.MigrationScript(
591            rev_id=command_args["rev_id"] or util.rev_id(),
592            message=command_args["message"],
593            upgrade_ops=ops.UpgradeOps([]),
594            downgrade_ops=ops.DowngradeOps([]),
595            head=command_args["head"],
596            splice=command_args["splice"],
597            branch_label=command_args["branch_label"],
598            version_path=command_args["version_path"],
599            depends_on=command_args["depends_on"],
600        )
601        return op
602
603    def generate_scripts(self) -> Iterator[Optional["Script"]]:
604        for generated_revision in self.generated_revisions:
605            yield self._to_script(generated_revision)
606