1import contextlib 2import re 3from typing import Iterator 4from typing import Mapping 5from typing import Optional 6from typing import TYPE_CHECKING 7from typing import TypeVar 8from typing import Union 9 10from sqlalchemy import __version__ 11from sqlalchemy import inspect 12from sqlalchemy import schema 13from sqlalchemy import sql 14from sqlalchemy import types as sqltypes 15from sqlalchemy.engine import url 16from sqlalchemy.ext.compiler import compiles 17from sqlalchemy.schema import CheckConstraint 18from sqlalchemy.schema import Column 19from sqlalchemy.schema import ForeignKeyConstraint 20from sqlalchemy.sql import visitors 21from sqlalchemy.sql.elements import BindParameter 22from sqlalchemy.sql.elements import quoted_name 23from sqlalchemy.sql.elements import TextClause 24from sqlalchemy.sql.visitors import traverse 25 26from . import compat 27 28if TYPE_CHECKING: 29 from sqlalchemy import Index 30 from sqlalchemy import Table 31 from sqlalchemy.engine import Connection 32 from sqlalchemy.engine import Dialect 33 from sqlalchemy.engine import Transaction 34 from sqlalchemy.engine.reflection import Inspector 35 from sqlalchemy.sql.base import ColumnCollection 36 from sqlalchemy.sql.compiler import SQLCompiler 37 from sqlalchemy.sql.dml import Insert 38 from sqlalchemy.sql.elements import ColumnClause 39 from sqlalchemy.sql.elements import ColumnElement 40 from sqlalchemy.sql.schema import Constraint 41 from sqlalchemy.sql.schema import SchemaItem 42 from sqlalchemy.sql.selectable import Select 43 from sqlalchemy.sql.selectable import TableClause 44 45_CE = TypeVar("_CE", bound=Union["ColumnElement", "SchemaItem"]) 46 47 48def _safe_int(value: str) -> Union[int, str]: 49 try: 50 return int(value) 51 except: 52 return value 53 54 55_vers = tuple( 56 [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)] 57) 58sqla_13 = _vers >= (1, 3) 59sqla_14 = _vers >= (1, 4) 60sqla_14_26 = _vers >= (1, 4, 26) 61 62 63if sqla_14: 64 # when future engine merges, this can be again based on version string 65 from sqlalchemy.engine import Connection as legacy_connection 66 67 sqla_1x = not hasattr(legacy_connection, "commit") 68else: 69 sqla_1x = True 70 71try: 72 from sqlalchemy import Computed # noqa 73except ImportError: 74 Computed = type(None) # type: ignore 75 has_computed = False 76 has_computed_reflection = False 77else: 78 has_computed = True 79 has_computed_reflection = _vers >= (1, 3, 16) 80 81try: 82 from sqlalchemy import Identity # noqa 83except ImportError: 84 Identity = type(None) # type: ignore 85 has_identity = False 86else: 87 # attributes common to Indentity and Sequence 88 _identity_options_attrs = ( 89 "start", 90 "increment", 91 "minvalue", 92 "maxvalue", 93 "nominvalue", 94 "nomaxvalue", 95 "cycle", 96 "cache", 97 "order", 98 ) 99 # attributes of Indentity 100 _identity_attrs = _identity_options_attrs + ("on_null",) 101 has_identity = True 102 103AUTOINCREMENT_DEFAULT = "auto" 104 105 106@contextlib.contextmanager 107def _ensure_scope_for_ddl( 108 connection: Optional["Connection"], 109) -> Iterator[None]: 110 try: 111 in_transaction = connection.in_transaction # type: ignore[union-attr] 112 except AttributeError: 113 # catch for MockConnection, None 114 yield 115 else: 116 if not in_transaction(): 117 assert connection is not None 118 with connection.begin(): 119 yield 120 else: 121 yield 122 123 124def _safe_begin_connection_transaction( 125 connection: "Connection", 126) -> "Transaction": 127 transaction = _get_connection_transaction(connection) 128 if transaction: 129 return transaction 130 else: 131 return connection.begin() 132 133 134def _safe_commit_connection_transaction( 135 connection: "Connection", 136) -> None: 137 transaction = _get_connection_transaction(connection) 138 if transaction: 139 transaction.commit() 140 141 142def _safe_rollback_connection_transaction( 143 connection: "Connection", 144) -> None: 145 transaction = _get_connection_transaction(connection) 146 if transaction: 147 transaction.rollback() 148 149 150def _get_connection_in_transaction(connection: Optional["Connection"]) -> bool: 151 try: 152 in_transaction = connection.in_transaction # type: ignore 153 except AttributeError: 154 # catch for MockConnection 155 return False 156 else: 157 return in_transaction() 158 159 160def _copy(schema_item: _CE, **kw) -> _CE: 161 if hasattr(schema_item, "_copy"): 162 return schema_item._copy(**kw) 163 else: 164 return schema_item.copy(**kw) 165 166 167def _get_connection_transaction( 168 connection: "Connection", 169) -> Optional["Transaction"]: 170 if sqla_14: 171 return connection.get_transaction() 172 else: 173 r = connection._root # type: ignore[attr-defined] 174 return r._Connection__transaction 175 176 177def _create_url(*arg, **kw) -> url.URL: 178 if hasattr(url.URL, "create"): 179 return url.URL.create(*arg, **kw) 180 else: 181 return url.URL(*arg, **kw) 182 183 184def _connectable_has_table( 185 connectable: "Connection", tablename: str, schemaname: Union[str, None] 186) -> bool: 187 if sqla_14: 188 return inspect(connectable).has_table(tablename, schemaname) 189 else: 190 return connectable.dialect.has_table( 191 connectable, tablename, schemaname 192 ) 193 194 195def _exec_on_inspector(inspector, statement, **params): 196 if sqla_14: 197 with inspector._operation_context() as conn: 198 return conn.execute(statement, params) 199 else: 200 return inspector.bind.execute(statement, params) 201 202 203def _nullability_might_be_unset(metadata_column): 204 if not sqla_14: 205 return metadata_column.nullable 206 else: 207 from sqlalchemy.sql import schema 208 209 return ( 210 metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED 211 ) 212 213 214def _server_default_is_computed(*server_default) -> bool: 215 if not has_computed: 216 return False 217 else: 218 return any(isinstance(sd, Computed) for sd in server_default) 219 220 221def _server_default_is_identity(*server_default) -> bool: 222 if not sqla_14: 223 return False 224 else: 225 return any(isinstance(sd, Identity) for sd in server_default) 226 227 228def _table_for_constraint(constraint: "Constraint") -> "Table": 229 if isinstance(constraint, ForeignKeyConstraint): 230 table = constraint.parent 231 assert table is not None 232 return table 233 else: 234 return constraint.table 235 236 237def _columns_for_constraint(constraint): 238 if isinstance(constraint, ForeignKeyConstraint): 239 return [fk.parent for fk in constraint.elements] 240 elif isinstance(constraint, CheckConstraint): 241 return _find_columns(constraint.sqltext) 242 else: 243 return list(constraint.columns) 244 245 246def _reflect_table( 247 inspector: "Inspector", table: "Table", include_cols: None 248) -> None: 249 if sqla_14: 250 return inspector.reflect_table(table, None) 251 else: 252 return inspector.reflecttable(table, None) 253 254 255def _fk_spec(constraint): 256 source_columns = [ 257 constraint.columns[key].name for key in constraint.column_keys 258 ] 259 260 source_table = constraint.parent.name 261 source_schema = constraint.parent.schema 262 target_schema = constraint.elements[0].column.table.schema 263 target_table = constraint.elements[0].column.table.name 264 target_columns = [element.column.name for element in constraint.elements] 265 ondelete = constraint.ondelete 266 onupdate = constraint.onupdate 267 deferrable = constraint.deferrable 268 initially = constraint.initially 269 return ( 270 source_schema, 271 source_table, 272 source_columns, 273 target_schema, 274 target_table, 275 target_columns, 276 onupdate, 277 ondelete, 278 deferrable, 279 initially, 280 ) 281 282 283def _fk_is_self_referential(constraint: "ForeignKeyConstraint") -> bool: 284 spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined] 285 tokens = spec.split(".") 286 tokens.pop(-1) # colname 287 tablekey = ".".join(tokens) 288 assert constraint.parent is not None 289 return tablekey == constraint.parent.key 290 291 292def _is_type_bound(constraint: "Constraint") -> bool: 293 # this deals with SQLAlchemy #3260, don't copy CHECK constraints 294 # that will be generated by the type. 295 # new feature added for #3260 296 return constraint._type_bound # type: ignore[attr-defined] 297 298 299def _find_columns(clause): 300 """locate Column objects within the given expression.""" 301 302 cols = set() 303 traverse(clause, {}, {"column": cols.add}) 304 return cols 305 306 307def _remove_column_from_collection( 308 collection: "ColumnCollection", column: Union["Column", "ColumnClause"] 309) -> None: 310 """remove a column from a ColumnCollection.""" 311 312 # workaround for older SQLAlchemy, remove the 313 # same object that's present 314 assert column.key is not None 315 to_remove = collection[column.key] 316 collection.remove(to_remove) 317 318 319def _textual_index_column( 320 table: "Table", text_: Union[str, "TextClause", "ColumnElement"] 321) -> Union["ColumnElement", "Column"]: 322 """a workaround for the Index construct's severe lack of flexibility""" 323 if isinstance(text_, compat.string_types): 324 c = Column(text_, sqltypes.NULLTYPE) 325 table.append_column(c) 326 return c 327 elif isinstance(text_, TextClause): 328 return _textual_index_element(table, text_) 329 elif isinstance(text_, sql.ColumnElement): 330 return _copy_expression(text_, table) 331 else: 332 raise ValueError("String or text() construct expected") 333 334 335def _copy_expression(expression: _CE, target_table: "Table") -> _CE: 336 def replace(col): 337 if ( 338 isinstance(col, Column) 339 and col.table is not None 340 and col.table is not target_table 341 ): 342 if col.name in target_table.c: 343 return target_table.c[col.name] 344 else: 345 c = _copy(col) 346 target_table.append_column(c) 347 return c 348 else: 349 return None 350 351 return visitors.replacement_traverse(expression, {}, replace) 352 353 354class _textual_index_element(sql.ColumnElement): 355 """Wrap around a sqlalchemy text() construct in such a way that 356 we appear like a column-oriented SQL expression to an Index 357 construct. 358 359 The issue here is that currently the Postgresql dialect, the biggest 360 recipient of functional indexes, keys all the index expressions to 361 the corresponding column expressions when rendering CREATE INDEX, 362 so the Index we create here needs to have a .columns collection that 363 is the same length as the .expressions collection. Ultimately 364 SQLAlchemy should support text() expressions in indexes. 365 366 See SQLAlchemy issue 3174. 367 368 """ 369 370 __visit_name__ = "_textual_idx_element" 371 372 def __init__(self, table: "Table", text: "TextClause") -> None: 373 self.table = table 374 self.text = text 375 self.key = text.text 376 self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE) 377 table.append_column(self.fake_column) 378 379 def get_children(self): 380 return [self.fake_column] 381 382 383@compiles(_textual_index_element) 384def _render_textual_index_column( 385 element: _textual_index_element, compiler: "SQLCompiler", **kw 386) -> str: 387 return compiler.process(element.text, **kw) 388 389 390class _literal_bindparam(BindParameter): 391 pass 392 393 394@compiles(_literal_bindparam) 395def _render_literal_bindparam( 396 element: _literal_bindparam, compiler: "SQLCompiler", **kw 397) -> str: 398 return compiler.render_literal_bindparam(element, **kw) 399 400 401def _get_index_expressions(idx): 402 return list(idx.expressions) 403 404 405def _get_index_column_names(idx): 406 return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)] 407 408 409def _column_kwargs(col: "Column") -> Mapping: 410 if sqla_13: 411 return col.kwargs 412 else: 413 return {} 414 415 416def _get_constraint_final_name( 417 constraint: Union["Index", "Constraint"], dialect: Optional["Dialect"] 418) -> Optional[str]: 419 if constraint.name is None: 420 return None 421 assert dialect is not None 422 if sqla_14: 423 # for SQLAlchemy 1.4 we would like to have the option to expand 424 # the use of "deferred" names for constraints as well as to have 425 # some flexibility with "None" name and similar; make use of new 426 # SQLAlchemy API to return what would be the final compiled form of 427 # the name for this dialect. 428 return dialect.identifier_preparer.format_constraint( 429 constraint, _alembic_quote=False 430 ) 431 else: 432 433 # prior to SQLAlchemy 1.4, work around quoting logic to get at the 434 # final compiled name without quotes. 435 if hasattr(constraint.name, "quote"): 436 # might be quoted_name, might be truncated_name, keep it the 437 # same 438 quoted_name_cls: type = type(constraint.name) 439 else: 440 quoted_name_cls = quoted_name 441 442 new_name = quoted_name_cls(str(constraint.name), quote=False) 443 constraint = constraint.__class__(name=new_name) 444 445 if isinstance(constraint, schema.Index): 446 # name should not be quoted. 447 d = dialect.ddl_compiler(dialect, None) 448 return d._prepared_index_name( # type: ignore[attr-defined] 449 constraint 450 ) 451 else: 452 # name should not be quoted. 453 return dialect.identifier_preparer.format_constraint(constraint) 454 455 456def _constraint_is_named( 457 constraint: Union["Constraint", "Index"], dialect: Optional["Dialect"] 458) -> bool: 459 if sqla_14: 460 if constraint.name is None: 461 return False 462 assert dialect is not None 463 name = dialect.identifier_preparer.format_constraint( 464 constraint, _alembic_quote=False 465 ) 466 return name is not None 467 else: 468 return constraint.name is not None 469 470 471def _is_mariadb(mysql_dialect: "Dialect") -> bool: 472 if sqla_14: 473 return mysql_dialect.is_mariadb # type: ignore[attr-defined] 474 else: 475 return bool( 476 mysql_dialect.server_version_info 477 and mysql_dialect._is_mariadb # type: ignore[attr-defined] 478 ) 479 480 481def _mariadb_normalized_version_info(mysql_dialect): 482 return mysql_dialect._mariadb_normalized_version_info 483 484 485def _insert_inline(table: Union["TableClause", "Table"]) -> "Insert": 486 if sqla_14: 487 return table.insert().inline() 488 else: 489 return table.insert(inline=True) 490 491 492if sqla_14: 493 from sqlalchemy import create_mock_engine 494 from sqlalchemy import select as _select 495else: 496 from sqlalchemy import create_engine 497 498 def create_mock_engine(url, executor, **kw): # type: ignore[misc] 499 return create_engine( 500 "postgresql://", strategy="mock", executor=executor 501 ) 502 503 def _select(*columns, **kw) -> "Select": 504 return sql.select(list(columns), **kw) 505