1"""Provide the 'autogenerate' feature which can produce migration operations 2automatically.""" 3 4import contextlib 5from typing import Any 6from typing import Callable 7from typing import Dict 8from typing import Iterator 9from typing import Optional 10from typing import Set 11from typing import Tuple 12from typing import TYPE_CHECKING 13from typing import Union 14 15from sqlalchemy import inspect 16 17from . import compare 18from . import render 19from .. import util 20from ..operations import ops 21 22if TYPE_CHECKING: 23 from sqlalchemy.engine import Connection 24 from sqlalchemy.engine import Dialect 25 from sqlalchemy.engine import Inspector 26 from sqlalchemy.sql.schema import Column 27 from sqlalchemy.sql.schema import ForeignKeyConstraint 28 from sqlalchemy.sql.schema import Index 29 from sqlalchemy.sql.schema import MetaData 30 from sqlalchemy.sql.schema import Table 31 from sqlalchemy.sql.schema import UniqueConstraint 32 33 from alembic.config import Config 34 from alembic.operations.ops import MigrationScript 35 from alembic.operations.ops import UpgradeOps 36 from alembic.runtime.migration import MigrationContext 37 from alembic.script.base import Script 38 from alembic.script.base import ScriptDirectory 39 40 41def compare_metadata(context: "MigrationContext", metadata: "MetaData") -> Any: 42 """Compare a database schema to that given in a 43 :class:`~sqlalchemy.schema.MetaData` instance. 44 45 The database connection is presented in the context 46 of a :class:`.MigrationContext` object, which 47 provides database connectivity as well as optional 48 comparison functions to use for datatypes and 49 server defaults - see the "autogenerate" arguments 50 at :meth:`.EnvironmentContext.configure` 51 for details on these. 52 53 The return format is a list of "diff" directives, 54 each representing individual differences:: 55 56 from alembic.migration import MigrationContext 57 from alembic.autogenerate import compare_metadata 58 from sqlalchemy.schema import SchemaItem 59 from sqlalchemy.types import TypeEngine 60 from sqlalchemy import (create_engine, MetaData, Column, 61 Integer, String, Table, text) 62 import pprint 63 64 engine = create_engine("sqlite://") 65 66 with engine.begin() as conn: 67 conn.execute(text(''' 68 create table foo ( 69 id integer not null primary key, 70 old_data varchar, 71 x integer 72 )''')) 73 74 conn.execute(text(''' 75 create table bar ( 76 data varchar 77 )''')) 78 79 metadata = MetaData() 80 Table('foo', metadata, 81 Column('id', Integer, primary_key=True), 82 Column('data', Integer), 83 Column('x', Integer, nullable=False) 84 ) 85 Table('bat', metadata, 86 Column('info', String) 87 ) 88 89 mc = MigrationContext.configure(engine.connect()) 90 91 diff = compare_metadata(mc, metadata) 92 pprint.pprint(diff, indent=2, width=20) 93 94 Output:: 95 96 [ ( 'add_table', 97 Table('bat', MetaData(bind=None), 98 Column('info', String(), table=<bat>), schema=None)), 99 ( 'remove_table', 100 Table(u'bar', MetaData(bind=None), 101 Column(u'data', VARCHAR(), table=<bar>), schema=None)), 102 ( 'add_column', 103 None, 104 'foo', 105 Column('data', Integer(), table=<foo>)), 106 ( 'remove_column', 107 None, 108 'foo', 109 Column(u'old_data', VARCHAR(), table=None)), 110 [ ( 'modify_nullable', 111 None, 112 'foo', 113 u'x', 114 { 'existing_server_default': None, 115 'existing_type': INTEGER()}, 116 True, 117 False)]] 118 119 120 :param context: a :class:`.MigrationContext` 121 instance. 122 :param metadata: a :class:`~sqlalchemy.schema.MetaData` 123 instance. 124 125 .. seealso:: 126 127 :func:`.produce_migrations` - produces a :class:`.MigrationScript` 128 structure based on metadata comparison. 129 130 """ 131 132 migration_script = produce_migrations(context, metadata) 133 return migration_script.upgrade_ops.as_diffs() 134 135 136def produce_migrations( 137 context: "MigrationContext", metadata: "MetaData" 138) -> "MigrationScript": 139 """Produce a :class:`.MigrationScript` structure based on schema 140 comparison. 141 142 This function does essentially what :func:`.compare_metadata` does, 143 but then runs the resulting list of diffs to produce the full 144 :class:`.MigrationScript` object. For an example of what this looks like, 145 see the example in :ref:`customizing_revision`. 146 147 .. seealso:: 148 149 :func:`.compare_metadata` - returns more fundamental "diff" 150 data from comparing a schema. 151 152 """ 153 154 autogen_context = AutogenContext(context, metadata=metadata) 155 156 migration_script = ops.MigrationScript( 157 rev_id=None, 158 upgrade_ops=ops.UpgradeOps([]), 159 downgrade_ops=ops.DowngradeOps([]), 160 ) 161 162 compare._populate_migration_script(autogen_context, migration_script) 163 164 return migration_script 165 166 167def render_python_code( 168 up_or_down_op: "UpgradeOps", 169 sqlalchemy_module_prefix: str = "sa.", 170 alembic_module_prefix: str = "op.", 171 render_as_batch: bool = False, 172 imports: Tuple[str, ...] = (), 173 render_item: None = None, 174 migration_context: Optional["MigrationContext"] = None, 175) -> str: 176 """Render Python code given an :class:`.UpgradeOps` or 177 :class:`.DowngradeOps` object. 178 179 This is a convenience function that can be used to test the 180 autogenerate output of a user-defined :class:`.MigrationScript` structure. 181 182 """ 183 opts = { 184 "sqlalchemy_module_prefix": sqlalchemy_module_prefix, 185 "alembic_module_prefix": alembic_module_prefix, 186 "render_item": render_item, 187 "render_as_batch": render_as_batch, 188 } 189 190 if migration_context is None: 191 from ..runtime.migration import MigrationContext 192 from sqlalchemy.engine.default import DefaultDialect 193 194 migration_context = MigrationContext.configure( 195 dialect=DefaultDialect() 196 ) 197 198 autogen_context = AutogenContext(migration_context, opts=opts) 199 autogen_context.imports = set(imports) 200 return render._indent( 201 render._render_cmd_body(up_or_down_op, autogen_context) 202 ) 203 204 205def _render_migration_diffs( 206 context: "MigrationContext", template_args: Dict[Any, Any] 207) -> None: 208 """legacy, used by test_autogen_composition at the moment""" 209 210 autogen_context = AutogenContext(context) 211 212 upgrade_ops = ops.UpgradeOps([]) 213 compare._produce_net_changes(autogen_context, upgrade_ops) 214 215 migration_script = ops.MigrationScript( 216 rev_id=None, 217 upgrade_ops=upgrade_ops, 218 downgrade_ops=upgrade_ops.reverse(), 219 ) 220 221 render._render_python_into_templatevars( 222 autogen_context, migration_script, template_args 223 ) 224 225 226class AutogenContext: 227 """Maintains configuration and state that's specific to an 228 autogenerate operation.""" 229 230 metadata: Optional["MetaData"] = None 231 """The :class:`~sqlalchemy.schema.MetaData` object 232 representing the destination. 233 234 This object is the one that is passed within ``env.py`` 235 to the :paramref:`.EnvironmentContext.configure.target_metadata` 236 parameter. It represents the structure of :class:`.Table` and other 237 objects as stated in the current database model, and represents the 238 destination structure for the database being examined. 239 240 While the :class:`~sqlalchemy.schema.MetaData` object is primarily 241 known as a collection of :class:`~sqlalchemy.schema.Table` objects, 242 it also has an :attr:`~sqlalchemy.schema.MetaData.info` dictionary 243 that may be used by end-user schemes to store additional schema-level 244 objects that are to be compared in custom autogeneration schemes. 245 246 """ 247 248 connection: Optional["Connection"] = None 249 """The :class:`~sqlalchemy.engine.base.Connection` object currently 250 connected to the database backend being compared. 251 252 This is obtained from the :attr:`.MigrationContext.bind` and is 253 ultimately set up in the ``env.py`` script. 254 255 """ 256 257 dialect: Optional["Dialect"] = None 258 """The :class:`~sqlalchemy.engine.Dialect` object currently in use. 259 260 This is normally obtained from the 261 :attr:`~sqlalchemy.engine.base.Connection.dialect` attribute. 262 263 """ 264 265 imports: Set[str] = None # type: ignore[assignment] 266 """A ``set()`` which contains string Python import directives. 267 268 The directives are to be rendered into the ``${imports}`` section 269 of a script template. The set is normally empty and can be modified 270 within hooks such as the 271 :paramref:`.EnvironmentContext.configure.render_item` hook. 272 273 .. seealso:: 274 275 :ref:`autogen_render_types` 276 277 """ 278 279 migration_context: "MigrationContext" = None # type: ignore[assignment] 280 """The :class:`.MigrationContext` established by the ``env.py`` script.""" 281 282 def __init__( 283 self, 284 migration_context: "MigrationContext", 285 metadata: Optional["MetaData"] = None, 286 opts: Optional[dict] = None, 287 autogenerate: bool = True, 288 ) -> None: 289 290 if ( 291 autogenerate 292 and migration_context is not None 293 and migration_context.as_sql 294 ): 295 raise util.CommandError( 296 "autogenerate can't use as_sql=True as it prevents querying " 297 "the database for schema information" 298 ) 299 300 if opts is None: 301 opts = migration_context.opts 302 303 self.metadata = metadata = ( 304 opts.get("target_metadata", None) if metadata is None else metadata 305 ) 306 307 if ( 308 autogenerate 309 and metadata is None 310 and migration_context is not None 311 and migration_context.script is not None 312 ): 313 raise util.CommandError( 314 "Can't proceed with --autogenerate option; environment " 315 "script %s does not provide " 316 "a MetaData object or sequence of objects to the context." 317 % (migration_context.script.env_py_location) 318 ) 319 320 include_object = opts.get("include_object", None) 321 include_name = opts.get("include_name", None) 322 323 object_filters = [] 324 name_filters = [] 325 if include_object: 326 object_filters.append(include_object) 327 if include_name: 328 name_filters.append(include_name) 329 330 self._object_filters = object_filters 331 self._name_filters = name_filters 332 333 self.migration_context = migration_context 334 if self.migration_context is not None: 335 self.connection = self.migration_context.bind 336 self.dialect = self.migration_context.dialect 337 338 self.imports = set() 339 self.opts: Dict[str, Any] = opts 340 self._has_batch: bool = False 341 342 @util.memoized_property 343 def inspector(self) -> "Inspector": 344 if self.connection is None: 345 raise TypeError( 346 "can't return inspector as this " 347 "AutogenContext has no database connection" 348 ) 349 return inspect(self.connection) 350 351 @contextlib.contextmanager 352 def _within_batch(self) -> Iterator[None]: 353 self._has_batch = True 354 yield 355 self._has_batch = False 356 357 def run_name_filters( 358 self, 359 name: Optional[str], 360 type_: str, 361 parent_names: Dict[str, Optional[str]], 362 ) -> bool: 363 """Run the context's name filters and return True if the targets 364 should be part of the autogenerate operation. 365 366 This method should be run for every kind of name encountered within the 367 reflection side of an autogenerate operation, giving the environment 368 the chance to filter what names should be reflected as database 369 objects. The filters here are produced directly via the 370 :paramref:`.EnvironmentContext.configure.include_name` parameter. 371 372 """ 373 374 if "schema_name" in parent_names: 375 if type_ == "table": 376 table_name = name 377 else: 378 table_name = parent_names.get("table_name", None) 379 if table_name: 380 schema_name = parent_names["schema_name"] 381 if schema_name: 382 parent_names["schema_qualified_table_name"] = "%s.%s" % ( 383 schema_name, 384 table_name, 385 ) 386 else: 387 parent_names["schema_qualified_table_name"] = table_name 388 389 for fn in self._name_filters: 390 391 if not fn(name, type_, parent_names): 392 return False 393 else: 394 return True 395 396 def run_object_filters( 397 self, 398 object_: Union[ 399 "Table", 400 "Index", 401 "Column", 402 "UniqueConstraint", 403 "ForeignKeyConstraint", 404 ], 405 name: Optional[str], 406 type_: str, 407 reflected: bool, 408 compare_to: Optional[ 409 Union["Table", "Index", "Column", "UniqueConstraint"] 410 ], 411 ) -> bool: 412 """Run the context's object filters and return True if the targets 413 should be part of the autogenerate operation. 414 415 This method should be run for every kind of object encountered within 416 an autogenerate operation, giving the environment the chance 417 to filter what objects should be included in the comparison. 418 The filters here are produced directly via the 419 :paramref:`.EnvironmentContext.configure.include_object` parameter. 420 421 """ 422 for fn in self._object_filters: 423 if not fn(object_, name, type_, reflected, compare_to): 424 return False 425 else: 426 return True 427 428 run_filters = run_object_filters 429 430 @util.memoized_property 431 def sorted_tables(self): 432 """Return an aggregate of the :attr:`.MetaData.sorted_tables` collection(s). 433 434 For a sequence of :class:`.MetaData` objects, this 435 concatenates the :attr:`.MetaData.sorted_tables` collection 436 for each individual :class:`.MetaData` in the order of the 437 sequence. It does **not** collate the sorted tables collections. 438 439 """ 440 result = [] 441 for m in util.to_list(self.metadata): 442 result.extend(m.sorted_tables) 443 return result 444 445 @util.memoized_property 446 def table_key_to_table(self): 447 """Return an aggregate of the :attr:`.MetaData.tables` dictionaries. 448 449 The :attr:`.MetaData.tables` collection is a dictionary of table key 450 to :class:`.Table`; this method aggregates the dictionary across 451 multiple :class:`.MetaData` objects into one dictionary. 452 453 Duplicate table keys are **not** supported; if two :class:`.MetaData` 454 objects contain the same table key, an exception is raised. 455 456 """ 457 result = {} 458 for m in util.to_list(self.metadata): 459 intersect = set(result).intersection(set(m.tables)) 460 if intersect: 461 raise ValueError( 462 "Duplicate table keys across multiple " 463 "MetaData objects: %s" 464 % (", ".join('"%s"' % key for key in sorted(intersect))) 465 ) 466 467 result.update(m.tables) 468 return result 469 470 471class RevisionContext: 472 """Maintains configuration and state that's specific to a revision 473 file generation operation.""" 474 475 def __init__( 476 self, 477 config: "Config", 478 script_directory: "ScriptDirectory", 479 command_args: Dict[str, Any], 480 process_revision_directives: Optional[Callable] = None, 481 ) -> None: 482 self.config = config 483 self.script_directory = script_directory 484 self.command_args = command_args 485 self.process_revision_directives = process_revision_directives 486 self.template_args = { 487 "config": config # Let templates use config for 488 # e.g. multiple databases 489 } 490 self.generated_revisions = [self._default_revision()] 491 492 def _to_script( 493 self, migration_script: "MigrationScript" 494 ) -> Optional["Script"]: 495 template_args: Dict[str, Any] = self.template_args.copy() 496 497 if getattr(migration_script, "_needs_render", False): 498 autogen_context = self._last_autogen_context 499 500 # clear out existing imports if we are doing multiple 501 # renders 502 autogen_context.imports = set() 503 if migration_script.imports: 504 autogen_context.imports.update(migration_script.imports) 505 render._render_python_into_templatevars( 506 autogen_context, migration_script, template_args 507 ) 508 509 assert migration_script.rev_id is not None 510 return self.script_directory.generate_revision( 511 migration_script.rev_id, 512 migration_script.message, 513 refresh=True, 514 head=migration_script.head, 515 splice=migration_script.splice, 516 branch_labels=migration_script.branch_label, 517 version_path=migration_script.version_path, 518 depends_on=migration_script.depends_on, 519 **template_args 520 ) 521 522 def run_autogenerate( 523 self, rev: tuple, migration_context: "MigrationContext" 524 ): 525 self._run_environment(rev, migration_context, True) 526 527 def run_no_autogenerate( 528 self, rev: tuple, migration_context: "MigrationContext" 529 ): 530 self._run_environment(rev, migration_context, False) 531 532 def _run_environment( 533 self, 534 rev: tuple, 535 migration_context: "MigrationContext", 536 autogenerate: bool, 537 ): 538 if autogenerate: 539 if self.command_args["sql"]: 540 raise util.CommandError( 541 "Using --sql with --autogenerate does not make any sense" 542 ) 543 if set(self.script_directory.get_revisions(rev)) != set( 544 self.script_directory.get_revisions("heads") 545 ): 546 raise util.CommandError("Target database is not up to date.") 547 548 upgrade_token = migration_context.opts["upgrade_token"] 549 downgrade_token = migration_context.opts["downgrade_token"] 550 551 migration_script = self.generated_revisions[-1] 552 if not getattr(migration_script, "_needs_render", False): 553 migration_script.upgrade_ops_list[-1].upgrade_token = upgrade_token 554 migration_script.downgrade_ops_list[ 555 -1 556 ].downgrade_token = downgrade_token 557 migration_script._needs_render = True 558 else: 559 migration_script._upgrade_ops.append( 560 ops.UpgradeOps([], upgrade_token=upgrade_token) 561 ) 562 migration_script._downgrade_ops.append( 563 ops.DowngradeOps([], downgrade_token=downgrade_token) 564 ) 565 566 autogen_context = AutogenContext( 567 migration_context, autogenerate=autogenerate 568 ) 569 self._last_autogen_context: AutogenContext = autogen_context 570 571 if autogenerate: 572 compare._populate_migration_script( 573 autogen_context, migration_script 574 ) 575 576 if self.process_revision_directives: 577 self.process_revision_directives( 578 migration_context, rev, self.generated_revisions 579 ) 580 581 hook = migration_context.opts["process_revision_directives"] 582 if hook: 583 hook(migration_context, rev, self.generated_revisions) 584 585 for migration_script in self.generated_revisions: 586 migration_script._needs_render = True 587 588 def _default_revision(self) -> "MigrationScript": 589 command_args: Dict[str, Any] = self.command_args 590 op = ops.MigrationScript( 591 rev_id=command_args["rev_id"] or util.rev_id(), 592 message=command_args["message"], 593 upgrade_ops=ops.UpgradeOps([]), 594 downgrade_ops=ops.DowngradeOps([]), 595 head=command_args["head"], 596 splice=command_args["splice"], 597 branch_label=command_args["branch_label"], 598 version_path=command_args["version_path"], 599 depends_on=command_args["depends_on"], 600 ) 601 return op 602 603 def generate_scripts(self) -> Iterator[Optional["Script"]]: 604 for generated_revision in self.generated_revisions: 605 yield self._to_script(generated_revision) 606