1# sql/base.py
2# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Foundational utilities common to many sql modules.
9
10"""
11
12
13import itertools
14import re
15
16from .visitors import ClauseVisitor
17from .. import exc
18from .. import util
19
20PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT")
21NO_ARG = util.symbol("NO_ARG")
22
23
24class Immutable(object):
25    """mark a ClauseElement as 'immutable' when expressions are cloned."""
26
27    def unique_params(self, *optionaldict, **kwargs):
28        raise NotImplementedError("Immutable objects do not support copying")
29
30    def params(self, *optionaldict, **kwargs):
31        raise NotImplementedError("Immutable objects do not support copying")
32
33    def _clone(self):
34        return self
35
36
37def _from_objects(*elements):
38    return itertools.chain(*[element._from_objects for element in elements])
39
40
41@util.decorator
42def _generative(fn, *args, **kw):
43    """Mark a method as generative."""
44
45    self = args[0]._generate()
46    fn(self, *args[1:], **kw)
47    return self
48
49
50class _DialectArgView(util.collections_abc.MutableMapping):
51    """A dictionary view of dialect-level arguments in the form
52    <dialectname>_<argument_name>.
53
54    """
55
56    def __init__(self, obj):
57        self.obj = obj
58
59    def _key(self, key):
60        try:
61            dialect, value_key = key.split("_", 1)
62        except ValueError:
63            raise KeyError(key)
64        else:
65            return dialect, value_key
66
67    def __getitem__(self, key):
68        dialect, value_key = self._key(key)
69
70        try:
71            opt = self.obj.dialect_options[dialect]
72        except exc.NoSuchModuleError:
73            raise KeyError(key)
74        else:
75            return opt[value_key]
76
77    def __setitem__(self, key, value):
78        try:
79            dialect, value_key = self._key(key)
80        except KeyError:
81            raise exc.ArgumentError(
82                "Keys must be of the form <dialectname>_<argname>"
83            )
84        else:
85            self.obj.dialect_options[dialect][value_key] = value
86
87    def __delitem__(self, key):
88        dialect, value_key = self._key(key)
89        del self.obj.dialect_options[dialect][value_key]
90
91    def __len__(self):
92        return sum(
93            len(args._non_defaults)
94            for args in self.obj.dialect_options.values()
95        )
96
97    def __iter__(self):
98        return (
99            util.safe_kwarg("%s_%s" % (dialect_name, value_name))
100            for dialect_name in self.obj.dialect_options
101            for value_name in self.obj.dialect_options[
102                dialect_name
103            ]._non_defaults
104        )
105
106
107class _DialectArgDict(util.collections_abc.MutableMapping):
108    """A dictionary view of dialect-level arguments for a specific
109    dialect.
110
111    Maintains a separate collection of user-specified arguments
112    and dialect-specified default arguments.
113
114    """
115
116    def __init__(self):
117        self._non_defaults = {}
118        self._defaults = {}
119
120    def __len__(self):
121        return len(set(self._non_defaults).union(self._defaults))
122
123    def __iter__(self):
124        return iter(set(self._non_defaults).union(self._defaults))
125
126    def __getitem__(self, key):
127        if key in self._non_defaults:
128            return self._non_defaults[key]
129        else:
130            return self._defaults[key]
131
132    def __setitem__(self, key, value):
133        self._non_defaults[key] = value
134
135    def __delitem__(self, key):
136        del self._non_defaults[key]
137
138
139class DialectKWArgs(object):
140    """Establish the ability for a class to have dialect-specific arguments
141    with defaults and constructor validation.
142
143    The :class:`.DialectKWArgs` interacts with the
144    :attr:`.DefaultDialect.construct_arguments` present on a dialect.
145
146    .. seealso::
147
148        :attr:`.DefaultDialect.construct_arguments`
149
150    """
151
152    @classmethod
153    def argument_for(cls, dialect_name, argument_name, default):
154        """Add a new kind of dialect-specific keyword argument for this class.
155
156        E.g.::
157
158            Index.argument_for("mydialect", "length", None)
159
160            some_index = Index('a', 'b', mydialect_length=5)
161
162        The :meth:`.DialectKWArgs.argument_for` method is a per-argument
163        way adding extra arguments to the
164        :attr:`.DefaultDialect.construct_arguments` dictionary. This
165        dictionary provides a list of argument names accepted by various
166        schema-level constructs on behalf of a dialect.
167
168        New dialects should typically specify this dictionary all at once as a
169        data member of the dialect class.  The use case for ad-hoc addition of
170        argument names is typically for end-user code that is also using
171        a custom compilation scheme which consumes the additional arguments.
172
173        :param dialect_name: name of a dialect.  The dialect must be
174         locatable, else a :class:`.NoSuchModuleError` is raised.   The
175         dialect must also include an existing
176         :attr:`.DefaultDialect.construct_arguments` collection, indicating
177         that it participates in the keyword-argument validation and default
178         system, else :class:`.ArgumentError` is raised.  If the dialect does
179         not include this collection, then any keyword argument can be
180         specified on behalf of this dialect already.  All dialects packaged
181         within SQLAlchemy include this collection, however for third party
182         dialects, support may vary.
183
184        :param argument_name: name of the parameter.
185
186        :param default: default value of the parameter.
187
188        .. versionadded:: 0.9.4
189
190        """
191
192        construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
193        if construct_arg_dictionary is None:
194            raise exc.ArgumentError(
195                "Dialect '%s' does have keyword-argument "
196                "validation and defaults enabled configured" % dialect_name
197            )
198        if cls not in construct_arg_dictionary:
199            construct_arg_dictionary[cls] = {}
200        construct_arg_dictionary[cls][argument_name] = default
201
202    @util.memoized_property
203    def dialect_kwargs(self):
204        """A collection of keyword arguments specified as dialect-specific
205        options to this construct.
206
207        The arguments are present here in their original ``<dialect>_<kwarg>``
208        format.  Only arguments that were actually passed are included;
209        unlike the :attr:`.DialectKWArgs.dialect_options` collection, which
210        contains all options known by this dialect including defaults.
211
212        The collection is also writable; keys are accepted of the
213        form ``<dialect>_<kwarg>`` where the value will be assembled
214        into the list of options.
215
216        .. versionadded:: 0.9.2
217
218        .. versionchanged:: 0.9.4 The :attr:`.DialectKWArgs.dialect_kwargs`
219           collection is now writable.
220
221        .. seealso::
222
223            :attr:`.DialectKWArgs.dialect_options` - nested dictionary form
224
225        """
226        return _DialectArgView(self)
227
228    @property
229    def kwargs(self):
230        """A synonym for :attr:`.DialectKWArgs.dialect_kwargs`."""
231        return self.dialect_kwargs
232
233    @util.dependencies("sqlalchemy.dialects")
234    def _kw_reg_for_dialect(dialects, dialect_name):
235        dialect_cls = dialects.registry.load(dialect_name)
236        if dialect_cls.construct_arguments is None:
237            return None
238        return dict(dialect_cls.construct_arguments)
239
240    _kw_registry = util.PopulateDict(_kw_reg_for_dialect)
241
242    def _kw_reg_for_dialect_cls(self, dialect_name):
243        construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
244        d = _DialectArgDict()
245
246        if construct_arg_dictionary is None:
247            d._defaults.update({"*": None})
248        else:
249            for cls in reversed(self.__class__.__mro__):
250                if cls in construct_arg_dictionary:
251                    d._defaults.update(construct_arg_dictionary[cls])
252        return d
253
254    @util.memoized_property
255    def dialect_options(self):
256        """A collection of keyword arguments specified as dialect-specific
257        options to this construct.
258
259        This is a two-level nested registry, keyed to ``<dialect_name>``
260        and ``<argument_name>``.  For example, the ``postgresql_where``
261        argument would be locatable as::
262
263            arg = my_object.dialect_options['postgresql']['where']
264
265        .. versionadded:: 0.9.2
266
267        .. seealso::
268
269            :attr:`.DialectKWArgs.dialect_kwargs` - flat dictionary form
270
271        """
272
273        return util.PopulateDict(
274            util.portable_instancemethod(self._kw_reg_for_dialect_cls)
275        )
276
277    def _validate_dialect_kwargs(self, kwargs):
278        # validate remaining kwargs that they all specify DB prefixes
279
280        if not kwargs:
281            return
282
283        for k in kwargs:
284            m = re.match("^(.+?)_(.+)$", k)
285            if not m:
286                raise TypeError(
287                    "Additional arguments should be "
288                    "named <dialectname>_<argument>, got '%s'" % k
289                )
290            dialect_name, arg_name = m.group(1, 2)
291
292            try:
293                construct_arg_dictionary = self.dialect_options[dialect_name]
294            except exc.NoSuchModuleError:
295                util.warn(
296                    "Can't validate argument %r; can't "
297                    "locate any SQLAlchemy dialect named %r"
298                    % (k, dialect_name)
299                )
300                self.dialect_options[dialect_name] = d = _DialectArgDict()
301                d._defaults.update({"*": None})
302                d._non_defaults[arg_name] = kwargs[k]
303            else:
304                if (
305                    "*" not in construct_arg_dictionary
306                    and arg_name not in construct_arg_dictionary
307                ):
308                    raise exc.ArgumentError(
309                        "Argument %r is not accepted by "
310                        "dialect %r on behalf of %r"
311                        % (k, dialect_name, self.__class__)
312                    )
313                else:
314                    construct_arg_dictionary[arg_name] = kwargs[k]
315
316
317class Generative(object):
318    """Allow a ClauseElement to generate itself via the
319    @_generative decorator.
320
321    """
322
323    def _generate(self):
324        s = self.__class__.__new__(self.__class__)
325        s.__dict__ = self.__dict__.copy()
326        return s
327
328
329class Executable(Generative):
330    """Mark a ClauseElement as supporting execution.
331
332    :class:`.Executable` is a superclass for all "statement" types
333    of objects, including :func:`select`, :func:`delete`, :func:`update`,
334    :func:`insert`, :func:`text`.
335
336    """
337
338    supports_execution = True
339    _execution_options = util.immutabledict()
340    _bind = None
341
342    @_generative
343    def execution_options(self, **kw):
344        """ Set non-SQL options for the statement which take effect during
345        execution.
346
347        Execution options can be set on a per-statement or
348        per :class:`.Connection` basis.   Additionally, the
349        :class:`.Engine` and ORM :class:`~.orm.query.Query` objects provide
350        access to execution options which they in turn configure upon
351        connections.
352
353        The :meth:`execution_options` method is generative.  A new
354        instance of this statement is returned that contains the options::
355
356            statement = select([table.c.x, table.c.y])
357            statement = statement.execution_options(autocommit=True)
358
359        Note that only a subset of possible execution options can be applied
360        to a statement - these include "autocommit" and "stream_results",
361        but not "isolation_level" or "compiled_cache".
362        See :meth:`.Connection.execution_options` for a full list of
363        possible options.
364
365        .. seealso::
366
367            :meth:`.Connection.execution_options()`
368
369            :meth:`.Query.execution_options()`
370
371        """
372        if "isolation_level" in kw:
373            raise exc.ArgumentError(
374                "'isolation_level' execution option may only be specified "
375                "on Connection.execution_options(), or "
376                "per-engine using the isolation_level "
377                "argument to create_engine()."
378            )
379        if "compiled_cache" in kw:
380            raise exc.ArgumentError(
381                "'compiled_cache' execution option may only be specified "
382                "on Connection.execution_options(), not per statement."
383            )
384        self._execution_options = self._execution_options.union(kw)
385
386    def execute(self, *multiparams, **params):
387        """Compile and execute this :class:`.Executable`."""
388        e = self.bind
389        if e is None:
390            label = getattr(self, "description", self.__class__.__name__)
391            msg = (
392                "This %s is not directly bound to a Connection or Engine. "
393                "Use the .execute() method of a Connection or Engine "
394                "to execute this construct." % label
395            )
396            raise exc.UnboundExecutionError(msg)
397        return e._execute_clauseelement(self, multiparams, params)
398
399    def scalar(self, *multiparams, **params):
400        """Compile and execute this :class:`.Executable`, returning the
401        result's scalar representation.
402
403        """
404        return self.execute(*multiparams, **params).scalar()
405
406    @property
407    def bind(self):
408        """Returns the :class:`.Engine` or :class:`.Connection` to
409        which this :class:`.Executable` is bound, or None if none found.
410
411        This is a traversal which checks locally, then
412        checks among the "from" clauses of associated objects
413        until a bound engine or connection is found.
414
415        """
416        if self._bind is not None:
417            return self._bind
418
419        for f in _from_objects(self):
420            if f is self:
421                continue
422            engine = f.bind
423            if engine is not None:
424                return engine
425        else:
426            return None
427
428
429class SchemaEventTarget(object):
430    """Base class for elements that are the targets of :class:`.DDLEvents`
431    events.
432
433    This includes :class:`.SchemaItem` as well as :class:`.SchemaType`.
434
435    """
436
437    def _set_parent(self, parent):
438        """Associate with this SchemaEvent's parent object."""
439
440    def _set_parent_with_dispatch(self, parent):
441        self.dispatch.before_parent_attach(self, parent)
442        self._set_parent(parent)
443        self.dispatch.after_parent_attach(self, parent)
444
445
446class SchemaVisitor(ClauseVisitor):
447    """Define the visiting for ``SchemaItem`` objects."""
448
449    __traverse_options__ = {"schema_visitor": True}
450
451
452class ColumnCollection(util.OrderedProperties):
453    """An ordered dictionary that stores a list of ColumnElement
454    instances.
455
456    Overrides the ``__eq__()`` method to produce SQL clauses between
457    sets of correlated columns.
458
459    """
460
461    __slots__ = "_all_columns"
462
463    def __init__(self, *columns):
464        super(ColumnCollection, self).__init__()
465        object.__setattr__(self, "_all_columns", [])
466        for c in columns:
467            self.add(c)
468
469    def __str__(self):
470        return repr([str(c) for c in self])
471
472    def replace(self, column):
473        """add the given column to this collection, removing unaliased
474           versions of this column  as well as existing columns with the
475           same key.
476
477            e.g.::
478
479                t = Table('sometable', metadata, Column('col1', Integer))
480                t.columns.replace(Column('col1', Integer, key='columnone'))
481
482            will remove the original 'col1' from the collection, and add
483            the new column under the name 'columnname'.
484
485           Used by schema.Column to override columns during table reflection.
486
487        """
488        remove_col = None
489        if column.name in self and column.key != column.name:
490            other = self[column.name]
491            if other.name == other.key:
492                remove_col = other
493                del self._data[other.key]
494
495        if column.key in self._data:
496            remove_col = self._data[column.key]
497
498        self._data[column.key] = column
499        if remove_col is not None:
500            self._all_columns[:] = [
501                column if c is remove_col else c for c in self._all_columns
502            ]
503        else:
504            self._all_columns.append(column)
505
506    def add(self, column):
507        """Add a column to this collection.
508
509        The key attribute of the column will be used as the hash key
510        for this dictionary.
511
512        """
513        if not column.key:
514            raise exc.ArgumentError(
515                "Can't add unnamed column to column collection"
516            )
517        self[column.key] = column
518
519    def __delitem__(self, key):
520        raise NotImplementedError()
521
522    def __setattr__(self, key, obj):
523        raise NotImplementedError()
524
525    def __setitem__(self, key, value):
526        if key in self:
527
528            # this warning is primarily to catch select() statements
529            # which have conflicting column names in their exported
530            # columns collection
531
532            existing = self[key]
533
534            if existing is value:
535                return
536
537            if not existing.shares_lineage(value):
538                util.warn(
539                    "Column %r on table %r being replaced by "
540                    "%r, which has the same key.  Consider "
541                    "use_labels for select() statements."
542                    % (key, getattr(existing, "table", None), value)
543                )
544
545            # pop out memoized proxy_set as this
546            # operation may very well be occurring
547            # in a _make_proxy operation
548            util.memoized_property.reset(value, "proxy_set")
549
550        self._all_columns.append(value)
551        self._data[key] = value
552
553    def clear(self):
554        raise NotImplementedError()
555
556    def remove(self, column):
557        del self._data[column.key]
558        self._all_columns[:] = [
559            c for c in self._all_columns if c is not column
560        ]
561
562    def update(self, iter_):
563        cols = list(iter_)
564        all_col_set = set(self._all_columns)
565        self._all_columns.extend(
566            c for label, c in cols if c not in all_col_set
567        )
568        self._data.update((label, c) for label, c in cols)
569
570    def extend(self, iter_):
571        cols = list(iter_)
572        all_col_set = set(self._all_columns)
573        self._all_columns.extend(c for c in cols if c not in all_col_set)
574        self._data.update((c.key, c) for c in cols)
575
576    __hash__ = None
577
578    @util.dependencies("sqlalchemy.sql.elements")
579    def __eq__(self, elements, other):
580        l = []
581        for c in getattr(other, "_all_columns", other):
582            for local in self._all_columns:
583                if c.shares_lineage(local):
584                    l.append(c == local)
585        return elements.and_(*l)
586
587    def __contains__(self, other):
588        if not isinstance(other, util.string_types):
589            raise exc.ArgumentError("__contains__ requires a string argument")
590        return util.OrderedProperties.__contains__(self, other)
591
592    def __getstate__(self):
593        return {"_data": self._data, "_all_columns": self._all_columns}
594
595    def __setstate__(self, state):
596        object.__setattr__(self, "_data", state["_data"])
597        object.__setattr__(self, "_all_columns", state["_all_columns"])
598
599    def contains_column(self, col):
600        return col in set(self._all_columns)
601
602    def as_immutable(self):
603        return ImmutableColumnCollection(self._data, self._all_columns)
604
605
606class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
607    def __init__(self, data, all_columns):
608        util.ImmutableProperties.__init__(self, data)
609        object.__setattr__(self, "_all_columns", all_columns)
610
611    extend = remove = util.ImmutableProperties._immutable
612
613
614class ColumnSet(util.ordered_column_set):
615    def contains_column(self, col):
616        return col in self
617
618    def extend(self, cols):
619        for col in cols:
620            self.add(col)
621
622    def __add__(self, other):
623        return list(self) + list(other)
624
625    @util.dependencies("sqlalchemy.sql.elements")
626    def __eq__(self, elements, other):
627        l = []
628        for c in other:
629            for local in self:
630                if c.shares_lineage(local):
631                    l.append(c == local)
632        return elements.and_(*l)
633
634    def __hash__(self):
635        return hash(tuple(x for x in self))
636
637
638def _bind_or_error(schemaitem, msg=None):
639    bind = schemaitem.bind
640    if not bind:
641        name = schemaitem.__class__.__name__
642        label = getattr(
643            schemaitem, "fullname", getattr(schemaitem, "name", None)
644        )
645        if label:
646            item = "%s object %r" % (name, label)
647        else:
648            item = "%s object" % name
649        if msg is None:
650            msg = (
651                "%s is not bound to an Engine or Connection.  "
652                "Execution can not proceed without a database to execute "
653                "against." % item
654            )
655        raise exc.UnboundExecutionError(msg)
656    return bind
657