1from typing import Any
2from typing import Dict
3from typing import List
4from typing import Optional
5from typing import Sequence
6from typing import Tuple
7from typing import TYPE_CHECKING
8from typing import Union
9
10from sqlalchemy import schema as sa_schema
11from sqlalchemy.sql.schema import Column
12from sqlalchemy.sql.schema import Constraint
13from sqlalchemy.sql.schema import Index
14from sqlalchemy.types import Integer
15from sqlalchemy.types import NULLTYPE
16
17from .. import util
18from ..util import sqla_compat
19from ..util.compat import string_types
20
21if TYPE_CHECKING:
22    from sqlalchemy.sql.elements import ColumnElement
23    from sqlalchemy.sql.elements import TextClause
24    from sqlalchemy.sql.schema import CheckConstraint
25    from sqlalchemy.sql.schema import ForeignKey
26    from sqlalchemy.sql.schema import ForeignKeyConstraint
27    from sqlalchemy.sql.schema import MetaData
28    from sqlalchemy.sql.schema import PrimaryKeyConstraint
29    from sqlalchemy.sql.schema import Table
30    from sqlalchemy.sql.schema import UniqueConstraint
31    from sqlalchemy.sql.type_api import TypeEngine
32
33    from ..runtime.migration import MigrationContext
34
35
36class SchemaObjects:
37    def __init__(
38        self, migration_context: Optional["MigrationContext"] = None
39    ) -> None:
40        self.migration_context = migration_context
41
42    def primary_key_constraint(
43        self,
44        name: Optional[str],
45        table_name: str,
46        cols: Sequence[str],
47        schema: Optional[str] = None,
48        **dialect_kw
49    ) -> "PrimaryKeyConstraint":
50        m = self.metadata()
51        columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
52        t = sa_schema.Table(table_name, m, *columns, schema=schema)
53        p = sa_schema.PrimaryKeyConstraint(
54            *[t.c[n] for n in cols], name=name, **dialect_kw
55        )
56        return p
57
58    def foreign_key_constraint(
59        self,
60        name: Optional[str],
61        source: str,
62        referent: str,
63        local_cols: List[str],
64        remote_cols: List[str],
65        onupdate: Optional[str] = None,
66        ondelete: Optional[str] = None,
67        deferrable: Optional[bool] = None,
68        source_schema: Optional[str] = None,
69        referent_schema: Optional[str] = None,
70        initially: Optional[str] = None,
71        match: Optional[str] = None,
72        **dialect_kw
73    ) -> "ForeignKeyConstraint":
74        m = self.metadata()
75        if source == referent and source_schema == referent_schema:
76            t1_cols = local_cols + remote_cols
77        else:
78            t1_cols = local_cols
79            sa_schema.Table(
80                referent,
81                m,
82                *[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
83                schema=referent_schema
84            )
85
86        t1 = sa_schema.Table(
87            source,
88            m,
89            *[sa_schema.Column(n, NULLTYPE) for n in t1_cols],
90            schema=source_schema
91        )
92
93        tname = (
94            "%s.%s" % (referent_schema, referent)
95            if referent_schema
96            else referent
97        )
98
99        dialect_kw["match"] = match
100
101        f = sa_schema.ForeignKeyConstraint(
102            local_cols,
103            ["%s.%s" % (tname, n) for n in remote_cols],
104            name=name,
105            onupdate=onupdate,
106            ondelete=ondelete,
107            deferrable=deferrable,
108            initially=initially,
109            **dialect_kw
110        )
111        t1.append_constraint(f)
112
113        return f
114
115    def unique_constraint(
116        self,
117        name: Optional[str],
118        source: str,
119        local_cols: Sequence[str],
120        schema: Optional[str] = None,
121        **kw
122    ) -> "UniqueConstraint":
123        t = sa_schema.Table(
124            source,
125            self.metadata(),
126            *[sa_schema.Column(n, NULLTYPE) for n in local_cols],
127            schema=schema
128        )
129        kw["name"] = name
130        uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw)
131        # TODO: need event tests to ensure the event
132        # is fired off here
133        t.append_constraint(uq)
134        return uq
135
136    def check_constraint(
137        self,
138        name: Optional[str],
139        source: str,
140        condition: Union[str, "TextClause", "ColumnElement[Any]"],
141        schema: Optional[str] = None,
142        **kw
143    ) -> Union["CheckConstraint"]:
144        t = sa_schema.Table(
145            source,
146            self.metadata(),
147            sa_schema.Column("x", Integer),
148            schema=schema,
149        )
150        ck = sa_schema.CheckConstraint(condition, name=name, **kw)
151        t.append_constraint(ck)
152        return ck
153
154    def generic_constraint(
155        self,
156        name: Optional[str],
157        table_name: str,
158        type_: Optional[str],
159        schema: Optional[str] = None,
160        **kw
161    ) -> Any:
162        t = self.table(table_name, schema=schema)
163        types: Dict[Optional[str], Any] = {
164            "foreignkey": lambda name: sa_schema.ForeignKeyConstraint(
165                [], [], name=name
166            ),
167            "primary": sa_schema.PrimaryKeyConstraint,
168            "unique": sa_schema.UniqueConstraint,
169            "check": lambda name: sa_schema.CheckConstraint("", name=name),
170            None: sa_schema.Constraint,
171        }
172        try:
173            const = types[type_]
174        except KeyError as ke:
175            raise TypeError(
176                "'type' can be one of %s"
177                % ", ".join(sorted(repr(x) for x in types))
178            ) from ke
179        else:
180            const = const(name=name)
181            t.append_constraint(const)
182            return const
183
184    def metadata(self) -> "MetaData":
185        kw = {}
186        if (
187            self.migration_context is not None
188            and "target_metadata" in self.migration_context.opts
189        ):
190            mt = self.migration_context.opts["target_metadata"]
191            if hasattr(mt, "naming_convention"):
192                kw["naming_convention"] = mt.naming_convention
193        return sa_schema.MetaData(**kw)
194
195    def table(self, name: str, *columns, **kw) -> "Table":
196        m = self.metadata()
197
198        cols = [
199            sqla_compat._copy(c) if c.table is not None else c
200            for c in columns
201            if isinstance(c, Column)
202        ]
203        # these flags have already added their UniqueConstraint /
204        # Index objects to the table, so flip them off here.
205        # SQLAlchemy tometadata() avoids this instead by preserving the
206        # flags and skipping the constraints that have _type_bound on them,
207        # but for a migration we'd rather list out the constraints
208        # explicitly.
209        _constraints_included = kw.pop("_constraints_included", False)
210        if _constraints_included:
211            for c in cols:
212                c.unique = c.index = False
213
214        t = sa_schema.Table(name, m, *cols, **kw)
215
216        constraints = [
217            sqla_compat._copy(elem, target_table=t)
218            if getattr(elem, "parent", None) is not None
219            else elem
220            for elem in columns
221            if isinstance(elem, (Constraint, Index))
222        ]
223
224        for const in constraints:
225            t.append_constraint(const)
226
227        for f in t.foreign_keys:
228            self._ensure_table_for_fk(m, f)
229        return t
230
231    def column(self, name: str, type_: "TypeEngine", **kw) -> "Column":
232        return sa_schema.Column(name, type_, **kw)
233
234    def index(
235        self,
236        name: str,
237        tablename: Optional[str],
238        columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]],
239        schema: Optional[str] = None,
240        **kw
241    ) -> "Index":
242        t = sa_schema.Table(
243            tablename or "no_table",
244            self.metadata(),
245            schema=schema,
246        )
247        kw["_table"] = t
248        idx = sa_schema.Index(
249            name,
250            *[util.sqla_compat._textual_index_column(t, n) for n in columns],
251            **kw
252        )
253        return idx
254
255    def _parse_table_key(self, table_key: str) -> Tuple[Optional[str], str]:
256        if "." in table_key:
257            tokens = table_key.split(".")
258            sname: Optional[str] = ".".join(tokens[0:-1])
259            tname = tokens[-1]
260        else:
261            tname = table_key
262            sname = None
263        return (sname, tname)
264
265    def _ensure_table_for_fk(
266        self, metadata: "MetaData", fk: "ForeignKey"
267    ) -> None:
268        """create a placeholder Table object for the referent of a
269        ForeignKey.
270
271        """
272        if isinstance(fk._colspec, string_types):  # type:ignore[attr-defined]
273            table_key, cname = fk._colspec.rsplit(  # type:ignore[attr-defined]
274                ".", 1
275            )
276            sname, tname = self._parse_table_key(table_key)
277            if table_key not in metadata.tables:
278                rel_t = sa_schema.Table(tname, metadata, schema=sname)
279            else:
280                rel_t = metadata.tables[table_key]
281            if cname not in rel_t.c:
282                rel_t.append_column(sa_schema.Column(cname, NULLTYPE))
283