1from typing import Any
2from typing import Callable
3from typing import Iterator
4from typing import List
5from typing import Type
6from typing import TYPE_CHECKING
7from typing import Union
8
9from alembic import util
10from alembic.operations import ops
11
12if TYPE_CHECKING:
13    from alembic.operations.ops import AddColumnOp
14    from alembic.operations.ops import AlterColumnOp
15    from alembic.operations.ops import CreateTableOp
16    from alembic.operations.ops import MigrateOperation
17    from alembic.operations.ops import MigrationScript
18    from alembic.operations.ops import ModifyTableOps
19    from alembic.operations.ops import OpContainer
20    from alembic.runtime.migration import MigrationContext
21    from alembic.script.revision import Revision
22
23
24class Rewriter:
25    """A helper object that allows easy 'rewriting' of ops streams.
26
27    The :class:`.Rewriter` object is intended to be passed along
28    to the
29    :paramref:`.EnvironmentContext.configure.process_revision_directives`
30    parameter in an ``env.py`` script.    Once constructed, any number
31    of "rewrites" functions can be associated with it, which will be given
32    the opportunity to modify the structure without having to have explicit
33    knowledge of the overall structure.
34
35    The function is passed the :class:`.MigrationContext` object and
36    ``revision`` tuple that are passed to the  :paramref:`.Environment
37    Context.configure.process_revision_directives` function normally,
38    and the third argument is an individual directive of the type
39    noted in the decorator.  The function has the choice of  returning
40    a single op directive, which normally can be the directive that
41    was actually passed, or a new directive to replace it, or a list
42    of zero or more directives to replace it.
43
44    .. seealso::
45
46        :ref:`autogen_rewriter` - usage example
47
48    """
49
50    _traverse = util.Dispatcher()
51
52    _chained = None
53
54    def __init__(self) -> None:
55        self.dispatch = util.Dispatcher()
56
57    def chain(self, other: "Rewriter") -> "Rewriter":
58        """Produce a "chain" of this :class:`.Rewriter` to another.
59
60        This allows two rewriters to operate serially on a stream,
61        e.g.::
62
63            writer1 = autogenerate.Rewriter()
64            writer2 = autogenerate.Rewriter()
65
66            @writer1.rewrites(ops.AddColumnOp)
67            def add_column_nullable(context, revision, op):
68                op.column.nullable = True
69                return op
70
71            @writer2.rewrites(ops.AddColumnOp)
72            def add_column_idx(context, revision, op):
73                idx_op = ops.CreateIndexOp(
74                    'ixc', op.table_name, [op.column.name])
75                return [
76                    op,
77                    idx_op
78                ]
79
80            writer = writer1.chain(writer2)
81
82        :param other: a :class:`.Rewriter` instance
83        :return: a new :class:`.Rewriter` that will run the operations
84         of this writer, then the "other" writer, in succession.
85
86        """
87        wr = self.__class__.__new__(self.__class__)
88        wr.__dict__.update(self.__dict__)
89        wr._chained = other
90        return wr
91
92    def rewrites(
93        self,
94        operator: Union[
95            Type["AddColumnOp"],
96            Type["MigrateOperation"],
97            Type["AlterColumnOp"],
98            Type["CreateTableOp"],
99            Type["ModifyTableOps"],
100        ],
101    ) -> Callable:
102        """Register a function as rewriter for a given type.
103
104        The function should receive three arguments, which are
105        the :class:`.MigrationContext`, a ``revision`` tuple, and
106        an op directive of the type indicated.  E.g.::
107
108            @writer1.rewrites(ops.AddColumnOp)
109            def add_column_nullable(context, revision, op):
110                op.column.nullable = True
111                return op
112
113        """
114        return self.dispatch.dispatch_for(operator)
115
116    def _rewrite(
117        self,
118        context: "MigrationContext",
119        revision: "Revision",
120        directive: "MigrateOperation",
121    ) -> Iterator["MigrateOperation"]:
122        try:
123            _rewriter = self.dispatch.dispatch(directive)
124        except ValueError:
125            _rewriter = None
126            yield directive
127        else:
128            if self in directive._mutations:
129                yield directive
130            else:
131                for r_directive in util.to_list(
132                    _rewriter(context, revision, directive), []
133                ):
134                    r_directive._mutations = r_directive._mutations.union(
135                        [self]
136                    )
137                    yield r_directive
138
139    def __call__(
140        self,
141        context: "MigrationContext",
142        revision: "Revision",
143        directives: List["MigrationScript"],
144    ) -> None:
145        self.process_revision_directives(context, revision, directives)
146        if self._chained:
147            self._chained(context, revision, directives)
148
149    @_traverse.dispatch_for(ops.MigrationScript)
150    def _traverse_script(
151        self,
152        context: "MigrationContext",
153        revision: "Revision",
154        directive: "MigrationScript",
155    ) -> None:
156        upgrade_ops_list = []
157        for upgrade_ops in directive.upgrade_ops_list:
158            ret = self._traverse_for(context, revision, upgrade_ops)
159            if len(ret) != 1:
160                raise ValueError(
161                    "Can only return single object for UpgradeOps traverse"
162                )
163            upgrade_ops_list.append(ret[0])
164        directive.upgrade_ops = upgrade_ops_list
165
166        downgrade_ops_list = []
167        for downgrade_ops in directive.downgrade_ops_list:
168            ret = self._traverse_for(context, revision, downgrade_ops)
169            if len(ret) != 1:
170                raise ValueError(
171                    "Can only return single object for DowngradeOps traverse"
172                )
173            downgrade_ops_list.append(ret[0])
174        directive.downgrade_ops = downgrade_ops_list
175
176    @_traverse.dispatch_for(ops.OpContainer)
177    def _traverse_op_container(
178        self,
179        context: "MigrationContext",
180        revision: "Revision",
181        directive: "OpContainer",
182    ) -> None:
183        self._traverse_list(context, revision, directive.ops)
184
185    @_traverse.dispatch_for(ops.MigrateOperation)
186    def _traverse_any_directive(
187        self,
188        context: "MigrationContext",
189        revision: "Revision",
190        directive: "MigrateOperation",
191    ) -> None:
192        pass
193
194    def _traverse_for(
195        self,
196        context: "MigrationContext",
197        revision: "Revision",
198        directive: "MigrateOperation",
199    ) -> Any:
200        directives = list(self._rewrite(context, revision, directive))
201        for directive in directives:
202            traverser = self._traverse.dispatch(directive)
203            traverser(self, context, revision, directive)
204        return directives
205
206    def _traverse_list(
207        self,
208        context: "MigrationContext",
209        revision: "Revision",
210        directives: Any,
211    ) -> None:
212        dest = []
213        for directive in directives:
214            dest.extend(self._traverse_for(context, revision, directive))
215
216        directives[:] = dest
217
218    def process_revision_directives(
219        self,
220        context: "MigrationContext",
221        revision: "Revision",
222        directives: List["MigrationScript"],
223    ) -> None:
224        self._traverse_list(context, revision, directives)
225