1from collections import namedtuple
2import re
3
4from sqlalchemy import cast
5from sqlalchemy import schema
6from sqlalchemy import text
7
8from . import base
9from .. import util
10from ..util import sqla_compat
11from ..util.compat import string_types
12from ..util.compat import text_type
13from ..util.compat import with_metaclass
14
15
16class ImplMeta(type):
17    def __init__(cls, classname, bases, dict_):
18        newtype = type.__init__(cls, classname, bases, dict_)
19        if "__dialect__" in dict_:
20            _impls[dict_["__dialect__"]] = cls
21        return newtype
22
23
24_impls = {}
25
26Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
27
28
29class DefaultImpl(with_metaclass(ImplMeta)):
30
31    """Provide the entrypoint for major migration operations,
32    including database-specific behavioral variances.
33
34    While individual SQL/DDL constructs already provide
35    for database-specific implementations, variances here
36    allow for entirely different sequences of operations
37    to take place for a particular migration, such as
38    SQL Server's special 'IDENTITY INSERT' step for
39    bulk inserts.
40
41    """
42
43    __dialect__ = "default"
44
45    transactional_ddl = False
46    command_terminator = ";"
47    type_synonyms = ({"NUMERIC", "DECIMAL"},)
48    type_arg_extract = ()
49
50    def __init__(
51        self,
52        dialect,
53        connection,
54        as_sql,
55        transactional_ddl,
56        output_buffer,
57        context_opts,
58    ):
59        self.dialect = dialect
60        self.connection = connection
61        self.as_sql = as_sql
62        self.literal_binds = context_opts.get("literal_binds", False)
63
64        self.output_buffer = output_buffer
65        self.memo = {}
66        self.context_opts = context_opts
67        if transactional_ddl is not None:
68            self.transactional_ddl = transactional_ddl
69
70        if self.literal_binds:
71            if not self.as_sql:
72                raise util.CommandError(
73                    "Can't use literal_binds setting without as_sql mode"
74                )
75
76    @classmethod
77    def get_by_dialect(cls, dialect):
78        return _impls[dialect.name]
79
80    def static_output(self, text):
81        self.output_buffer.write(text_type(text + "\n\n"))
82        self.output_buffer.flush()
83
84    def requires_recreate_in_batch(self, batch_op):
85        """Return True if the given :class:`.BatchOperationsImpl`
86        would need the table to be recreated and copied in order to
87        proceed.
88
89        Normally, only returns True on SQLite when operations other
90        than add_column are present.
91
92        """
93        return False
94
95    def prep_table_for_batch(self, table):
96        """perform any operations needed on a table before a new
97        one is created to replace it in batch mode.
98
99        the PG dialect uses this to drop constraints on the table
100        before the new one uses those same names.
101
102        """
103
104    @property
105    def bind(self):
106        return self.connection
107
108    def _exec(
109        self,
110        construct,
111        execution_options=None,
112        multiparams=(),
113        params=util.immutabledict(),
114    ):
115        if isinstance(construct, string_types):
116            construct = text(construct)
117        if self.as_sql:
118            if multiparams or params:
119                # TODO: coverage
120                raise Exception("Execution arguments not allowed with as_sql")
121
122            if self.literal_binds and not isinstance(
123                construct, schema.DDLElement
124            ):
125                compile_kw = dict(compile_kwargs={"literal_binds": True})
126            else:
127                compile_kw = {}
128
129            self.static_output(
130                text_type(
131                    construct.compile(dialect=self.dialect, **compile_kw)
132                )
133                .replace("\t", "    ")
134                .strip()
135                + self.command_terminator
136            )
137        else:
138            conn = self.connection
139            if execution_options:
140                conn = conn.execution_options(**execution_options)
141            return conn.execute(construct, *multiparams, **params)
142
143    def execute(self, sql, execution_options=None):
144        self._exec(sql, execution_options)
145
146    def alter_column(
147        self,
148        table_name,
149        column_name,
150        nullable=None,
151        server_default=False,
152        name=None,
153        type_=None,
154        schema=None,
155        autoincrement=None,
156        comment=False,
157        existing_comment=None,
158        existing_type=None,
159        existing_server_default=None,
160        existing_nullable=None,
161        existing_autoincrement=None,
162    ):
163        if autoincrement is not None or existing_autoincrement is not None:
164            util.warn(
165                "autoincrement and existing_autoincrement "
166                "only make sense for MySQL",
167                stacklevel=3,
168            )
169        if nullable is not None:
170            self._exec(
171                base.ColumnNullable(
172                    table_name,
173                    column_name,
174                    nullable,
175                    schema=schema,
176                    existing_type=existing_type,
177                    existing_server_default=existing_server_default,
178                    existing_nullable=existing_nullable,
179                    existing_comment=existing_comment,
180                )
181            )
182        if server_default is not False:
183            self._exec(
184                base.ColumnDefault(
185                    table_name,
186                    column_name,
187                    server_default,
188                    schema=schema,
189                    existing_type=existing_type,
190                    existing_server_default=existing_server_default,
191                    existing_nullable=existing_nullable,
192                    existing_comment=existing_comment,
193                )
194            )
195        if type_ is not None:
196            self._exec(
197                base.ColumnType(
198                    table_name,
199                    column_name,
200                    type_,
201                    schema=schema,
202                    existing_type=existing_type,
203                    existing_server_default=existing_server_default,
204                    existing_nullable=existing_nullable,
205                    existing_comment=existing_comment,
206                )
207            )
208
209        if comment is not False:
210            self._exec(
211                base.ColumnComment(
212                    table_name,
213                    column_name,
214                    comment,
215                    schema=schema,
216                    existing_type=existing_type,
217                    existing_server_default=existing_server_default,
218                    existing_nullable=existing_nullable,
219                    existing_comment=existing_comment,
220                )
221            )
222
223        # do the new name last ;)
224        if name is not None:
225            self._exec(
226                base.ColumnName(
227                    table_name,
228                    column_name,
229                    name,
230                    schema=schema,
231                    existing_type=existing_type,
232                    existing_server_default=existing_server_default,
233                    existing_nullable=existing_nullable,
234                )
235            )
236
237    def add_column(self, table_name, column, schema=None):
238        self._exec(base.AddColumn(table_name, column, schema=schema))
239
240    def drop_column(self, table_name, column, schema=None, **kw):
241        self._exec(base.DropColumn(table_name, column, schema=schema))
242
243    def add_constraint(self, const):
244        if const._create_rule is None or const._create_rule(self):
245            self._exec(schema.AddConstraint(const))
246
247    def drop_constraint(self, const):
248        self._exec(schema.DropConstraint(const))
249
250    def rename_table(self, old_table_name, new_table_name, schema=None):
251        self._exec(
252            base.RenameTable(old_table_name, new_table_name, schema=schema)
253        )
254
255    def create_table(self, table):
256        table.dispatch.before_create(
257            table, self.connection, checkfirst=False, _ddl_runner=self
258        )
259        self._exec(schema.CreateTable(table))
260        table.dispatch.after_create(
261            table, self.connection, checkfirst=False, _ddl_runner=self
262        )
263        for index in table.indexes:
264            self._exec(schema.CreateIndex(index))
265
266        with_comment = (
267            sqla_compat._dialect_supports_comments(self.dialect)
268            and not self.dialect.inline_comments
269        )
270        comment = sqla_compat._comment_attribute(table)
271        if comment and with_comment:
272            self.create_table_comment(table)
273
274        for column in table.columns:
275            comment = sqla_compat._comment_attribute(column)
276            if comment and with_comment:
277                self.create_column_comment(column)
278
279    def drop_table(self, table):
280        self._exec(schema.DropTable(table))
281
282    def create_index(self, index):
283        self._exec(schema.CreateIndex(index))
284
285    def create_table_comment(self, table):
286        self._exec(schema.SetTableComment(table))
287
288    def drop_table_comment(self, table):
289        self._exec(schema.DropTableComment(table))
290
291    def create_column_comment(self, column):
292        self._exec(schema.SetColumnComment(column))
293
294    def drop_index(self, index):
295        self._exec(schema.DropIndex(index))
296
297    def bulk_insert(self, table, rows, multiinsert=True):
298        if not isinstance(rows, list):
299            raise TypeError("List expected")
300        elif rows and not isinstance(rows[0], dict):
301            raise TypeError("List of dictionaries expected")
302        if self.as_sql:
303            for row in rows:
304                self._exec(
305                    table.insert(inline=True).values(
306                        **dict(
307                            (
308                                k,
309                                sqla_compat._literal_bindparam(
310                                    k, v, type_=table.c[k].type
311                                )
312                                if not isinstance(
313                                    v, sqla_compat._literal_bindparam
314                                )
315                                else v,
316                            )
317                            for k, v in row.items()
318                        )
319                    )
320                )
321        else:
322            # work around http://www.sqlalchemy.org/trac/ticket/2461
323            if not hasattr(table, "_autoincrement_column"):
324                table._autoincrement_column = None
325            if rows:
326                if multiinsert:
327                    self._exec(table.insert(inline=True), multiparams=rows)
328                else:
329                    for row in rows:
330                        self._exec(table.insert(inline=True).values(**row))
331
332    def _tokenize_column_type(self, column):
333        definition = self.dialect.type_compiler.process(column.type).lower()
334
335        # tokenize the SQLAlchemy-generated version of a type, so that
336        # the two can be compared.
337        #
338        # examples:
339        # NUMERIC(10, 5)
340        # TIMESTAMP WITH TIMEZONE
341        # INTEGER UNSIGNED
342        # INTEGER (10) UNSIGNED
343        # INTEGER(10) UNSIGNED
344        # varchar character set utf8
345        #
346
347        tokens = re.findall(r"[\w\-_]+|\(.+?\)", definition)
348
349        term_tokens = []
350        paren_term = None
351
352        for token in tokens:
353            if re.match(r"^\(.*\)$", token):
354                paren_term = token
355            else:
356                term_tokens.append(token)
357
358        params = Params(term_tokens[0], term_tokens[1:], [], {})
359
360        if paren_term:
361            for term in re.findall("[^(),]+", paren_term):
362                if "=" in term:
363                    key, val = term.split("=")
364                    params.kwargs[key.strip()] = val.strip()
365                else:
366                    params.args.append(term.strip())
367
368        return params
369
370    def _column_types_match(self, inspector_params, metadata_params):
371        if inspector_params.token0 == metadata_params.token0:
372            return True
373
374        synonyms = [{t.lower() for t in batch} for batch in self.type_synonyms]
375        inspector_all_terms = " ".join(
376            [inspector_params.token0] + inspector_params.tokens
377        )
378        metadata_all_terms = " ".join(
379            [metadata_params.token0] + metadata_params.tokens
380        )
381
382        for batch in synonyms:
383            if {inspector_all_terms, metadata_all_terms}.issubset(batch) or {
384                inspector_params.token0,
385                metadata_params.token0,
386            }.issubset(batch):
387                return True
388        return False
389
390    def _column_args_match(self, inspected_params, meta_params):
391        """We want to compare column parameters. However, we only want
392        to compare parameters that are set. If they both have `collation`,
393        we want to make sure they are the same. However, if only one
394        specifies it, dont flag it for being less specific
395        """
396
397        if (
398            len(meta_params.tokens) == len(inspected_params.tokens)
399            and meta_params.tokens != inspected_params.tokens
400        ):
401            return False
402
403        if (
404            len(meta_params.args) == len(inspected_params.args)
405            and meta_params.args != inspected_params.args
406        ):
407            return False
408
409        insp = " ".join(inspected_params.tokens).lower()
410        meta = " ".join(meta_params.tokens).lower()
411
412        for reg in self.type_arg_extract:
413            mi = re.search(reg, insp)
414            mm = re.search(reg, meta)
415
416            if mi and mm and mi.group(1) != mm.group(1):
417                return False
418
419        return True
420
421    def compare_type(self, inspector_column, metadata_column):
422        """Returns True if there ARE differences between the types of the two
423        columns. Takes impl.type_synonyms into account between retrospected
424        and metadata types
425        """
426        inspector_params = self._tokenize_column_type(inspector_column)
427        metadata_params = self._tokenize_column_type(metadata_column)
428
429        if not self._column_types_match(inspector_params, metadata_params,):
430            return True
431        if not self._column_args_match(inspector_params, metadata_params):
432            return True
433        return False
434
435    def compare_server_default(
436        self,
437        inspector_column,
438        metadata_column,
439        rendered_metadata_default,
440        rendered_inspector_default,
441    ):
442        return rendered_inspector_default != rendered_metadata_default
443
444    def correct_for_autogen_constraints(
445        self,
446        conn_uniques,
447        conn_indexes,
448        metadata_unique_constraints,
449        metadata_indexes,
450    ):
451        pass
452
453    def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
454        if existing.type._type_affinity is not new_type._type_affinity:
455            existing_transfer["expr"] = cast(
456                existing_transfer["expr"], new_type
457            )
458
459    def render_ddl_sql_expr(self, expr, is_server_default=False, **kw):
460        """Render a SQL expression that is typically a server default,
461        index expression, etc.
462
463        .. versionadded:: 1.0.11
464
465        """
466
467        compile_kw = dict(
468            compile_kwargs={"literal_binds": True, "include_table": False}
469        )
470        return text_type(expr.compile(dialect=self.dialect, **compile_kw))
471
472    def _compat_autogen_column_reflect(self, inspector):
473        return self.autogen_column_reflect
474
475    def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
476        pass
477
478    def autogen_column_reflect(self, inspector, table, column_info):
479        """A hook that is attached to the 'column_reflect' event for when
480        a Table is reflected from the database during the autogenerate
481        process.
482
483        Dialects can elect to modify the information gathered here.
484
485        """
486
487    def start_migrations(self):
488        """A hook called when :meth:`.EnvironmentContext.run_migrations`
489        is called.
490
491        Implementations can set up per-migration-run state here.
492
493        """
494
495    def emit_begin(self):
496        """Emit the string ``BEGIN``, or the backend-specific
497        equivalent, on the current connection context.
498
499        This is used in offline mode and typically
500        via :meth:`.EnvironmentContext.begin_transaction`.
501
502        """
503        self.static_output("BEGIN" + self.command_terminator)
504
505    def emit_commit(self):
506        """Emit the string ``COMMIT``, or the backend-specific
507        equivalent, on the current connection context.
508
509        This is used in offline mode and typically
510        via :meth:`.EnvironmentContext.begin_transaction`.
511
512        """
513        self.static_output("COMMIT" + self.command_terminator)
514
515    def render_type(self, type_obj, autogen_context):
516        return False
517