1import re
2
3from sqlalchemy import __version__
4from sqlalchemy import inspect
5from sqlalchemy import schema
6from sqlalchemy import sql
7from sqlalchemy import types as sqltypes
8from sqlalchemy.engine import url
9from sqlalchemy.ext.compiler import compiles
10from sqlalchemy.schema import CheckConstraint
11from sqlalchemy.schema import Column
12from sqlalchemy.schema import ForeignKeyConstraint
13from sqlalchemy.sql.elements import quoted_name
14from sqlalchemy.sql.expression import _BindParamClause
15from sqlalchemy.sql.expression import _TextClause as TextClause
16from sqlalchemy.sql.visitors import traverse
17
18from . import compat
19
20
21def _safe_int(value):
22    try:
23        return int(value)
24    except:
25        return value
26
27
28_vers = tuple(
29    [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
30)
31sqla_110 = _vers >= (1, 1, 0)
32sqla_1115 = _vers >= (1, 1, 15)
33sqla_120 = _vers >= (1, 2, 0)
34sqla_1216 = _vers >= (1, 2, 16)
35sqla_13 = _vers >= (1, 3)
36sqla_14 = _vers >= (1, 4)
37try:
38    from sqlalchemy import Computed  # noqa
39
40    has_computed = True
41
42    has_computed_reflection = _vers >= (1, 3, 16)
43except ImportError:
44    has_computed = False
45    has_computed_reflection = False
46
47AUTOINCREMENT_DEFAULT = "auto"
48
49
50def _create_url(*arg, **kw):
51    if hasattr(url.URL, "create"):
52        return url.URL.create(*arg, **kw)
53    else:
54        return url.URL(*arg, **kw)
55
56
57def _connectable_has_table(connectable, tablename, schemaname):
58    if sqla_14:
59        return inspect(connectable).has_table(tablename, schemaname)
60    else:
61        return connectable.dialect.has_table(
62            connectable, tablename, schemaname
63        )
64
65
66def _exec_on_inspector(inspector, statement, **params):
67    if sqla_14:
68        with inspector._operation_context() as conn:
69            return conn.execute(statement, params)
70    else:
71        return inspector.bind.execute(statement, params)
72
73
74def _server_default_is_computed(column):
75    if not has_computed:
76        return False
77    else:
78        return isinstance(column.computed, Computed)
79
80
81def _table_for_constraint(constraint):
82    if isinstance(constraint, ForeignKeyConstraint):
83        return constraint.parent
84    else:
85        return constraint.table
86
87
88def _columns_for_constraint(constraint):
89    if isinstance(constraint, ForeignKeyConstraint):
90        return [fk.parent for fk in constraint.elements]
91    elif isinstance(constraint, CheckConstraint):
92        return _find_columns(constraint.sqltext)
93    else:
94        return list(constraint.columns)
95
96
97def _fk_spec(constraint):
98    source_columns = [
99        constraint.columns[key].name for key in constraint.column_keys
100    ]
101
102    source_table = constraint.parent.name
103    source_schema = constraint.parent.schema
104    target_schema = constraint.elements[0].column.table.schema
105    target_table = constraint.elements[0].column.table.name
106    target_columns = [element.column.name for element in constraint.elements]
107    ondelete = constraint.ondelete
108    onupdate = constraint.onupdate
109    deferrable = constraint.deferrable
110    initially = constraint.initially
111    return (
112        source_schema,
113        source_table,
114        source_columns,
115        target_schema,
116        target_table,
117        target_columns,
118        onupdate,
119        ondelete,
120        deferrable,
121        initially,
122    )
123
124
125def _fk_is_self_referential(constraint):
126    spec = constraint.elements[0]._get_colspec()
127    tokens = spec.split(".")
128    tokens.pop(-1)  # colname
129    tablekey = ".".join(tokens)
130    return tablekey == constraint.parent.key
131
132
133def _is_type_bound(constraint):
134    # this deals with SQLAlchemy #3260, don't copy CHECK constraints
135    # that will be generated by the type.
136    # new feature added for #3260
137    return constraint._type_bound
138
139
140def _find_columns(clause):
141    """locate Column objects within the given expression."""
142
143    cols = set()
144    traverse(clause, {}, {"column": cols.add})
145    return cols
146
147
148def _remove_column_from_collection(collection, column):
149    """remove a column from a ColumnCollection."""
150
151    # workaround for older SQLAlchemy, remove the
152    # same object that's present
153    to_remove = collection[column.key]
154    collection.remove(to_remove)
155
156
157def _textual_index_column(table, text_):
158    """a workaround for the Index construct's severe lack of flexibility"""
159    if isinstance(text_, compat.string_types):
160        c = Column(text_, sqltypes.NULLTYPE)
161        table.append_column(c)
162        return c
163    elif isinstance(text_, TextClause):
164        return _textual_index_element(table, text_)
165    else:
166        raise ValueError("String or text() construct expected")
167
168
169class _textual_index_element(sql.ColumnElement):
170    """Wrap around a sqlalchemy text() construct in such a way that
171    we appear like a column-oriented SQL expression to an Index
172    construct.
173
174    The issue here is that currently the Postgresql dialect, the biggest
175    recipient of functional indexes, keys all the index expressions to
176    the corresponding column expressions when rendering CREATE INDEX,
177    so the Index we create here needs to have a .columns collection that
178    is the same length as the .expressions collection.  Ultimately
179    SQLAlchemy should support text() expressions in indexes.
180
181    See SQLAlchemy issue 3174.
182
183    """
184
185    __visit_name__ = "_textual_idx_element"
186
187    def __init__(self, table, text):
188        self.table = table
189        self.text = text
190        self.key = text.text
191        self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE)
192        table.append_column(self.fake_column)
193
194    def get_children(self):
195        return [self.fake_column]
196
197
198@compiles(_textual_index_element)
199def _render_textual_index_column(element, compiler, **kw):
200    return compiler.process(element.text, **kw)
201
202
203class _literal_bindparam(_BindParamClause):
204    pass
205
206
207@compiles(_literal_bindparam)
208def _render_literal_bindparam(element, compiler, **kw):
209    return compiler.render_literal_bindparam(element, **kw)
210
211
212def _get_index_expressions(idx):
213    return list(idx.expressions)
214
215
216def _get_index_column_names(idx):
217    return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
218
219
220def _column_kwargs(col):
221    if sqla_13:
222        return col.kwargs
223    else:
224        return {}
225
226
227def _get_constraint_final_name(constraint, dialect):
228    if constraint.name is None:
229        return None
230    elif sqla_14:
231        # for SQLAlchemy 1.4 we would like to have the option to expand
232        # the use of "deferred" names for constraints as well as to have
233        # some flexibility with "None" name and similar; make use of new
234        # SQLAlchemy API to return what would be the final compiled form of
235        # the name for this dialect.
236        return dialect.identifier_preparer.format_constraint(
237            constraint, _alembic_quote=False
238        )
239    else:
240
241        # prior to SQLAlchemy 1.4, work around quoting logic to get at the
242        # final compiled name without quotes.
243        if hasattr(constraint.name, "quote"):
244            # might be quoted_name, might be truncated_name, keep it the
245            # same
246            quoted_name_cls = type(constraint.name)
247        else:
248            quoted_name_cls = quoted_name
249
250        new_name = quoted_name_cls(str(constraint.name), quote=False)
251        constraint = constraint.__class__(name=new_name)
252
253        if isinstance(constraint, schema.Index):
254            # name should not be quoted.
255            return dialect.ddl_compiler(dialect, None)._prepared_index_name(
256                constraint
257            )
258        else:
259            # name should not be quoted.
260            return dialect.identifier_preparer.format_constraint(constraint)
261
262
263def _constraint_is_named(constraint, dialect):
264    if sqla_14:
265        if constraint.name is None:
266            return False
267        name = dialect.identifier_preparer.format_constraint(
268            constraint, _alembic_quote=False
269        )
270        return name is not None
271    else:
272        return constraint.name is not None
273
274
275def _dialect_supports_comments(dialect):
276    if sqla_120:
277        return dialect.supports_comments
278    else:
279        return False
280
281
282def _comment_attribute(obj):
283    """return the .comment attribute from a Table or Column"""
284
285    if sqla_120:
286        return obj.comment
287    else:
288        return None
289
290
291def _is_mariadb(mysql_dialect):
292    if sqla_14:
293        return mysql_dialect.is_mariadb
294    else:
295        return mysql_dialect.server_version_info and mysql_dialect._is_mariadb
296
297
298def _mariadb_normalized_version_info(mysql_dialect):
299    return mysql_dialect._mariadb_normalized_version_info
300
301
302if sqla_14:
303    from sqlalchemy import create_mock_engine
304else:
305    from sqlalchemy import create_engine
306
307    def create_mock_engine(url, executor):
308        return create_engine(
309            "postgresql://", strategy="mock", executor=executor
310        )
311