1from typing import Any 2from typing import Dict 3from typing import List 4from typing import Optional 5from typing import Sequence 6from typing import Tuple 7from typing import TYPE_CHECKING 8from typing import Union 9 10from sqlalchemy import schema as sa_schema 11from sqlalchemy.sql.schema import Column 12from sqlalchemy.sql.schema import Constraint 13from sqlalchemy.sql.schema import Index 14from sqlalchemy.types import Integer 15from sqlalchemy.types import NULLTYPE 16 17from .. import util 18from ..util import sqla_compat 19from ..util.compat import string_types 20 21if TYPE_CHECKING: 22 from sqlalchemy.sql.elements import ColumnElement 23 from sqlalchemy.sql.elements import TextClause 24 from sqlalchemy.sql.schema import CheckConstraint 25 from sqlalchemy.sql.schema import ForeignKey 26 from sqlalchemy.sql.schema import ForeignKeyConstraint 27 from sqlalchemy.sql.schema import MetaData 28 from sqlalchemy.sql.schema import PrimaryKeyConstraint 29 from sqlalchemy.sql.schema import Table 30 from sqlalchemy.sql.schema import UniqueConstraint 31 from sqlalchemy.sql.type_api import TypeEngine 32 33 from ..runtime.migration import MigrationContext 34 35 36class SchemaObjects: 37 def __init__( 38 self, migration_context: Optional["MigrationContext"] = None 39 ) -> None: 40 self.migration_context = migration_context 41 42 def primary_key_constraint( 43 self, 44 name: Optional[str], 45 table_name: str, 46 cols: Sequence[str], 47 schema: Optional[str] = None, 48 **dialect_kw 49 ) -> "PrimaryKeyConstraint": 50 m = self.metadata() 51 columns = [sa_schema.Column(n, NULLTYPE) for n in cols] 52 t = sa_schema.Table(table_name, m, *columns, schema=schema) 53 p = sa_schema.PrimaryKeyConstraint( 54 *[t.c[n] for n in cols], name=name, **dialect_kw 55 ) 56 return p 57 58 def foreign_key_constraint( 59 self, 60 name: Optional[str], 61 source: str, 62 referent: str, 63 local_cols: List[str], 64 remote_cols: List[str], 65 onupdate: Optional[str] = None, 66 ondelete: Optional[str] = None, 67 deferrable: Optional[bool] = None, 68 source_schema: Optional[str] = None, 69 referent_schema: Optional[str] = None, 70 initially: Optional[str] = None, 71 match: Optional[str] = None, 72 **dialect_kw 73 ) -> "ForeignKeyConstraint": 74 m = self.metadata() 75 if source == referent and source_schema == referent_schema: 76 t1_cols = local_cols + remote_cols 77 else: 78 t1_cols = local_cols 79 sa_schema.Table( 80 referent, 81 m, 82 *[sa_schema.Column(n, NULLTYPE) for n in remote_cols], 83 schema=referent_schema 84 ) 85 86 t1 = sa_schema.Table( 87 source, 88 m, 89 *[sa_schema.Column(n, NULLTYPE) for n in t1_cols], 90 schema=source_schema 91 ) 92 93 tname = ( 94 "%s.%s" % (referent_schema, referent) 95 if referent_schema 96 else referent 97 ) 98 99 dialect_kw["match"] = match 100 101 f = sa_schema.ForeignKeyConstraint( 102 local_cols, 103 ["%s.%s" % (tname, n) for n in remote_cols], 104 name=name, 105 onupdate=onupdate, 106 ondelete=ondelete, 107 deferrable=deferrable, 108 initially=initially, 109 **dialect_kw 110 ) 111 t1.append_constraint(f) 112 113 return f 114 115 def unique_constraint( 116 self, 117 name: Optional[str], 118 source: str, 119 local_cols: Sequence[str], 120 schema: Optional[str] = None, 121 **kw 122 ) -> "UniqueConstraint": 123 t = sa_schema.Table( 124 source, 125 self.metadata(), 126 *[sa_schema.Column(n, NULLTYPE) for n in local_cols], 127 schema=schema 128 ) 129 kw["name"] = name 130 uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw) 131 # TODO: need event tests to ensure the event 132 # is fired off here 133 t.append_constraint(uq) 134 return uq 135 136 def check_constraint( 137 self, 138 name: Optional[str], 139 source: str, 140 condition: Union[str, "TextClause", "ColumnElement[Any]"], 141 schema: Optional[str] = None, 142 **kw 143 ) -> Union["CheckConstraint"]: 144 t = sa_schema.Table( 145 source, 146 self.metadata(), 147 sa_schema.Column("x", Integer), 148 schema=schema, 149 ) 150 ck = sa_schema.CheckConstraint(condition, name=name, **kw) 151 t.append_constraint(ck) 152 return ck 153 154 def generic_constraint( 155 self, 156 name: Optional[str], 157 table_name: str, 158 type_: Optional[str], 159 schema: Optional[str] = None, 160 **kw 161 ) -> Any: 162 t = self.table(table_name, schema=schema) 163 types: Dict[Optional[str], Any] = { 164 "foreignkey": lambda name: sa_schema.ForeignKeyConstraint( 165 [], [], name=name 166 ), 167 "primary": sa_schema.PrimaryKeyConstraint, 168 "unique": sa_schema.UniqueConstraint, 169 "check": lambda name: sa_schema.CheckConstraint("", name=name), 170 None: sa_schema.Constraint, 171 } 172 try: 173 const = types[type_] 174 except KeyError as ke: 175 raise TypeError( 176 "'type' can be one of %s" 177 % ", ".join(sorted(repr(x) for x in types)) 178 ) from ke 179 else: 180 const = const(name=name) 181 t.append_constraint(const) 182 return const 183 184 def metadata(self) -> "MetaData": 185 kw = {} 186 if ( 187 self.migration_context is not None 188 and "target_metadata" in self.migration_context.opts 189 ): 190 mt = self.migration_context.opts["target_metadata"] 191 if hasattr(mt, "naming_convention"): 192 kw["naming_convention"] = mt.naming_convention 193 return sa_schema.MetaData(**kw) 194 195 def table(self, name: str, *columns, **kw) -> "Table": 196 m = self.metadata() 197 198 cols = [ 199 sqla_compat._copy(c) if c.table is not None else c 200 for c in columns 201 if isinstance(c, Column) 202 ] 203 # these flags have already added their UniqueConstraint / 204 # Index objects to the table, so flip them off here. 205 # SQLAlchemy tometadata() avoids this instead by preserving the 206 # flags and skipping the constraints that have _type_bound on them, 207 # but for a migration we'd rather list out the constraints 208 # explicitly. 209 _constraints_included = kw.pop("_constraints_included", False) 210 if _constraints_included: 211 for c in cols: 212 c.unique = c.index = False 213 214 t = sa_schema.Table(name, m, *cols, **kw) 215 216 constraints = [ 217 sqla_compat._copy(elem, target_table=t) 218 if getattr(elem, "parent", None) is not None 219 else elem 220 for elem in columns 221 if isinstance(elem, (Constraint, Index)) 222 ] 223 224 for const in constraints: 225 t.append_constraint(const) 226 227 for f in t.foreign_keys: 228 self._ensure_table_for_fk(m, f) 229 return t 230 231 def column(self, name: str, type_: "TypeEngine", **kw) -> "Column": 232 return sa_schema.Column(name, type_, **kw) 233 234 def index( 235 self, 236 name: str, 237 tablename: Optional[str], 238 columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]], 239 schema: Optional[str] = None, 240 **kw 241 ) -> "Index": 242 t = sa_schema.Table( 243 tablename or "no_table", 244 self.metadata(), 245 schema=schema, 246 ) 247 kw["_table"] = t 248 idx = sa_schema.Index( 249 name, 250 *[util.sqla_compat._textual_index_column(t, n) for n in columns], 251 **kw 252 ) 253 return idx 254 255 def _parse_table_key(self, table_key: str) -> Tuple[Optional[str], str]: 256 if "." in table_key: 257 tokens = table_key.split(".") 258 sname: Optional[str] = ".".join(tokens[0:-1]) 259 tname = tokens[-1] 260 else: 261 tname = table_key 262 sname = None 263 return (sname, tname) 264 265 def _ensure_table_for_fk( 266 self, metadata: "MetaData", fk: "ForeignKey" 267 ) -> None: 268 """create a placeholder Table object for the referent of a 269 ForeignKey. 270 271 """ 272 if isinstance(fk._colspec, string_types): # type:ignore[attr-defined] 273 table_key, cname = fk._colspec.rsplit( # type:ignore[attr-defined] 274 ".", 1 275 ) 276 sname, tname = self._parse_table_key(table_key) 277 if table_key not in metadata.tables: 278 rel_t = sa_schema.Table(tname, metadata, schema=sname) 279 else: 280 rel_t = metadata.tables[table_key] 281 if cname not in rel_t.c: 282 rel_t.append_column(sa_schema.Column(cname, NULLTYPE)) 283