1import re
2
3from sqlalchemy import schema
4from sqlalchemy import types as sqltypes
5from sqlalchemy.ext.compiler import compiles
6
7from .base import alter_table
8from .base import AlterColumn
9from .base import ColumnDefault
10from .base import ColumnName
11from .base import ColumnNullable
12from .base import ColumnType
13from .base import format_column_name
14from .base import format_server_default
15from .impl import DefaultImpl
16from .. import util
17from ..autogenerate import compare
18from ..util.sqla_compat import _is_mariadb
19from ..util.sqla_compat import _is_type_bound
20
21
22class MySQLImpl(DefaultImpl):
23    __dialect__ = "mysql"
24
25    transactional_ddl = False
26    type_synonyms = DefaultImpl.type_synonyms + ({"BOOL", "TINYINT"},)
27    type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"]
28
29    def alter_column(
30        self,
31        table_name,
32        column_name,
33        nullable=None,
34        server_default=False,
35        name=None,
36        type_=None,
37        schema=None,
38        existing_type=None,
39        existing_server_default=None,
40        existing_nullable=None,
41        autoincrement=None,
42        existing_autoincrement=None,
43        comment=False,
44        existing_comment=None,
45        **kw
46    ):
47        if name is not None or self._is_mysql_allowed_functional_default(
48            type_ if type_ is not None else existing_type, server_default
49        ):
50            self._exec(
51                MySQLChangeColumn(
52                    table_name,
53                    column_name,
54                    schema=schema,
55                    newname=name if name is not None else column_name,
56                    nullable=nullable
57                    if nullable is not None
58                    else existing_nullable
59                    if existing_nullable is not None
60                    else True,
61                    type_=type_ if type_ is not None else existing_type,
62                    default=server_default
63                    if server_default is not False
64                    else existing_server_default,
65                    autoincrement=autoincrement
66                    if autoincrement is not None
67                    else existing_autoincrement,
68                    comment=comment
69                    if comment is not False
70                    else existing_comment,
71                )
72            )
73        elif (
74            nullable is not None
75            or type_ is not None
76            or autoincrement is not None
77            or comment is not False
78        ):
79            self._exec(
80                MySQLModifyColumn(
81                    table_name,
82                    column_name,
83                    schema=schema,
84                    newname=name if name is not None else column_name,
85                    nullable=nullable
86                    if nullable is not None
87                    else existing_nullable
88                    if existing_nullable is not None
89                    else True,
90                    type_=type_ if type_ is not None else existing_type,
91                    default=server_default
92                    if server_default is not False
93                    else existing_server_default,
94                    autoincrement=autoincrement
95                    if autoincrement is not None
96                    else existing_autoincrement,
97                    comment=comment
98                    if comment is not False
99                    else existing_comment,
100                )
101            )
102        elif server_default is not False:
103            self._exec(
104                MySQLAlterDefault(
105                    table_name, column_name, server_default, schema=schema
106                )
107            )
108
109    def drop_constraint(self, const):
110        if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
111            return
112
113        super(MySQLImpl, self).drop_constraint(const)
114
115    def _is_mysql_allowed_functional_default(self, type_, server_default):
116        return (
117            type_ is not None
118            and type_._type_affinity is sqltypes.DateTime
119            and server_default is not None
120        )
121
122    def compare_server_default(
123        self,
124        inspector_column,
125        metadata_column,
126        rendered_metadata_default,
127        rendered_inspector_default,
128    ):
129        # partially a workaround for SQLAlchemy issue #3023; if the
130        # column were created without "NOT NULL", MySQL may have added
131        # an implicit default of '0' which we need to skip
132        # TODO: this is not really covered anymore ?
133        if (
134            metadata_column.type._type_affinity is sqltypes.Integer
135            and inspector_column.primary_key
136            and not inspector_column.autoincrement
137            and not rendered_metadata_default
138            and rendered_inspector_default == "'0'"
139        ):
140            return False
141        elif inspector_column.type._type_affinity is sqltypes.Integer:
142            rendered_inspector_default = (
143                re.sub(r"^'|'$", "", rendered_inspector_default)
144                if rendered_inspector_default is not None
145                else None
146            )
147            return rendered_inspector_default != rendered_metadata_default
148        elif rendered_inspector_default and rendered_metadata_default:
149            # adjust for "function()" vs. "FUNCTION" as can occur particularly
150            # for the CURRENT_TIMESTAMP function on newer MariaDB versions
151
152            # SQLAlchemy MySQL dialect bundles ON UPDATE into the server
153            # default; adjust for this possibly being present.
154            onupdate_ins = re.match(
155                r"(.*) (on update.*?)(?:\(\))?$",
156                rendered_inspector_default.lower(),
157            )
158            onupdate_met = re.match(
159                r"(.*) (on update.*?)(?:\(\))?$",
160                rendered_metadata_default.lower(),
161            )
162
163            if onupdate_ins:
164                if not onupdate_met:
165                    return True
166                elif onupdate_ins.group(2) != onupdate_met.group(2):
167                    return True
168
169                rendered_inspector_default = onupdate_ins.group(1)
170                rendered_metadata_default = onupdate_met.group(1)
171
172            return re.sub(
173                r"(.*?)(?:\(\))?$", r"\1", rendered_inspector_default.lower()
174            ) != re.sub(
175                r"(.*?)(?:\(\))?$", r"\1", rendered_metadata_default.lower()
176            )
177        else:
178            return rendered_inspector_default != rendered_metadata_default
179
180    def correct_for_autogen_constraints(
181        self,
182        conn_unique_constraints,
183        conn_indexes,
184        metadata_unique_constraints,
185        metadata_indexes,
186    ):
187
188        # TODO: if SQLA 1.0, make use of "duplicates_index"
189        # metadata
190        removed = set()
191        for idx in list(conn_indexes):
192            if idx.unique:
193                continue
194            # MySQL puts implicit indexes on FK columns, even if
195            # composite and even if MyISAM, so can't check this too easily.
196            # the name of the index may be the column name or it may
197            # be the name of the FK constraint.
198            for col in idx.columns:
199                if idx.name == col.name:
200                    conn_indexes.remove(idx)
201                    removed.add(idx.name)
202                    break
203                for fk in col.foreign_keys:
204                    if fk.name == idx.name:
205                        conn_indexes.remove(idx)
206                        removed.add(idx.name)
207                        break
208                if idx.name in removed:
209                    break
210
211        # then remove indexes from the "metadata_indexes"
212        # that we've removed from reflected, otherwise they come out
213        # as adds (see #202)
214        for idx in list(metadata_indexes):
215            if idx.name in removed:
216                metadata_indexes.remove(idx)
217
218    def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
219        conn_fk_by_sig = dict(
220            (compare._fk_constraint_sig(fk).sig, fk) for fk in conn_fks
221        )
222        metadata_fk_by_sig = dict(
223            (compare._fk_constraint_sig(fk).sig, fk) for fk in metadata_fks
224        )
225
226        for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig):
227            mdfk = metadata_fk_by_sig[sig]
228            cnfk = conn_fk_by_sig[sig]
229            # MySQL considers RESTRICT to be the default and doesn't
230            # report on it.  if the model has explicit RESTRICT and
231            # the conn FK has None, set it to RESTRICT
232            if (
233                mdfk.ondelete is not None
234                and mdfk.ondelete.lower() == "restrict"
235                and cnfk.ondelete is None
236            ):
237                cnfk.ondelete = "RESTRICT"
238            if (
239                mdfk.onupdate is not None
240                and mdfk.onupdate.lower() == "restrict"
241                and cnfk.onupdate is None
242            ):
243                cnfk.onupdate = "RESTRICT"
244
245
246class MariaDBImpl(MySQLImpl):
247    __dialect__ = "mariadb"
248
249
250class MySQLAlterDefault(AlterColumn):
251    def __init__(self, name, column_name, default, schema=None):
252        super(AlterColumn, self).__init__(name, schema=schema)
253        self.column_name = column_name
254        self.default = default
255
256
257class MySQLChangeColumn(AlterColumn):
258    def __init__(
259        self,
260        name,
261        column_name,
262        schema=None,
263        newname=None,
264        type_=None,
265        nullable=None,
266        default=False,
267        autoincrement=None,
268        comment=False,
269    ):
270        super(AlterColumn, self).__init__(name, schema=schema)
271        self.column_name = column_name
272        self.nullable = nullable
273        self.newname = newname
274        self.default = default
275        self.autoincrement = autoincrement
276        self.comment = comment
277        if type_ is None:
278            raise util.CommandError(
279                "All MySQL CHANGE/MODIFY COLUMN operations "
280                "require the existing type."
281            )
282
283        self.type_ = sqltypes.to_instance(type_)
284
285
286class MySQLModifyColumn(MySQLChangeColumn):
287    pass
288
289
290@compiles(ColumnNullable, "mysql", "mariadb")
291@compiles(ColumnName, "mysql", "mariadb")
292@compiles(ColumnDefault, "mysql", "mariadb")
293@compiles(ColumnType, "mysql", "mariadb")
294def _mysql_doesnt_support_individual(element, compiler, **kw):
295    raise NotImplementedError(
296        "Individual alter column constructs not supported by MySQL"
297    )
298
299
300@compiles(MySQLAlterDefault, "mysql", "mariadb")
301def _mysql_alter_default(element, compiler, **kw):
302    return "%s ALTER COLUMN %s %s" % (
303        alter_table(compiler, element.table_name, element.schema),
304        format_column_name(compiler, element.column_name),
305        "SET DEFAULT %s" % format_server_default(compiler, element.default)
306        if element.default is not None
307        else "DROP DEFAULT",
308    )
309
310
311@compiles(MySQLModifyColumn, "mysql", "mariadb")
312def _mysql_modify_column(element, compiler, **kw):
313    return "%s MODIFY %s %s" % (
314        alter_table(compiler, element.table_name, element.schema),
315        format_column_name(compiler, element.column_name),
316        _mysql_colspec(
317            compiler,
318            nullable=element.nullable,
319            server_default=element.default,
320            type_=element.type_,
321            autoincrement=element.autoincrement,
322            comment=element.comment,
323        ),
324    )
325
326
327@compiles(MySQLChangeColumn, "mysql", "mariadb")
328def _mysql_change_column(element, compiler, **kw):
329    return "%s CHANGE %s %s %s" % (
330        alter_table(compiler, element.table_name, element.schema),
331        format_column_name(compiler, element.column_name),
332        format_column_name(compiler, element.newname),
333        _mysql_colspec(
334            compiler,
335            nullable=element.nullable,
336            server_default=element.default,
337            type_=element.type_,
338            autoincrement=element.autoincrement,
339            comment=element.comment,
340        ),
341    )
342
343
344def _mysql_colspec(
345    compiler, nullable, server_default, type_, autoincrement, comment
346):
347    spec = "%s %s" % (
348        compiler.dialect.type_compiler.process(type_),
349        "NULL" if nullable else "NOT NULL",
350    )
351    if autoincrement:
352        spec += " AUTO_INCREMENT"
353    if server_default is not False and server_default is not None:
354        spec += " DEFAULT %s" % format_server_default(compiler, server_default)
355    if comment:
356        spec += " COMMENT %s" % compiler.sql_compiler.render_literal_value(
357            comment, sqltypes.String()
358        )
359
360    return spec
361
362
363@compiles(schema.DropConstraint, "mysql", "mariadb")
364def _mysql_drop_constraint(element, compiler, **kw):
365    """Redefine SQLAlchemy's drop constraint to
366    raise errors for invalid constraint type."""
367
368    constraint = element.element
369    if isinstance(
370        constraint,
371        (
372            schema.ForeignKeyConstraint,
373            schema.PrimaryKeyConstraint,
374            schema.UniqueConstraint,
375        ),
376    ):
377        return compiler.visit_drop_constraint(element, **kw)
378    elif isinstance(constraint, schema.CheckConstraint):
379        # note that SQLAlchemy as of 1.2 does not yet support
380        # DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
381        # here.
382        if _is_mariadb(compiler.dialect):
383            return "ALTER TABLE %s DROP CONSTRAINT %s" % (
384                compiler.preparer.format_table(constraint.table),
385                compiler.preparer.format_constraint(constraint),
386            )
387        else:
388            return "ALTER TABLE %s DROP CHECK %s" % (
389                compiler.preparer.format_table(constraint.table),
390                compiler.preparer.format_constraint(constraint),
391            )
392    else:
393        raise NotImplementedError(
394            "No generic 'DROP CONSTRAINT' in MySQL - "
395            "please specify constraint type"
396        )
397