1# sql/base.py
2# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Foundational utilities common to many sql modules.
9
10"""
11
12
13from .. import util, exc
14import itertools
15from .visitors import ClauseVisitor
16import re
17import collections
18
19PARSE_AUTOCOMMIT = util.symbol('PARSE_AUTOCOMMIT')
20NO_ARG = util.symbol('NO_ARG')
21
22
23class Immutable(object):
24    """mark a ClauseElement as 'immutable' when expressions are cloned."""
25
26    def unique_params(self, *optionaldict, **kwargs):
27        raise NotImplementedError("Immutable objects do not support copying")
28
29    def params(self, *optionaldict, **kwargs):
30        raise NotImplementedError("Immutable objects do not support copying")
31
32    def _clone(self):
33        return self
34
35
36def _from_objects(*elements):
37    return itertools.chain(*[element._from_objects for element in elements])
38
39
40@util.decorator
41def _generative(fn, *args, **kw):
42    """Mark a method as generative."""
43
44    self = args[0]._generate()
45    fn(self, *args[1:], **kw)
46    return self
47
48
49class _DialectArgView(collections.MutableMapping):
50    """A dictionary view of dialect-level arguments in the form
51    <dialectname>_<argument_name>.
52
53    """
54
55    def __init__(self, obj):
56        self.obj = obj
57
58    def _key(self, key):
59        try:
60            dialect, value_key = key.split("_", 1)
61        except ValueError:
62            raise KeyError(key)
63        else:
64            return dialect, value_key
65
66    def __getitem__(self, key):
67        dialect, value_key = self._key(key)
68
69        try:
70            opt = self.obj.dialect_options[dialect]
71        except exc.NoSuchModuleError:
72            raise KeyError(key)
73        else:
74            return opt[value_key]
75
76    def __setitem__(self, key, value):
77        try:
78            dialect, value_key = self._key(key)
79        except KeyError:
80            raise exc.ArgumentError(
81                "Keys must be of the form <dialectname>_<argname>")
82        else:
83            self.obj.dialect_options[dialect][value_key] = value
84
85    def __delitem__(self, key):
86        dialect, value_key = self._key(key)
87        del self.obj.dialect_options[dialect][value_key]
88
89    def __len__(self):
90        return sum(len(args._non_defaults) for args in
91                   self.obj.dialect_options.values())
92
93    def __iter__(self):
94        return (
95            util.safe_kwarg("%s_%s" % (dialect_name, value_name))
96            for dialect_name in self.obj.dialect_options
97            for value_name in
98            self.obj.dialect_options[dialect_name]._non_defaults
99        )
100
101
102class _DialectArgDict(collections.MutableMapping):
103    """A dictionary view of dialect-level arguments for a specific
104    dialect.
105
106    Maintains a separate collection of user-specified arguments
107    and dialect-specified default arguments.
108
109    """
110
111    def __init__(self):
112        self._non_defaults = {}
113        self._defaults = {}
114
115    def __len__(self):
116        return len(set(self._non_defaults).union(self._defaults))
117
118    def __iter__(self):
119        return iter(set(self._non_defaults).union(self._defaults))
120
121    def __getitem__(self, key):
122        if key in self._non_defaults:
123            return self._non_defaults[key]
124        else:
125            return self._defaults[key]
126
127    def __setitem__(self, key, value):
128        self._non_defaults[key] = value
129
130    def __delitem__(self, key):
131        del self._non_defaults[key]
132
133
134class DialectKWArgs(object):
135    """Establish the ability for a class to have dialect-specific arguments
136    with defaults and constructor validation.
137
138    The :class:`.DialectKWArgs` interacts with the
139    :attr:`.DefaultDialect.construct_arguments` present on a dialect.
140
141    .. seealso::
142
143        :attr:`.DefaultDialect.construct_arguments`
144
145    """
146
147    @classmethod
148    def argument_for(cls, dialect_name, argument_name, default):
149        """Add a new kind of dialect-specific keyword argument for this class.
150
151        E.g.::
152
153            Index.argument_for("mydialect", "length", None)
154
155            some_index = Index('a', 'b', mydialect_length=5)
156
157        The :meth:`.DialectKWArgs.argument_for` method is a per-argument
158        way adding extra arguments to the
159        :attr:`.DefaultDialect.construct_arguments` dictionary. This
160        dictionary provides a list of argument names accepted by various
161        schema-level constructs on behalf of a dialect.
162
163        New dialects should typically specify this dictionary all at once as a
164        data member of the dialect class.  The use case for ad-hoc addition of
165        argument names is typically for end-user code that is also using
166        a custom compilation scheme which consumes the additional arguments.
167
168        :param dialect_name: name of a dialect.  The dialect must be
169         locatable, else a :class:`.NoSuchModuleError` is raised.   The
170         dialect must also include an existing
171         :attr:`.DefaultDialect.construct_arguments` collection, indicating
172         that it participates in the keyword-argument validation and default
173         system, else :class:`.ArgumentError` is raised.  If the dialect does
174         not include this collection, then any keyword argument can be
175         specified on behalf of this dialect already.  All dialects packaged
176         within SQLAlchemy include this collection, however for third party
177         dialects, support may vary.
178
179        :param argument_name: name of the parameter.
180
181        :param default: default value of the parameter.
182
183        .. versionadded:: 0.9.4
184
185        """
186
187        construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
188        if construct_arg_dictionary is None:
189            raise exc.ArgumentError(
190                "Dialect '%s' does have keyword-argument "
191                "validation and defaults enabled configured" %
192                dialect_name)
193        if cls not in construct_arg_dictionary:
194            construct_arg_dictionary[cls] = {}
195        construct_arg_dictionary[cls][argument_name] = default
196
197    @util.memoized_property
198    def dialect_kwargs(self):
199        """A collection of keyword arguments specified as dialect-specific
200        options to this construct.
201
202        The arguments are present here in their original ``<dialect>_<kwarg>``
203        format.  Only arguments that were actually passed are included;
204        unlike the :attr:`.DialectKWArgs.dialect_options` collection, which
205        contains all options known by this dialect including defaults.
206
207        The collection is also writable; keys are accepted of the
208        form ``<dialect>_<kwarg>`` where the value will be assembled
209        into the list of options.
210
211        .. versionadded:: 0.9.2
212
213        .. versionchanged:: 0.9.4 The :attr:`.DialectKWArgs.dialect_kwargs`
214           collection is now writable.
215
216        .. seealso::
217
218            :attr:`.DialectKWArgs.dialect_options` - nested dictionary form
219
220        """
221        return _DialectArgView(self)
222
223    @property
224    def kwargs(self):
225        """A synonym for :attr:`.DialectKWArgs.dialect_kwargs`."""
226        return self.dialect_kwargs
227
228    @util.dependencies("sqlalchemy.dialects")
229    def _kw_reg_for_dialect(dialects, dialect_name):
230        dialect_cls = dialects.registry.load(dialect_name)
231        if dialect_cls.construct_arguments is None:
232            return None
233        return dict(dialect_cls.construct_arguments)
234    _kw_registry = util.PopulateDict(_kw_reg_for_dialect)
235
236    def _kw_reg_for_dialect_cls(self, dialect_name):
237        construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
238        d = _DialectArgDict()
239
240        if construct_arg_dictionary is None:
241            d._defaults.update({"*": None})
242        else:
243            for cls in reversed(self.__class__.__mro__):
244                if cls in construct_arg_dictionary:
245                    d._defaults.update(construct_arg_dictionary[cls])
246        return d
247
248    @util.memoized_property
249    def dialect_options(self):
250        """A collection of keyword arguments specified as dialect-specific
251        options to this construct.
252
253        This is a two-level nested registry, keyed to ``<dialect_name>``
254        and ``<argument_name>``.  For example, the ``postgresql_where``
255        argument would be locatable as::
256
257            arg = my_object.dialect_options['postgresql']['where']
258
259        .. versionadded:: 0.9.2
260
261        .. seealso::
262
263            :attr:`.DialectKWArgs.dialect_kwargs` - flat dictionary form
264
265        """
266
267        return util.PopulateDict(
268            util.portable_instancemethod(self._kw_reg_for_dialect_cls)
269        )
270
271    def _validate_dialect_kwargs(self, kwargs):
272        # validate remaining kwargs that they all specify DB prefixes
273
274        if not kwargs:
275            return
276
277        for k in kwargs:
278            m = re.match('^(.+?)_(.+)$', k)
279            if not m:
280                raise TypeError(
281                    "Additional arguments should be "
282                    "named <dialectname>_<argument>, got '%s'" % k)
283            dialect_name, arg_name = m.group(1, 2)
284
285            try:
286                construct_arg_dictionary = self.dialect_options[dialect_name]
287            except exc.NoSuchModuleError:
288                util.warn(
289                    "Can't validate argument %r; can't "
290                    "locate any SQLAlchemy dialect named %r" %
291                    (k, dialect_name))
292                self.dialect_options[dialect_name] = d = _DialectArgDict()
293                d._defaults.update({"*": None})
294                d._non_defaults[arg_name] = kwargs[k]
295            else:
296                if "*" not in construct_arg_dictionary and \
297                        arg_name not in construct_arg_dictionary:
298                    raise exc.ArgumentError(
299                        "Argument %r is not accepted by "
300                        "dialect %r on behalf of %r" % (
301                            k,
302                            dialect_name, self.__class__
303                        ))
304                else:
305                    construct_arg_dictionary[arg_name] = kwargs[k]
306
307
308class Generative(object):
309    """Allow a ClauseElement to generate itself via the
310    @_generative decorator.
311
312    """
313
314    def _generate(self):
315        s = self.__class__.__new__(self.__class__)
316        s.__dict__ = self.__dict__.copy()
317        return s
318
319
320class Executable(Generative):
321    """Mark a ClauseElement as supporting execution.
322
323    :class:`.Executable` is a superclass for all "statement" types
324    of objects, including :func:`select`, :func:`delete`, :func:`update`,
325    :func:`insert`, :func:`text`.
326
327    """
328
329    supports_execution = True
330    _execution_options = util.immutabledict()
331    _bind = None
332
333    @_generative
334    def execution_options(self, **kw):
335        """ Set non-SQL options for the statement which take effect during
336        execution.
337
338        Execution options can be set on a per-statement or
339        per :class:`.Connection` basis.   Additionally, the
340        :class:`.Engine` and ORM :class:`~.orm.query.Query` objects provide
341        access to execution options which they in turn configure upon
342        connections.
343
344        The :meth:`execution_options` method is generative.  A new
345        instance of this statement is returned that contains the options::
346
347            statement = select([table.c.x, table.c.y])
348            statement = statement.execution_options(autocommit=True)
349
350        Note that only a subset of possible execution options can be applied
351        to a statement - these include "autocommit" and "stream_results",
352        but not "isolation_level" or "compiled_cache".
353        See :meth:`.Connection.execution_options` for a full list of
354        possible options.
355
356        .. seealso::
357
358            :meth:`.Connection.execution_options()`
359
360            :meth:`.Query.execution_options()`
361
362        """
363        if 'isolation_level' in kw:
364            raise exc.ArgumentError(
365                "'isolation_level' execution option may only be specified "
366                "on Connection.execution_options(), or "
367                "per-engine using the isolation_level "
368                "argument to create_engine()."
369            )
370        if 'compiled_cache' in kw:
371            raise exc.ArgumentError(
372                "'compiled_cache' execution option may only be specified "
373                "on Connection.execution_options(), not per statement."
374            )
375        self._execution_options = self._execution_options.union(kw)
376
377    def execute(self, *multiparams, **params):
378        """Compile and execute this :class:`.Executable`."""
379        e = self.bind
380        if e is None:
381            label = getattr(self, 'description', self.__class__.__name__)
382            msg = ('This %s is not directly bound to a Connection or Engine.'
383                   'Use the .execute() method of a Connection or Engine '
384                   'to execute this construct.' % label)
385            raise exc.UnboundExecutionError(msg)
386        return e._execute_clauseelement(self, multiparams, params)
387
388    def scalar(self, *multiparams, **params):
389        """Compile and execute this :class:`.Executable`, returning the
390        result's scalar representation.
391
392        """
393        return self.execute(*multiparams, **params).scalar()
394
395    @property
396    def bind(self):
397        """Returns the :class:`.Engine` or :class:`.Connection` to
398        which this :class:`.Executable` is bound, or None if none found.
399
400        This is a traversal which checks locally, then
401        checks among the "from" clauses of associated objects
402        until a bound engine or connection is found.
403
404        """
405        if self._bind is not None:
406            return self._bind
407
408        for f in _from_objects(self):
409            if f is self:
410                continue
411            engine = f.bind
412            if engine is not None:
413                return engine
414        else:
415            return None
416
417
418class SchemaEventTarget(object):
419    """Base class for elements that are the targets of :class:`.DDLEvents`
420    events.
421
422    This includes :class:`.SchemaItem` as well as :class:`.SchemaType`.
423
424    """
425
426    def _set_parent(self, parent):
427        """Associate with this SchemaEvent's parent object."""
428
429        raise NotImplementedError()
430
431    def _set_parent_with_dispatch(self, parent):
432        self.dispatch.before_parent_attach(self, parent)
433        self._set_parent(parent)
434        self.dispatch.after_parent_attach(self, parent)
435
436
437class SchemaVisitor(ClauseVisitor):
438    """Define the visiting for ``SchemaItem`` objects."""
439
440    __traverse_options__ = {'schema_visitor': True}
441
442
443class ColumnCollection(util.OrderedProperties):
444    """An ordered dictionary that stores a list of ColumnElement
445    instances.
446
447    Overrides the ``__eq__()`` method to produce SQL clauses between
448    sets of correlated columns.
449
450    """
451
452    __slots__ = '_all_columns'
453
454    def __init__(self, *columns):
455        super(ColumnCollection, self).__init__()
456        object.__setattr__(self, '_all_columns', [])
457        for c in columns:
458            self.add(c)
459
460    def __str__(self):
461        return repr([str(c) for c in self])
462
463    def replace(self, column):
464        """add the given column to this collection, removing unaliased
465           versions of this column  as well as existing columns with the
466           same key.
467
468            e.g.::
469
470                t = Table('sometable', metadata, Column('col1', Integer))
471                t.columns.replace(Column('col1', Integer, key='columnone'))
472
473            will remove the original 'col1' from the collection, and add
474            the new column under the name 'columnname'.
475
476           Used by schema.Column to override columns during table reflection.
477
478        """
479        remove_col = None
480        if column.name in self and column.key != column.name:
481            other = self[column.name]
482            if other.name == other.key:
483                remove_col = other
484                del self._data[other.key]
485
486        if column.key in self._data:
487            remove_col = self._data[column.key]
488
489        self._data[column.key] = column
490        if remove_col is not None:
491            self._all_columns[:] = [column if c is remove_col
492                                    else c for c in self._all_columns]
493        else:
494            self._all_columns.append(column)
495
496    def add(self, column):
497        """Add a column to this collection.
498
499        The key attribute of the column will be used as the hash key
500        for this dictionary.
501
502        """
503        if not column.key:
504            raise exc.ArgumentError(
505                "Can't add unnamed column to column collection")
506        self[column.key] = column
507
508    def __delitem__(self, key):
509        raise NotImplementedError()
510
511    def __setattr__(self, key, object):
512        raise NotImplementedError()
513
514    def __setitem__(self, key, value):
515        if key in self:
516
517            # this warning is primarily to catch select() statements
518            # which have conflicting column names in their exported
519            # columns collection
520
521            existing = self[key]
522            if not existing.shares_lineage(value):
523                util.warn('Column %r on table %r being replaced by '
524                          '%r, which has the same key.  Consider '
525                          'use_labels for select() statements.' %
526                          (key, getattr(existing, 'table', None), value))
527
528            # pop out memoized proxy_set as this
529            # operation may very well be occurring
530            # in a _make_proxy operation
531            util.memoized_property.reset(value, "proxy_set")
532
533        self._all_columns.append(value)
534        self._data[key] = value
535
536    def clear(self):
537        raise NotImplementedError()
538
539    def remove(self, column):
540        del self._data[column.key]
541        self._all_columns[:] = [
542            c for c in self._all_columns if c is not column]
543
544    def update(self, iter):
545        cols = list(iter)
546        all_col_set = set(self._all_columns)
547        self._all_columns.extend(
548            c for label, c in cols if c not in all_col_set)
549        self._data.update((label, c) for label, c in cols)
550
551    def extend(self, iter):
552        cols = list(iter)
553        all_col_set = set(self._all_columns)
554        self._all_columns.extend(c for c in cols if c not in all_col_set)
555        self._data.update((c.key, c) for c in cols)
556
557    __hash__ = None
558
559    @util.dependencies("sqlalchemy.sql.elements")
560    def __eq__(self, elements, other):
561        l = []
562        for c in getattr(other, "_all_columns", other):
563            for local in self._all_columns:
564                if c.shares_lineage(local):
565                    l.append(c == local)
566        return elements.and_(*l)
567
568    def __contains__(self, other):
569        if not isinstance(other, util.string_types):
570            raise exc.ArgumentError("__contains__ requires a string argument")
571        return util.OrderedProperties.__contains__(self, other)
572
573    def __getstate__(self):
574        return {'_data': self._data,
575                '_all_columns': self._all_columns}
576
577    def __setstate__(self, state):
578        object.__setattr__(self, '_data', state['_data'])
579        object.__setattr__(self, '_all_columns', state['_all_columns'])
580
581    def contains_column(self, col):
582        return col in set(self._all_columns)
583
584    def as_immutable(self):
585        return ImmutableColumnCollection(self._data, self._all_columns)
586
587
588class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
589    def __init__(self, data, all_columns):
590        util.ImmutableProperties.__init__(self, data)
591        object.__setattr__(self, '_all_columns', all_columns)
592
593    extend = remove = util.ImmutableProperties._immutable
594
595
596class ColumnSet(util.ordered_column_set):
597    def contains_column(self, col):
598        return col in self
599
600    def extend(self, cols):
601        for col in cols:
602            self.add(col)
603
604    def __add__(self, other):
605        return list(self) + list(other)
606
607    @util.dependencies("sqlalchemy.sql.elements")
608    def __eq__(self, elements, other):
609        l = []
610        for c in other:
611            for local in self:
612                if c.shares_lineage(local):
613                    l.append(c == local)
614        return elements.and_(*l)
615
616    def __hash__(self):
617        return hash(tuple(x for x in self))
618
619
620def _bind_or_error(schemaitem, msg=None):
621    bind = schemaitem.bind
622    if not bind:
623        name = schemaitem.__class__.__name__
624        label = getattr(schemaitem, 'fullname',
625                        getattr(schemaitem, 'name', None))
626        if label:
627            item = '%s object %r' % (name, label)
628        else:
629            item = '%s object' % name
630        if msg is None:
631            msg = "%s is not bound to an Engine or Connection.  "\
632                "Execution can not proceed without a database to execute "\
633                "against." % item
634        raise exc.UnboundExecutionError(msg)
635    return bind
636