1from typing import Any 2from typing import Callable 3from typing import Iterator 4from typing import List 5from typing import Type 6from typing import TYPE_CHECKING 7from typing import Union 8 9from alembic import util 10from alembic.operations import ops 11 12if TYPE_CHECKING: 13 from alembic.operations.ops import AddColumnOp 14 from alembic.operations.ops import AlterColumnOp 15 from alembic.operations.ops import CreateTableOp 16 from alembic.operations.ops import MigrateOperation 17 from alembic.operations.ops import MigrationScript 18 from alembic.operations.ops import ModifyTableOps 19 from alembic.operations.ops import OpContainer 20 from alembic.runtime.migration import MigrationContext 21 from alembic.script.revision import Revision 22 23 24class Rewriter: 25 """A helper object that allows easy 'rewriting' of ops streams. 26 27 The :class:`.Rewriter` object is intended to be passed along 28 to the 29 :paramref:`.EnvironmentContext.configure.process_revision_directives` 30 parameter in an ``env.py`` script. Once constructed, any number 31 of "rewrites" functions can be associated with it, which will be given 32 the opportunity to modify the structure without having to have explicit 33 knowledge of the overall structure. 34 35 The function is passed the :class:`.MigrationContext` object and 36 ``revision`` tuple that are passed to the :paramref:`.Environment 37 Context.configure.process_revision_directives` function normally, 38 and the third argument is an individual directive of the type 39 noted in the decorator. The function has the choice of returning 40 a single op directive, which normally can be the directive that 41 was actually passed, or a new directive to replace it, or a list 42 of zero or more directives to replace it. 43 44 .. seealso:: 45 46 :ref:`autogen_rewriter` - usage example 47 48 """ 49 50 _traverse = util.Dispatcher() 51 52 _chained = None 53 54 def __init__(self) -> None: 55 self.dispatch = util.Dispatcher() 56 57 def chain(self, other: "Rewriter") -> "Rewriter": 58 """Produce a "chain" of this :class:`.Rewriter` to another. 59 60 This allows two rewriters to operate serially on a stream, 61 e.g.:: 62 63 writer1 = autogenerate.Rewriter() 64 writer2 = autogenerate.Rewriter() 65 66 @writer1.rewrites(ops.AddColumnOp) 67 def add_column_nullable(context, revision, op): 68 op.column.nullable = True 69 return op 70 71 @writer2.rewrites(ops.AddColumnOp) 72 def add_column_idx(context, revision, op): 73 idx_op = ops.CreateIndexOp( 74 'ixc', op.table_name, [op.column.name]) 75 return [ 76 op, 77 idx_op 78 ] 79 80 writer = writer1.chain(writer2) 81 82 :param other: a :class:`.Rewriter` instance 83 :return: a new :class:`.Rewriter` that will run the operations 84 of this writer, then the "other" writer, in succession. 85 86 """ 87 wr = self.__class__.__new__(self.__class__) 88 wr.__dict__.update(self.__dict__) 89 wr._chained = other 90 return wr 91 92 def rewrites( 93 self, 94 operator: Union[ 95 Type["AddColumnOp"], 96 Type["MigrateOperation"], 97 Type["AlterColumnOp"], 98 Type["CreateTableOp"], 99 Type["ModifyTableOps"], 100 ], 101 ) -> Callable: 102 """Register a function as rewriter for a given type. 103 104 The function should receive three arguments, which are 105 the :class:`.MigrationContext`, a ``revision`` tuple, and 106 an op directive of the type indicated. E.g.:: 107 108 @writer1.rewrites(ops.AddColumnOp) 109 def add_column_nullable(context, revision, op): 110 op.column.nullable = True 111 return op 112 113 """ 114 return self.dispatch.dispatch_for(operator) 115 116 def _rewrite( 117 self, 118 context: "MigrationContext", 119 revision: "Revision", 120 directive: "MigrateOperation", 121 ) -> Iterator["MigrateOperation"]: 122 try: 123 _rewriter = self.dispatch.dispatch(directive) 124 except ValueError: 125 _rewriter = None 126 yield directive 127 else: 128 if self in directive._mutations: 129 yield directive 130 else: 131 for r_directive in util.to_list( 132 _rewriter(context, revision, directive), [] 133 ): 134 r_directive._mutations = r_directive._mutations.union( 135 [self] 136 ) 137 yield r_directive 138 139 def __call__( 140 self, 141 context: "MigrationContext", 142 revision: "Revision", 143 directives: List["MigrationScript"], 144 ) -> None: 145 self.process_revision_directives(context, revision, directives) 146 if self._chained: 147 self._chained(context, revision, directives) 148 149 @_traverse.dispatch_for(ops.MigrationScript) 150 def _traverse_script( 151 self, 152 context: "MigrationContext", 153 revision: "Revision", 154 directive: "MigrationScript", 155 ) -> None: 156 upgrade_ops_list = [] 157 for upgrade_ops in directive.upgrade_ops_list: 158 ret = self._traverse_for(context, revision, upgrade_ops) 159 if len(ret) != 1: 160 raise ValueError( 161 "Can only return single object for UpgradeOps traverse" 162 ) 163 upgrade_ops_list.append(ret[0]) 164 directive.upgrade_ops = upgrade_ops_list 165 166 downgrade_ops_list = [] 167 for downgrade_ops in directive.downgrade_ops_list: 168 ret = self._traverse_for(context, revision, downgrade_ops) 169 if len(ret) != 1: 170 raise ValueError( 171 "Can only return single object for DowngradeOps traverse" 172 ) 173 downgrade_ops_list.append(ret[0]) 174 directive.downgrade_ops = downgrade_ops_list 175 176 @_traverse.dispatch_for(ops.OpContainer) 177 def _traverse_op_container( 178 self, 179 context: "MigrationContext", 180 revision: "Revision", 181 directive: "OpContainer", 182 ) -> None: 183 self._traverse_list(context, revision, directive.ops) 184 185 @_traverse.dispatch_for(ops.MigrateOperation) 186 def _traverse_any_directive( 187 self, 188 context: "MigrationContext", 189 revision: "Revision", 190 directive: "MigrateOperation", 191 ) -> None: 192 pass 193 194 def _traverse_for( 195 self, 196 context: "MigrationContext", 197 revision: "Revision", 198 directive: "MigrateOperation", 199 ) -> Any: 200 directives = list(self._rewrite(context, revision, directive)) 201 for directive in directives: 202 traverser = self._traverse.dispatch(directive) 203 traverser(self, context, revision, directive) 204 return directives 205 206 def _traverse_list( 207 self, 208 context: "MigrationContext", 209 revision: "Revision", 210 directives: Any, 211 ) -> None: 212 dest = [] 213 for directive in directives: 214 dest.extend(self._traverse_for(context, revision, directive)) 215 216 directives[:] = dest 217 218 def process_revision_directives( 219 self, 220 context: "MigrationContext", 221 revision: "Revision", 222 directives: List["MigrationScript"], 223 ) -> None: 224 self._traverse_list(context, revision, directives) 225