1import contextlib
2import re
3from typing import Iterator
4from typing import Mapping
5from typing import Optional
6from typing import TYPE_CHECKING
7from typing import TypeVar
8from typing import Union
9
10from sqlalchemy import __version__
11from sqlalchemy import inspect
12from sqlalchemy import schema
13from sqlalchemy import sql
14from sqlalchemy import types as sqltypes
15from sqlalchemy.engine import url
16from sqlalchemy.ext.compiler import compiles
17from sqlalchemy.schema import CheckConstraint
18from sqlalchemy.schema import Column
19from sqlalchemy.schema import ForeignKeyConstraint
20from sqlalchemy.sql import visitors
21from sqlalchemy.sql.elements import BindParameter
22from sqlalchemy.sql.elements import quoted_name
23from sqlalchemy.sql.elements import TextClause
24from sqlalchemy.sql.visitors import traverse
25
26from . import compat
27
28if TYPE_CHECKING:
29    from sqlalchemy import Index
30    from sqlalchemy import Table
31    from sqlalchemy.engine import Connection
32    from sqlalchemy.engine import Dialect
33    from sqlalchemy.engine import Transaction
34    from sqlalchemy.engine.reflection import Inspector
35    from sqlalchemy.sql.base import ColumnCollection
36    from sqlalchemy.sql.compiler import SQLCompiler
37    from sqlalchemy.sql.dml import Insert
38    from sqlalchemy.sql.elements import ColumnClause
39    from sqlalchemy.sql.elements import ColumnElement
40    from sqlalchemy.sql.schema import Constraint
41    from sqlalchemy.sql.schema import SchemaItem
42    from sqlalchemy.sql.selectable import Select
43    from sqlalchemy.sql.selectable import TableClause
44
45_CE = TypeVar("_CE", bound=Union["ColumnElement", "SchemaItem"])
46
47
48def _safe_int(value: str) -> Union[int, str]:
49    try:
50        return int(value)
51    except:
52        return value
53
54
55_vers = tuple(
56    [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
57)
58sqla_13 = _vers >= (1, 3)
59sqla_14 = _vers >= (1, 4)
60sqla_14_26 = _vers >= (1, 4, 26)
61
62
63if sqla_14:
64    # when future engine merges, this can be again based on version string
65    from sqlalchemy.engine import Connection as legacy_connection
66
67    sqla_1x = not hasattr(legacy_connection, "commit")
68else:
69    sqla_1x = True
70
71try:
72    from sqlalchemy import Computed  # noqa
73except ImportError:
74    Computed = type(None)  # type: ignore
75    has_computed = False
76    has_computed_reflection = False
77else:
78    has_computed = True
79    has_computed_reflection = _vers >= (1, 3, 16)
80
81try:
82    from sqlalchemy import Identity  # noqa
83except ImportError:
84    Identity = type(None)  # type: ignore
85    has_identity = False
86else:
87    # attributes common to Indentity and Sequence
88    _identity_options_attrs = (
89        "start",
90        "increment",
91        "minvalue",
92        "maxvalue",
93        "nominvalue",
94        "nomaxvalue",
95        "cycle",
96        "cache",
97        "order",
98    )
99    # attributes of Indentity
100    _identity_attrs = _identity_options_attrs + ("on_null",)
101    has_identity = True
102
103AUTOINCREMENT_DEFAULT = "auto"
104
105
106@contextlib.contextmanager
107def _ensure_scope_for_ddl(
108    connection: Optional["Connection"],
109) -> Iterator[None]:
110    try:
111        in_transaction = connection.in_transaction  # type: ignore[union-attr]
112    except AttributeError:
113        # catch for MockConnection, None
114        yield
115    else:
116        if not in_transaction():
117            assert connection is not None
118            with connection.begin():
119                yield
120        else:
121            yield
122
123
124def _safe_begin_connection_transaction(
125    connection: "Connection",
126) -> "Transaction":
127    transaction = _get_connection_transaction(connection)
128    if transaction:
129        return transaction
130    else:
131        return connection.begin()
132
133
134def _safe_commit_connection_transaction(
135    connection: "Connection",
136) -> None:
137    transaction = _get_connection_transaction(connection)
138    if transaction:
139        transaction.commit()
140
141
142def _safe_rollback_connection_transaction(
143    connection: "Connection",
144) -> None:
145    transaction = _get_connection_transaction(connection)
146    if transaction:
147        transaction.rollback()
148
149
150def _get_connection_in_transaction(connection: Optional["Connection"]) -> bool:
151    try:
152        in_transaction = connection.in_transaction  # type: ignore
153    except AttributeError:
154        # catch for MockConnection
155        return False
156    else:
157        return in_transaction()
158
159
160def _copy(schema_item: _CE, **kw) -> _CE:
161    if hasattr(schema_item, "_copy"):
162        return schema_item._copy(**kw)
163    else:
164        return schema_item.copy(**kw)
165
166
167def _get_connection_transaction(
168    connection: "Connection",
169) -> Optional["Transaction"]:
170    if sqla_14:
171        return connection.get_transaction()
172    else:
173        r = connection._root  # type: ignore[attr-defined]
174        return r._Connection__transaction
175
176
177def _create_url(*arg, **kw) -> url.URL:
178    if hasattr(url.URL, "create"):
179        return url.URL.create(*arg, **kw)
180    else:
181        return url.URL(*arg, **kw)
182
183
184def _connectable_has_table(
185    connectable: "Connection", tablename: str, schemaname: Union[str, None]
186) -> bool:
187    if sqla_14:
188        return inspect(connectable).has_table(tablename, schemaname)
189    else:
190        return connectable.dialect.has_table(
191            connectable, tablename, schemaname
192        )
193
194
195def _exec_on_inspector(inspector, statement, **params):
196    if sqla_14:
197        with inspector._operation_context() as conn:
198            return conn.execute(statement, params)
199    else:
200        return inspector.bind.execute(statement, params)
201
202
203def _nullability_might_be_unset(metadata_column):
204    if not sqla_14:
205        return metadata_column.nullable
206    else:
207        from sqlalchemy.sql import schema
208
209        return (
210            metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED
211        )
212
213
214def _server_default_is_computed(*server_default) -> bool:
215    if not has_computed:
216        return False
217    else:
218        return any(isinstance(sd, Computed) for sd in server_default)
219
220
221def _server_default_is_identity(*server_default) -> bool:
222    if not sqla_14:
223        return False
224    else:
225        return any(isinstance(sd, Identity) for sd in server_default)
226
227
228def _table_for_constraint(constraint: "Constraint") -> "Table":
229    if isinstance(constraint, ForeignKeyConstraint):
230        table = constraint.parent
231        assert table is not None
232        return table
233    else:
234        return constraint.table
235
236
237def _columns_for_constraint(constraint):
238    if isinstance(constraint, ForeignKeyConstraint):
239        return [fk.parent for fk in constraint.elements]
240    elif isinstance(constraint, CheckConstraint):
241        return _find_columns(constraint.sqltext)
242    else:
243        return list(constraint.columns)
244
245
246def _reflect_table(
247    inspector: "Inspector", table: "Table", include_cols: None
248) -> None:
249    if sqla_14:
250        return inspector.reflect_table(table, None)
251    else:
252        return inspector.reflecttable(table, None)
253
254
255def _fk_spec(constraint):
256    source_columns = [
257        constraint.columns[key].name for key in constraint.column_keys
258    ]
259
260    source_table = constraint.parent.name
261    source_schema = constraint.parent.schema
262    target_schema = constraint.elements[0].column.table.schema
263    target_table = constraint.elements[0].column.table.name
264    target_columns = [element.column.name for element in constraint.elements]
265    ondelete = constraint.ondelete
266    onupdate = constraint.onupdate
267    deferrable = constraint.deferrable
268    initially = constraint.initially
269    return (
270        source_schema,
271        source_table,
272        source_columns,
273        target_schema,
274        target_table,
275        target_columns,
276        onupdate,
277        ondelete,
278        deferrable,
279        initially,
280    )
281
282
283def _fk_is_self_referential(constraint: "ForeignKeyConstraint") -> bool:
284    spec = constraint.elements[0]._get_colspec()  # type: ignore[attr-defined]
285    tokens = spec.split(".")
286    tokens.pop(-1)  # colname
287    tablekey = ".".join(tokens)
288    assert constraint.parent is not None
289    return tablekey == constraint.parent.key
290
291
292def _is_type_bound(constraint: "Constraint") -> bool:
293    # this deals with SQLAlchemy #3260, don't copy CHECK constraints
294    # that will be generated by the type.
295    # new feature added for #3260
296    return constraint._type_bound  # type: ignore[attr-defined]
297
298
299def _find_columns(clause):
300    """locate Column objects within the given expression."""
301
302    cols = set()
303    traverse(clause, {}, {"column": cols.add})
304    return cols
305
306
307def _remove_column_from_collection(
308    collection: "ColumnCollection", column: Union["Column", "ColumnClause"]
309) -> None:
310    """remove a column from a ColumnCollection."""
311
312    # workaround for older SQLAlchemy, remove the
313    # same object that's present
314    assert column.key is not None
315    to_remove = collection[column.key]
316    collection.remove(to_remove)
317
318
319def _textual_index_column(
320    table: "Table", text_: Union[str, "TextClause", "ColumnElement"]
321) -> Union["ColumnElement", "Column"]:
322    """a workaround for the Index construct's severe lack of flexibility"""
323    if isinstance(text_, compat.string_types):
324        c = Column(text_, sqltypes.NULLTYPE)
325        table.append_column(c)
326        return c
327    elif isinstance(text_, TextClause):
328        return _textual_index_element(table, text_)
329    elif isinstance(text_, sql.ColumnElement):
330        return _copy_expression(text_, table)
331    else:
332        raise ValueError("String or text() construct expected")
333
334
335def _copy_expression(expression: _CE, target_table: "Table") -> _CE:
336    def replace(col):
337        if (
338            isinstance(col, Column)
339            and col.table is not None
340            and col.table is not target_table
341        ):
342            if col.name in target_table.c:
343                return target_table.c[col.name]
344            else:
345                c = _copy(col)
346                target_table.append_column(c)
347                return c
348        else:
349            return None
350
351    return visitors.replacement_traverse(expression, {}, replace)
352
353
354class _textual_index_element(sql.ColumnElement):
355    """Wrap around a sqlalchemy text() construct in such a way that
356    we appear like a column-oriented SQL expression to an Index
357    construct.
358
359    The issue here is that currently the Postgresql dialect, the biggest
360    recipient of functional indexes, keys all the index expressions to
361    the corresponding column expressions when rendering CREATE INDEX,
362    so the Index we create here needs to have a .columns collection that
363    is the same length as the .expressions collection.  Ultimately
364    SQLAlchemy should support text() expressions in indexes.
365
366    See SQLAlchemy issue 3174.
367
368    """
369
370    __visit_name__ = "_textual_idx_element"
371
372    def __init__(self, table: "Table", text: "TextClause") -> None:
373        self.table = table
374        self.text = text
375        self.key = text.text
376        self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE)
377        table.append_column(self.fake_column)
378
379    def get_children(self):
380        return [self.fake_column]
381
382
383@compiles(_textual_index_element)
384def _render_textual_index_column(
385    element: _textual_index_element, compiler: "SQLCompiler", **kw
386) -> str:
387    return compiler.process(element.text, **kw)
388
389
390class _literal_bindparam(BindParameter):
391    pass
392
393
394@compiles(_literal_bindparam)
395def _render_literal_bindparam(
396    element: _literal_bindparam, compiler: "SQLCompiler", **kw
397) -> str:
398    return compiler.render_literal_bindparam(element, **kw)
399
400
401def _get_index_expressions(idx):
402    return list(idx.expressions)
403
404
405def _get_index_column_names(idx):
406    return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
407
408
409def _column_kwargs(col: "Column") -> Mapping:
410    if sqla_13:
411        return col.kwargs
412    else:
413        return {}
414
415
416def _get_constraint_final_name(
417    constraint: Union["Index", "Constraint"], dialect: Optional["Dialect"]
418) -> Optional[str]:
419    if constraint.name is None:
420        return None
421    assert dialect is not None
422    if sqla_14:
423        # for SQLAlchemy 1.4 we would like to have the option to expand
424        # the use of "deferred" names for constraints as well as to have
425        # some flexibility with "None" name and similar; make use of new
426        # SQLAlchemy API to return what would be the final compiled form of
427        # the name for this dialect.
428        return dialect.identifier_preparer.format_constraint(
429            constraint, _alembic_quote=False
430        )
431    else:
432
433        # prior to SQLAlchemy 1.4, work around quoting logic to get at the
434        # final compiled name without quotes.
435        if hasattr(constraint.name, "quote"):
436            # might be quoted_name, might be truncated_name, keep it the
437            # same
438            quoted_name_cls: type = type(constraint.name)
439        else:
440            quoted_name_cls = quoted_name
441
442        new_name = quoted_name_cls(str(constraint.name), quote=False)
443        constraint = constraint.__class__(name=new_name)
444
445        if isinstance(constraint, schema.Index):
446            # name should not be quoted.
447            d = dialect.ddl_compiler(dialect, None)
448            return d._prepared_index_name(  # type: ignore[attr-defined]
449                constraint
450            )
451        else:
452            # name should not be quoted.
453            return dialect.identifier_preparer.format_constraint(constraint)
454
455
456def _constraint_is_named(
457    constraint: Union["Constraint", "Index"], dialect: Optional["Dialect"]
458) -> bool:
459    if sqla_14:
460        if constraint.name is None:
461            return False
462        assert dialect is not None
463        name = dialect.identifier_preparer.format_constraint(
464            constraint, _alembic_quote=False
465        )
466        return name is not None
467    else:
468        return constraint.name is not None
469
470
471def _is_mariadb(mysql_dialect: "Dialect") -> bool:
472    if sqla_14:
473        return mysql_dialect.is_mariadb  # type: ignore[attr-defined]
474    else:
475        return bool(
476            mysql_dialect.server_version_info
477            and mysql_dialect._is_mariadb  # type: ignore[attr-defined]
478        )
479
480
481def _mariadb_normalized_version_info(mysql_dialect):
482    return mysql_dialect._mariadb_normalized_version_info
483
484
485def _insert_inline(table: Union["TableClause", "Table"]) -> "Insert":
486    if sqla_14:
487        return table.insert().inline()
488    else:
489        return table.insert(inline=True)
490
491
492if sqla_14:
493    from sqlalchemy import create_mock_engine
494    from sqlalchemy import select as _select
495else:
496    from sqlalchemy import create_engine
497
498    def create_mock_engine(url, executor, **kw):  # type: ignore[misc]
499        return create_engine(
500            "postgresql://", strategy="mock", executor=executor
501        )
502
503    def _select(*columns, **kw) -> "Select":
504        return sql.select(list(columns), **kw)
505