1from typing import Any
2from typing import Dict
3
4from sqlalchemy import CHAR
5from sqlalchemy import CheckConstraint
6from sqlalchemy import Column
7from sqlalchemy import event
8from sqlalchemy import ForeignKey
9from sqlalchemy import Index
10from sqlalchemy import inspect
11from sqlalchemy import Integer
12from sqlalchemy import MetaData
13from sqlalchemy import Numeric
14from sqlalchemy import String
15from sqlalchemy import Table
16from sqlalchemy import Text
17from sqlalchemy import text
18from sqlalchemy import UniqueConstraint
19
20from ... import autogenerate
21from ... import util
22from ...autogenerate import api
23from ...ddl.base import _fk_spec
24from ...migration import MigrationContext
25from ...operations import ops
26from ...testing import config
27from ...testing import eq_
28from ...testing.env import clear_staging_env
29from ...testing.env import staging_env
30
31names_in_this_test = set()
32
33
34@event.listens_for(Table, "after_parent_attach")
35def new_table(table, parent):
36    names_in_this_test.add(table.name)
37
38
39def _default_include_object(obj, name, type_, reflected, compare_to):
40    if type_ == "table":
41        return name in names_in_this_test
42    else:
43        return True
44
45
46_default_object_filters = _default_include_object
47
48_default_name_filters = None
49
50
51class ModelOne:
52    __requires__ = ("unique_constraint_reflection",)
53
54    schema = None
55
56    @classmethod
57    def _get_db_schema(cls):
58        schema = cls.schema
59
60        m = MetaData(schema=schema)
61
62        Table(
63            "user",
64            m,
65            Column("id", Integer, primary_key=True),
66            Column("name", String(50)),
67            Column("a1", Text),
68            Column("pw", String(50)),
69            Index("pw_idx", "pw"),
70        )
71
72        Table(
73            "address",
74            m,
75            Column("id", Integer, primary_key=True),
76            Column("email_address", String(100), nullable=False),
77        )
78
79        Table(
80            "order",
81            m,
82            Column("order_id", Integer, primary_key=True),
83            Column(
84                "amount",
85                Numeric(8, 2),
86                nullable=False,
87                server_default=text("0"),
88            ),
89            CheckConstraint("amount >= 0", name="ck_order_amount"),
90        )
91
92        Table(
93            "extra",
94            m,
95            Column("x", CHAR),
96            Column("uid", Integer, ForeignKey("user.id")),
97        )
98
99        return m
100
101    @classmethod
102    def _get_model_schema(cls):
103        schema = cls.schema
104
105        m = MetaData(schema=schema)
106
107        Table(
108            "user",
109            m,
110            Column("id", Integer, primary_key=True),
111            Column("name", String(50), nullable=False),
112            Column("a1", Text, server_default="x"),
113        )
114
115        Table(
116            "address",
117            m,
118            Column("id", Integer, primary_key=True),
119            Column("email_address", String(100), nullable=False),
120            Column("street", String(50)),
121            UniqueConstraint("email_address", name="uq_email"),
122        )
123
124        Table(
125            "order",
126            m,
127            Column("order_id", Integer, primary_key=True),
128            Column(
129                "amount",
130                Numeric(10, 2),
131                nullable=True,
132                server_default=text("0"),
133            ),
134            Column("user_id", Integer, ForeignKey("user.id")),
135            CheckConstraint("amount > -1", name="ck_order_amount"),
136        )
137
138        Table(
139            "item",
140            m,
141            Column("id", Integer, primary_key=True),
142            Column("description", String(100)),
143            Column("order_id", Integer, ForeignKey("order.order_id")),
144            CheckConstraint("len(description) > 5"),
145        )
146        return m
147
148
149class _ComparesFKs:
150    def _assert_fk_diff(
151        self,
152        diff,
153        type_,
154        source_table,
155        source_columns,
156        target_table,
157        target_columns,
158        name=None,
159        conditional_name=None,
160        source_schema=None,
161        onupdate=None,
162        ondelete=None,
163        initially=None,
164        deferrable=None,
165    ):
166        # the public API for ForeignKeyConstraint was not very rich
167        # in 0.7, 0.8, so here we use the well-known but slightly
168        # private API to get at its elements
169        (
170            fk_source_schema,
171            fk_source_table,
172            fk_source_columns,
173            fk_target_schema,
174            fk_target_table,
175            fk_target_columns,
176            fk_onupdate,
177            fk_ondelete,
178            fk_deferrable,
179            fk_initially,
180        ) = _fk_spec(diff[1])
181
182        eq_(diff[0], type_)
183        eq_(fk_source_table, source_table)
184        eq_(fk_source_columns, source_columns)
185        eq_(fk_target_table, target_table)
186        eq_(fk_source_schema, source_schema)
187        eq_(fk_onupdate, onupdate)
188        eq_(fk_ondelete, ondelete)
189        eq_(fk_initially, initially)
190        eq_(fk_deferrable, deferrable)
191
192        eq_([elem.column.name for elem in diff[1].elements], target_columns)
193        if conditional_name is not None:
194            if conditional_name == "servergenerated":
195                fks = inspect(self.bind).get_foreign_keys(source_table)
196                server_fk_name = fks[0]["name"]
197                eq_(diff[1].name, server_fk_name)
198            else:
199                eq_(diff[1].name, conditional_name)
200        else:
201            eq_(diff[1].name, name)
202
203
204class AutogenTest(_ComparesFKs):
205    def _flatten_diffs(self, diffs):
206        for d in diffs:
207            if isinstance(d, list):
208                for fd in self._flatten_diffs(d):
209                    yield fd
210            else:
211                yield d
212
213    @classmethod
214    def _get_bind(cls):
215        return config.db
216
217    configure_opts: Dict[Any, Any] = {}
218
219    @classmethod
220    def setup_class(cls):
221        staging_env()
222        cls.bind = cls._get_bind()
223        cls.m1 = cls._get_db_schema()
224        cls.m1.create_all(cls.bind)
225        cls.m2 = cls._get_model_schema()
226
227    @classmethod
228    def teardown_class(cls):
229        cls.m1.drop_all(cls.bind)
230        clear_staging_env()
231
232    def setUp(self):
233        self.conn = conn = self.bind.connect()
234        ctx_opts = {
235            "compare_type": True,
236            "compare_server_default": True,
237            "target_metadata": self.m2,
238            "upgrade_token": "upgrades",
239            "downgrade_token": "downgrades",
240            "alembic_module_prefix": "op.",
241            "sqlalchemy_module_prefix": "sa.",
242            "include_object": _default_object_filters,
243            "include_name": _default_name_filters,
244        }
245        if self.configure_opts:
246            ctx_opts.update(self.configure_opts)
247        self.context = context = MigrationContext.configure(
248            connection=conn, opts=ctx_opts
249        )
250
251        self.autogen_context = api.AutogenContext(context, self.m2)
252
253    def tearDown(self):
254        self.conn.close()
255
256    def _update_context(
257        self, object_filters=None, name_filters=None, include_schemas=None
258    ):
259        if include_schemas is not None:
260            self.autogen_context.opts["include_schemas"] = include_schemas
261        if object_filters is not None:
262            self.autogen_context._object_filters = [object_filters]
263        if name_filters is not None:
264            self.autogen_context._name_filters = [name_filters]
265        return self.autogen_context
266
267
268class AutogenFixtureTest(_ComparesFKs):
269    def _fixture(
270        self,
271        m1,
272        m2,
273        include_schemas=False,
274        opts=None,
275        object_filters=_default_object_filters,
276        name_filters=_default_name_filters,
277        return_ops=False,
278        max_identifier_length=None,
279    ):
280
281        if max_identifier_length:
282            dialect = self.bind.dialect
283            existing_length = dialect.max_identifier_length
284            dialect.max_identifier_length = (
285                dialect._user_defined_max_identifier_length
286            ) = max_identifier_length
287        try:
288            self._alembic_metadata, model_metadata = m1, m2
289            for m in util.to_list(self._alembic_metadata):
290                m.create_all(self.bind)
291
292            with self.bind.connect() as conn:
293                ctx_opts = {
294                    "compare_type": True,
295                    "compare_server_default": True,
296                    "target_metadata": model_metadata,
297                    "upgrade_token": "upgrades",
298                    "downgrade_token": "downgrades",
299                    "alembic_module_prefix": "op.",
300                    "sqlalchemy_module_prefix": "sa.",
301                    "include_object": object_filters,
302                    "include_name": name_filters,
303                    "include_schemas": include_schemas,
304                }
305                if opts:
306                    ctx_opts.update(opts)
307                self.context = context = MigrationContext.configure(
308                    connection=conn, opts=ctx_opts
309                )
310
311                autogen_context = api.AutogenContext(context, model_metadata)
312                uo = ops.UpgradeOps(ops=[])
313                autogenerate._produce_net_changes(autogen_context, uo)
314
315                if return_ops:
316                    return uo
317                else:
318                    return uo.as_diffs()
319        finally:
320            if max_identifier_length:
321                dialect = self.bind.dialect
322                dialect.max_identifier_length = (
323                    dialect._user_defined_max_identifier_length
324                ) = existing_length
325
326    reports_unnamed_constraints = False
327
328    def setUp(self):
329        staging_env()
330        self.bind = config.db
331
332    def tearDown(self):
333        if hasattr(self, "_alembic_metadata"):
334            for m in util.to_list(self._alembic_metadata):
335                m.drop_all(self.bind)
336        clear_staging_env()
337