1from typing import Any 2from typing import Dict 3 4from sqlalchemy import CHAR 5from sqlalchemy import CheckConstraint 6from sqlalchemy import Column 7from sqlalchemy import event 8from sqlalchemy import ForeignKey 9from sqlalchemy import Index 10from sqlalchemy import inspect 11from sqlalchemy import Integer 12from sqlalchemy import MetaData 13from sqlalchemy import Numeric 14from sqlalchemy import String 15from sqlalchemy import Table 16from sqlalchemy import Text 17from sqlalchemy import text 18from sqlalchemy import UniqueConstraint 19 20from ... import autogenerate 21from ... import util 22from ...autogenerate import api 23from ...ddl.base import _fk_spec 24from ...migration import MigrationContext 25from ...operations import ops 26from ...testing import config 27from ...testing import eq_ 28from ...testing.env import clear_staging_env 29from ...testing.env import staging_env 30 31names_in_this_test = set() 32 33 34@event.listens_for(Table, "after_parent_attach") 35def new_table(table, parent): 36 names_in_this_test.add(table.name) 37 38 39def _default_include_object(obj, name, type_, reflected, compare_to): 40 if type_ == "table": 41 return name in names_in_this_test 42 else: 43 return True 44 45 46_default_object_filters = _default_include_object 47 48_default_name_filters = None 49 50 51class ModelOne: 52 __requires__ = ("unique_constraint_reflection",) 53 54 schema = None 55 56 @classmethod 57 def _get_db_schema(cls): 58 schema = cls.schema 59 60 m = MetaData(schema=schema) 61 62 Table( 63 "user", 64 m, 65 Column("id", Integer, primary_key=True), 66 Column("name", String(50)), 67 Column("a1", Text), 68 Column("pw", String(50)), 69 Index("pw_idx", "pw"), 70 ) 71 72 Table( 73 "address", 74 m, 75 Column("id", Integer, primary_key=True), 76 Column("email_address", String(100), nullable=False), 77 ) 78 79 Table( 80 "order", 81 m, 82 Column("order_id", Integer, primary_key=True), 83 Column( 84 "amount", 85 Numeric(8, 2), 86 nullable=False, 87 server_default=text("0"), 88 ), 89 CheckConstraint("amount >= 0", name="ck_order_amount"), 90 ) 91 92 Table( 93 "extra", 94 m, 95 Column("x", CHAR), 96 Column("uid", Integer, ForeignKey("user.id")), 97 ) 98 99 return m 100 101 @classmethod 102 def _get_model_schema(cls): 103 schema = cls.schema 104 105 m = MetaData(schema=schema) 106 107 Table( 108 "user", 109 m, 110 Column("id", Integer, primary_key=True), 111 Column("name", String(50), nullable=False), 112 Column("a1", Text, server_default="x"), 113 ) 114 115 Table( 116 "address", 117 m, 118 Column("id", Integer, primary_key=True), 119 Column("email_address", String(100), nullable=False), 120 Column("street", String(50)), 121 UniqueConstraint("email_address", name="uq_email"), 122 ) 123 124 Table( 125 "order", 126 m, 127 Column("order_id", Integer, primary_key=True), 128 Column( 129 "amount", 130 Numeric(10, 2), 131 nullable=True, 132 server_default=text("0"), 133 ), 134 Column("user_id", Integer, ForeignKey("user.id")), 135 CheckConstraint("amount > -1", name="ck_order_amount"), 136 ) 137 138 Table( 139 "item", 140 m, 141 Column("id", Integer, primary_key=True), 142 Column("description", String(100)), 143 Column("order_id", Integer, ForeignKey("order.order_id")), 144 CheckConstraint("len(description) > 5"), 145 ) 146 return m 147 148 149class _ComparesFKs: 150 def _assert_fk_diff( 151 self, 152 diff, 153 type_, 154 source_table, 155 source_columns, 156 target_table, 157 target_columns, 158 name=None, 159 conditional_name=None, 160 source_schema=None, 161 onupdate=None, 162 ondelete=None, 163 initially=None, 164 deferrable=None, 165 ): 166 # the public API for ForeignKeyConstraint was not very rich 167 # in 0.7, 0.8, so here we use the well-known but slightly 168 # private API to get at its elements 169 ( 170 fk_source_schema, 171 fk_source_table, 172 fk_source_columns, 173 fk_target_schema, 174 fk_target_table, 175 fk_target_columns, 176 fk_onupdate, 177 fk_ondelete, 178 fk_deferrable, 179 fk_initially, 180 ) = _fk_spec(diff[1]) 181 182 eq_(diff[0], type_) 183 eq_(fk_source_table, source_table) 184 eq_(fk_source_columns, source_columns) 185 eq_(fk_target_table, target_table) 186 eq_(fk_source_schema, source_schema) 187 eq_(fk_onupdate, onupdate) 188 eq_(fk_ondelete, ondelete) 189 eq_(fk_initially, initially) 190 eq_(fk_deferrable, deferrable) 191 192 eq_([elem.column.name for elem in diff[1].elements], target_columns) 193 if conditional_name is not None: 194 if conditional_name == "servergenerated": 195 fks = inspect(self.bind).get_foreign_keys(source_table) 196 server_fk_name = fks[0]["name"] 197 eq_(diff[1].name, server_fk_name) 198 else: 199 eq_(diff[1].name, conditional_name) 200 else: 201 eq_(diff[1].name, name) 202 203 204class AutogenTest(_ComparesFKs): 205 def _flatten_diffs(self, diffs): 206 for d in diffs: 207 if isinstance(d, list): 208 for fd in self._flatten_diffs(d): 209 yield fd 210 else: 211 yield d 212 213 @classmethod 214 def _get_bind(cls): 215 return config.db 216 217 configure_opts: Dict[Any, Any] = {} 218 219 @classmethod 220 def setup_class(cls): 221 staging_env() 222 cls.bind = cls._get_bind() 223 cls.m1 = cls._get_db_schema() 224 cls.m1.create_all(cls.bind) 225 cls.m2 = cls._get_model_schema() 226 227 @classmethod 228 def teardown_class(cls): 229 cls.m1.drop_all(cls.bind) 230 clear_staging_env() 231 232 def setUp(self): 233 self.conn = conn = self.bind.connect() 234 ctx_opts = { 235 "compare_type": True, 236 "compare_server_default": True, 237 "target_metadata": self.m2, 238 "upgrade_token": "upgrades", 239 "downgrade_token": "downgrades", 240 "alembic_module_prefix": "op.", 241 "sqlalchemy_module_prefix": "sa.", 242 "include_object": _default_object_filters, 243 "include_name": _default_name_filters, 244 } 245 if self.configure_opts: 246 ctx_opts.update(self.configure_opts) 247 self.context = context = MigrationContext.configure( 248 connection=conn, opts=ctx_opts 249 ) 250 251 self.autogen_context = api.AutogenContext(context, self.m2) 252 253 def tearDown(self): 254 self.conn.close() 255 256 def _update_context( 257 self, object_filters=None, name_filters=None, include_schemas=None 258 ): 259 if include_schemas is not None: 260 self.autogen_context.opts["include_schemas"] = include_schemas 261 if object_filters is not None: 262 self.autogen_context._object_filters = [object_filters] 263 if name_filters is not None: 264 self.autogen_context._name_filters = [name_filters] 265 return self.autogen_context 266 267 268class AutogenFixtureTest(_ComparesFKs): 269 def _fixture( 270 self, 271 m1, 272 m2, 273 include_schemas=False, 274 opts=None, 275 object_filters=_default_object_filters, 276 name_filters=_default_name_filters, 277 return_ops=False, 278 max_identifier_length=None, 279 ): 280 281 if max_identifier_length: 282 dialect = self.bind.dialect 283 existing_length = dialect.max_identifier_length 284 dialect.max_identifier_length = ( 285 dialect._user_defined_max_identifier_length 286 ) = max_identifier_length 287 try: 288 self._alembic_metadata, model_metadata = m1, m2 289 for m in util.to_list(self._alembic_metadata): 290 m.create_all(self.bind) 291 292 with self.bind.connect() as conn: 293 ctx_opts = { 294 "compare_type": True, 295 "compare_server_default": True, 296 "target_metadata": model_metadata, 297 "upgrade_token": "upgrades", 298 "downgrade_token": "downgrades", 299 "alembic_module_prefix": "op.", 300 "sqlalchemy_module_prefix": "sa.", 301 "include_object": object_filters, 302 "include_name": name_filters, 303 "include_schemas": include_schemas, 304 } 305 if opts: 306 ctx_opts.update(opts) 307 self.context = context = MigrationContext.configure( 308 connection=conn, opts=ctx_opts 309 ) 310 311 autogen_context = api.AutogenContext(context, model_metadata) 312 uo = ops.UpgradeOps(ops=[]) 313 autogenerate._produce_net_changes(autogen_context, uo) 314 315 if return_ops: 316 return uo 317 else: 318 return uo.as_diffs() 319 finally: 320 if max_identifier_length: 321 dialect = self.bind.dialect 322 dialect.max_identifier_length = ( 323 dialect._user_defined_max_identifier_length 324 ) = existing_length 325 326 reports_unnamed_constraints = False 327 328 def setUp(self): 329 staging_env() 330 self.bind = config.db 331 332 def tearDown(self): 333 if hasattr(self, "_alembic_metadata"): 334 for m in util.to_list(self._alembic_metadata): 335 m.drop_all(self.bind) 336 clear_staging_env() 337