1import re 2 3from sqlalchemy import __version__ 4from sqlalchemy import inspect 5from sqlalchemy import schema 6from sqlalchemy import sql 7from sqlalchemy import types as sqltypes 8from sqlalchemy.engine import url 9from sqlalchemy.ext.compiler import compiles 10from sqlalchemy.schema import CheckConstraint 11from sqlalchemy.schema import Column 12from sqlalchemy.schema import ForeignKeyConstraint 13from sqlalchemy.sql.elements import quoted_name 14from sqlalchemy.sql.expression import _BindParamClause 15from sqlalchemy.sql.expression import _TextClause as TextClause 16from sqlalchemy.sql.visitors import traverse 17 18from . import compat 19 20 21def _safe_int(value): 22 try: 23 return int(value) 24 except: 25 return value 26 27 28_vers = tuple( 29 [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)] 30) 31sqla_110 = _vers >= (1, 1, 0) 32sqla_1115 = _vers >= (1, 1, 15) 33sqla_120 = _vers >= (1, 2, 0) 34sqla_1216 = _vers >= (1, 2, 16) 35sqla_13 = _vers >= (1, 3) 36sqla_14 = _vers >= (1, 4) 37try: 38 from sqlalchemy import Computed # noqa 39 40 has_computed = True 41 42 has_computed_reflection = _vers >= (1, 3, 16) 43except ImportError: 44 has_computed = False 45 has_computed_reflection = False 46 47AUTOINCREMENT_DEFAULT = "auto" 48 49 50def _create_url(*arg, **kw): 51 if hasattr(url.URL, "create"): 52 return url.URL.create(*arg, **kw) 53 else: 54 return url.URL(*arg, **kw) 55 56 57def _connectable_has_table(connectable, tablename, schemaname): 58 if sqla_14: 59 return inspect(connectable).has_table(tablename, schemaname) 60 else: 61 return connectable.dialect.has_table( 62 connectable, tablename, schemaname 63 ) 64 65 66def _exec_on_inspector(inspector, statement, **params): 67 if sqla_14: 68 with inspector._operation_context() as conn: 69 return conn.execute(statement, params) 70 else: 71 return inspector.bind.execute(statement, params) 72 73 74def _server_default_is_computed(column): 75 if not has_computed: 76 return False 77 else: 78 return isinstance(column.computed, Computed) 79 80 81def _table_for_constraint(constraint): 82 if isinstance(constraint, ForeignKeyConstraint): 83 return constraint.parent 84 else: 85 return constraint.table 86 87 88def _columns_for_constraint(constraint): 89 if isinstance(constraint, ForeignKeyConstraint): 90 return [fk.parent for fk in constraint.elements] 91 elif isinstance(constraint, CheckConstraint): 92 return _find_columns(constraint.sqltext) 93 else: 94 return list(constraint.columns) 95 96 97def _fk_spec(constraint): 98 source_columns = [ 99 constraint.columns[key].name for key in constraint.column_keys 100 ] 101 102 source_table = constraint.parent.name 103 source_schema = constraint.parent.schema 104 target_schema = constraint.elements[0].column.table.schema 105 target_table = constraint.elements[0].column.table.name 106 target_columns = [element.column.name for element in constraint.elements] 107 ondelete = constraint.ondelete 108 onupdate = constraint.onupdate 109 deferrable = constraint.deferrable 110 initially = constraint.initially 111 return ( 112 source_schema, 113 source_table, 114 source_columns, 115 target_schema, 116 target_table, 117 target_columns, 118 onupdate, 119 ondelete, 120 deferrable, 121 initially, 122 ) 123 124 125def _fk_is_self_referential(constraint): 126 spec = constraint.elements[0]._get_colspec() 127 tokens = spec.split(".") 128 tokens.pop(-1) # colname 129 tablekey = ".".join(tokens) 130 return tablekey == constraint.parent.key 131 132 133def _is_type_bound(constraint): 134 # this deals with SQLAlchemy #3260, don't copy CHECK constraints 135 # that will be generated by the type. 136 # new feature added for #3260 137 return constraint._type_bound 138 139 140def _find_columns(clause): 141 """locate Column objects within the given expression.""" 142 143 cols = set() 144 traverse(clause, {}, {"column": cols.add}) 145 return cols 146 147 148def _remove_column_from_collection(collection, column): 149 """remove a column from a ColumnCollection.""" 150 151 # workaround for older SQLAlchemy, remove the 152 # same object that's present 153 to_remove = collection[column.key] 154 collection.remove(to_remove) 155 156 157def _textual_index_column(table, text_): 158 """a workaround for the Index construct's severe lack of flexibility""" 159 if isinstance(text_, compat.string_types): 160 c = Column(text_, sqltypes.NULLTYPE) 161 table.append_column(c) 162 return c 163 elif isinstance(text_, TextClause): 164 return _textual_index_element(table, text_) 165 else: 166 raise ValueError("String or text() construct expected") 167 168 169class _textual_index_element(sql.ColumnElement): 170 """Wrap around a sqlalchemy text() construct in such a way that 171 we appear like a column-oriented SQL expression to an Index 172 construct. 173 174 The issue here is that currently the Postgresql dialect, the biggest 175 recipient of functional indexes, keys all the index expressions to 176 the corresponding column expressions when rendering CREATE INDEX, 177 so the Index we create here needs to have a .columns collection that 178 is the same length as the .expressions collection. Ultimately 179 SQLAlchemy should support text() expressions in indexes. 180 181 See SQLAlchemy issue 3174. 182 183 """ 184 185 __visit_name__ = "_textual_idx_element" 186 187 def __init__(self, table, text): 188 self.table = table 189 self.text = text 190 self.key = text.text 191 self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE) 192 table.append_column(self.fake_column) 193 194 def get_children(self): 195 return [self.fake_column] 196 197 198@compiles(_textual_index_element) 199def _render_textual_index_column(element, compiler, **kw): 200 return compiler.process(element.text, **kw) 201 202 203class _literal_bindparam(_BindParamClause): 204 pass 205 206 207@compiles(_literal_bindparam) 208def _render_literal_bindparam(element, compiler, **kw): 209 return compiler.render_literal_bindparam(element, **kw) 210 211 212def _get_index_expressions(idx): 213 return list(idx.expressions) 214 215 216def _get_index_column_names(idx): 217 return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)] 218 219 220def _column_kwargs(col): 221 if sqla_13: 222 return col.kwargs 223 else: 224 return {} 225 226 227def _get_constraint_final_name(constraint, dialect): 228 if constraint.name is None: 229 return None 230 elif sqla_14: 231 # for SQLAlchemy 1.4 we would like to have the option to expand 232 # the use of "deferred" names for constraints as well as to have 233 # some flexibility with "None" name and similar; make use of new 234 # SQLAlchemy API to return what would be the final compiled form of 235 # the name for this dialect. 236 return dialect.identifier_preparer.format_constraint( 237 constraint, _alembic_quote=False 238 ) 239 else: 240 241 # prior to SQLAlchemy 1.4, work around quoting logic to get at the 242 # final compiled name without quotes. 243 if hasattr(constraint.name, "quote"): 244 # might be quoted_name, might be truncated_name, keep it the 245 # same 246 quoted_name_cls = type(constraint.name) 247 else: 248 quoted_name_cls = quoted_name 249 250 new_name = quoted_name_cls(str(constraint.name), quote=False) 251 constraint = constraint.__class__(name=new_name) 252 253 if isinstance(constraint, schema.Index): 254 # name should not be quoted. 255 return dialect.ddl_compiler(dialect, None)._prepared_index_name( 256 constraint 257 ) 258 else: 259 # name should not be quoted. 260 return dialect.identifier_preparer.format_constraint(constraint) 261 262 263def _constraint_is_named(constraint, dialect): 264 if sqla_14: 265 if constraint.name is None: 266 return False 267 name = dialect.identifier_preparer.format_constraint( 268 constraint, _alembic_quote=False 269 ) 270 return name is not None 271 else: 272 return constraint.name is not None 273 274 275def _dialect_supports_comments(dialect): 276 if sqla_120: 277 return dialect.supports_comments 278 else: 279 return False 280 281 282def _comment_attribute(obj): 283 """return the .comment attribute from a Table or Column""" 284 285 if sqla_120: 286 return obj.comment 287 else: 288 return None 289 290 291def _is_mariadb(mysql_dialect): 292 if sqla_14: 293 return mysql_dialect.is_mariadb 294 else: 295 return mysql_dialect.server_version_info and mysql_dialect._is_mariadb 296 297 298def _mariadb_normalized_version_info(mysql_dialect): 299 return mysql_dialect._mariadb_normalized_version_info 300 301 302if sqla_14: 303 from sqlalchemy import create_mock_engine 304else: 305 from sqlalchemy import create_engine 306 307 def create_mock_engine(url, executor): 308 return create_engine( 309 "postgresql://", strategy="mock", executor=executor 310 ) 311